diff --git a/README.md b/README.md index 65f34b5..c37629f 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,7 @@ log | path to log file| | LOG | cors-domains | comma separated list of domains for CORS, setting it enable CORS | | CORS_DOMAINS | clamav-host | host for clamav feature | | CLAMAV_HOST | rate-limit | request per minute | | RATE_LIMIT | +max-upload-size | max upload size in kilobytes | | MAX_UPLOAD_SIZE | If you want to use TLS using lets encrypt certificates, set lets-encrypt-hosts to your domain, set tls-listener to :443 and enable force-https. diff --git a/cmd/cmd.go b/cmd/cmd.go index 5abd46e..f289af6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -34,85 +34,85 @@ VERSION: var globalFlags = []cli.Flag{ cli.StringFlag{ - Name: "listener", - Usage: "127.0.0.1:8080", - Value: "127.0.0.1:8080", + Name: "listener", + Usage: "127.0.0.1:8080", + Value: "127.0.0.1:8080", EnvVar: "LISTENER", }, // redirect to https? // hostnames cli.StringFlag{ - Name: "profile-listener", - Usage: "127.0.0.1:6060", - Value: "", + Name: "profile-listener", + Usage: "127.0.0.1:6060", + Value: "", EnvVar: "PROFILE_LISTENER", }, cli.BoolFlag{ - Name: "force-https", - Usage: "", + Name: "force-https", + Usage: "", EnvVar: "FORCE_HTTPS", }, cli.StringFlag{ - Name: "tls-listener", - Usage: "127.0.0.1:8443", - Value: "", + Name: "tls-listener", + Usage: "127.0.0.1:8443", + Value: "", EnvVar: "TLS_LISTENER", }, cli.BoolFlag{ - Name: "tls-listener-only", - Usage: "", + Name: "tls-listener-only", + Usage: "", EnvVar: "TLS_LISTENER_ONLY", }, cli.StringFlag{ - Name: "tls-cert-file", - Value: "", + Name: "tls-cert-file", + Value: "", EnvVar: "TLS_CERT_FILE", }, cli.StringFlag{ - Name: "tls-private-key", - Value: "", + Name: "tls-private-key", + Value: "", EnvVar: "TLS_PRIVATE_KEY", }, cli.StringFlag{ - Name: "temp-path", - Usage: "path to temp files", - Value: os.TempDir(), + Name: "temp-path", + Usage: "path to temp files", + Value: os.TempDir(), EnvVar: "TEMP_PATH", }, cli.StringFlag{ - Name: "web-path", - Usage: "path to static web files", - Value: "", + Name: "web-path", + Usage: "path to static web files", + Value: "", EnvVar: "WEB_PATH", }, cli.StringFlag{ - Name: "proxy-path", - Usage: "path prefix when service is run behind a proxy", - Value: "", + Name: "proxy-path", + Usage: "path prefix when service is run behind a proxy", + Value: "", EnvVar: "PROXY_PATH", }, cli.StringFlag{ - Name: "proxy-port", - Usage: "port of the proxy when the service is run behind a proxy", - Value: "", + Name: "proxy-port", + Usage: "port of the proxy when the service is run behind a proxy", + Value: "", EnvVar: "PROXY_PORT", }, cli.StringFlag{ - Name: "ga-key", - Usage: "key for google analytics (front end)", - Value: "", + Name: "ga-key", + Usage: "key for google analytics (front end)", + Value: "", EnvVar: "GA_KEY", }, cli.StringFlag{ - Name: "uservoice-key", - Usage: "key for user voice (front end)", - Value: "", + Name: "uservoice-key", + Usage: "key for user voice (front end)", + Value: "", EnvVar: "USERVOICE_KEY", }, cli.StringFlag{ - Name: "provider", - Usage: "s3|gdrive|local", - Value: "", + Name: "provider", + Usage: "s3|gdrive|local", + Value: "", EnvVar: "PROVIDER", }, cli.StringFlag{ @@ -146,31 +146,31 @@ var globalFlags = []cli.Flag{ EnvVar: "BUCKET", }, cli.BoolFlag{ - Name: "s3-no-multipart", - Usage: "Disables S3 Multipart Puts", + Name: "s3-no-multipart", + Usage: "Disables S3 Multipart Puts", EnvVar: "S3_NO_MULTIPART", }, cli.BoolFlag{ - Name: "s3-path-style", - Usage: "Forces path style URLs, required for Minio.", + Name: "s3-path-style", + Usage: "Forces path style URLs, required for Minio.", EnvVar: "S3_PATH_STYLE", }, cli.StringFlag{ - Name: "gdrive-client-json-filepath", - Usage: "", - Value: "", + Name: "gdrive-client-json-filepath", + Usage: "", + Value: "", EnvVar: "GDRIVE_CLIENT_JSON_FILEPATH", }, cli.StringFlag{ - Name: "gdrive-local-config-path", - Usage: "", - Value: "", + Name: "gdrive-local-config-path", + Usage: "", + Value: "", EnvVar: "GDRIVE_LOCAL_CONFIG_PATH", }, cli.IntFlag{ - Name: "gdrive-chunk-size", - Usage: "", - Value: googleapi.DefaultUploadChunkSize / 1024 / 1024, + Name: "gdrive-chunk-size", + Usage: "", + Value: googleapi.DefaultUploadChunkSize / 1024 / 1024, EnvVar: "GDRIVE_CHUNK_SIZE", }, cli.StringFlag{ @@ -191,6 +191,12 @@ var globalFlags = []cli.Flag{ Value: 0, EnvVar: "RATE_LIMIT", }, + cli.Int64Flag{ + Name: "max-upload-size", + Usage: "max limit for upload, in kilobytes", + Value: 0, + EnvVar: "MAX_UPLOAD_SIZE", + }, cli.StringFlag{ Name: "lets-encrypt-hosts", Usage: "host1, host2", @@ -198,15 +204,15 @@ var globalFlags = []cli.Flag{ EnvVar: "HOSTS", }, cli.StringFlag{ - Name: "log", - Usage: "/var/log/transfersh.log", - Value: "", + Name: "log", + Usage: "/var/log/transfersh.log", + Value: "", EnvVar: "LOG", }, cli.StringFlag{ - Name: "basedir", - Usage: "path to storage", - Value: "", + Name: "basedir", + Usage: "path to storage", + Value: "", EnvVar: "BASEDIR", }, cli.StringFlag{ @@ -222,38 +228,38 @@ var globalFlags = []cli.Flag{ EnvVar: "VIRUSTOTAL_KEY", }, cli.BoolFlag{ - Name: "profiler", - Usage: "enable profiling", + Name: "profiler", + Usage: "enable profiling", EnvVar: "PROFILER", }, cli.StringFlag{ - Name: "http-auth-user", - Usage: "user for http basic auth", - Value: "", + Name: "http-auth-user", + Usage: "user for http basic auth", + Value: "", EnvVar: "HTTP_AUTH_USER", }, cli.StringFlag{ - Name: "http-auth-pass", - Usage: "pass for http basic auth", - Value: "", + Name: "http-auth-pass", + Usage: "pass for http basic auth", + Value: "", EnvVar: "HTTP_AUTH_PASS", }, cli.StringFlag{ - Name: "ip-whitelist", - Usage: "comma separated list of ips allowed to connect to the service", - Value: "", + Name: "ip-whitelist", + Usage: "comma separated list of ips allowed to connect to the service", + Value: "", EnvVar: "IP_WHITELIST", }, cli.StringFlag{ - Name: "ip-blacklist", - Usage: "comma separated list of ips not allowed to connect to the service", - Value: "", + Name: "ip-blacklist", + Usage: "comma separated list of ips not allowed to connect to the service", + Value: "", EnvVar: "IP_BLACKLIST", }, cli.StringFlag{ - Name: "cors-domains", - Usage: "comma separated list of domains allowed for CORS requests", - Value: "", + Name: "cors-domains", + Usage: "comma separated list of domains allowed for CORS requests", + Value: "", EnvVar: "CORS_DOMAINS", }, } @@ -351,6 +357,10 @@ func New() *Cmd { options = append(options, server.ClamavHost(v)) } + if v := c.Int64("max-upload-size"); v > 0 { + options = append(options, server.MaxUploadSize(v)) + } + if v := c.Int("rate-limit"); v > 0 { options = append(options, server.RateLimit(v)) } diff --git a/server/handlers.go b/server/handlers.go index 067f493..67f3779 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -311,6 +311,12 @@ func (s *Server) postHandler(w http.ResponseWriter, r *http.Request) { contentLength := n + if s.maxUploadSize > 0 && contentLength > s.maxUploadSize { + log.Print("Entity too large") + http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) + return + } + metadata := MetadataForRequest(contentType, r) buffer := &bytes.Buffer{} @@ -455,6 +461,12 @@ func (s *Server) putHandler(w http.ResponseWriter, r *http.Request) { contentLength = n } + if s.maxUploadSize > 0 && contentLength > s.maxUploadSize { + log.Print("Entity too large") + http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge) + return + } + if contentLength == 0 { log.Print("Empty content-length") http.Error(w, errors.New("Could not upload empty file").Error(), 400) diff --git a/server/server.go b/server/server.go index 3fc0b97..2c2d6ec 100644 --- a/server/server.go +++ b/server/server.go @@ -25,11 +25,11 @@ THE SOFTWARE. package server import ( + crypto_rand "crypto/rand" + "encoding/binary" "errors" gorillaHandlers "github.com/gorilla/handlers" "log" - crypto_rand "crypto/rand" - "encoding/binary" "math/rand" "mime" "net/http" @@ -175,6 +175,12 @@ func Logger(logger *log.Logger) OptionFn { } } +func MaxUploadSize(kbytes int64) OptionFn { + return func(srvr *Server) { + srvr.maxUploadSize = kbytes * 1024 + } + +} func RateLimit(requests int) OptionFn { return func(srvr *Server) { srvr.rateLimitRequests = requests @@ -271,6 +277,7 @@ type Server struct { locks map[string]*sync.Mutex + maxUploadSize int64 rateLimitRequests int storage Storage