From baa2fdc86cbab1676402c7a364e269f8e53e931a Mon Sep 17 00:00:00 2001 From: Andrea Spacca Date: Sat, 23 Jun 2018 18:46:28 +0200 Subject: [PATCH] ISSUE-92 added http basic auth handler for upload --- README.md | 2 ++ cmd/cmd.go | 19 +++++++++++++++++-- server/handlers.go | 24 ++++++++++++++++++++++++ server/server.go | 19 +++++++++++++++---- 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index e56ced4..a66d1b2 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,8 @@ force-https | redirect to https | false | tls-listener | port to use for https (:443) | | tls-cert-file | path to tls certificate | | tls-private-key | path to tls private key | | +http-auth-user | user for basic http auth on upload | | +http-auth-pass | pass for basic http auth on upload | | temp-path | path to temp folder | system temp | web-path | path to static web files (for development) | | provider | which storage provider to use | (s3, grdrive or local) | diff --git a/cmd/cmd.go b/cmd/cmd.go index f92c1cf..1c29fad 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -105,13 +105,11 @@ var globalFlags = []cli.Flag{ Name: "gdrive-client-json-filepath", Usage: "", Value: "", - EnvVar: "", }, cli.StringFlag{ Name: "gdrive-local-config-path", Usage: "", Value: "", - EnvVar: "", }, cli.IntFlag{ Name: "rate-limit", @@ -151,6 +149,16 @@ var globalFlags = []cli.Flag{ Name: "profiler", Usage: "enable profiling", }, + cli.StringFlag{ + Name: "http-auth-user", + Usage: "user for http basic auth", + Value: "", + }, + cli.StringFlag{ + Name: "http-auth-pass", + Usage: "pass for http basic auth", + Value: "", + }, } type Cmd struct { @@ -232,6 +240,13 @@ func New() *Cmd { options = append(options, server.ForceHTTPs()) } + if httpAuthUser := c.String("http-auth-user"); httpAuthUser == "" { + } else if httpAuthPass := c.String("http-auth-pass"); httpAuthPass == "" { + } else { + options = append(options, server.HttpAuthCredentials(httpAuthUser, httpAuthPass)) + } + + switch provider := c.String("provider"); provider { case "s3": if accessKey := c.String("aws-access-key"); accessKey == "" { diff --git a/server/handlers.go b/server/handlers.go index 1576fa5..3472dc8 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -775,3 +775,27 @@ func LoveHandler(h http.Handler) http.HandlerFunc { h.ServeHTTP(w, r) } } + +func (s *Server) BasicAuthHandler(h http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if s.AuthUser == "" || s.AuthPass == "" { + h.ServeHTTP(w, r) + return + } + + w.Header().Set("WWW-Authenticate", "Basic realm=\"Restricted\"") + + username, password, authOK := r.BasicAuth() + if authOK == false { + http.Error(w, "Not authorized", 401) + return + } + + if username != s.AuthUser || password != s.AuthPass { + http.Error(w, "Not authorized", 401) + return + } + + h.ServeHTTP(w, r) + } +} diff --git a/server/server.go b/server/server.go index c750b46..78e5290 100644 --- a/server/server.go +++ b/server/server.go @@ -181,7 +181,18 @@ func TLSConfig(cert, pk string) OptionFn { } } + +func HttpAuthCredentials(user string, pass string) OptionFn { + return func(srvr *Server) { + srvr.AuthUser = user + srvr.AuthPass = pass + } +} + type Server struct { + AuthUser string + AuthPass string + tlsConfig *tls.Config profilerEnabled bool @@ -317,10 +328,10 @@ func (s *Server) Run() { r.HandleFunc("/{filename}/virustotal", s.virusTotalHandler).Methods("PUT") r.HandleFunc("/{filename}/scan", s.scanHandler).Methods("PUT") - r.HandleFunc("/put/{filename}", s.putHandler).Methods("PUT") - r.HandleFunc("/upload/{filename}", s.putHandler).Methods("PUT") - r.HandleFunc("/{filename}", s.putHandler).Methods("PUT") - r.HandleFunc("/", s.postHandler).Methods("POST") + r.HandleFunc("/put/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") + r.HandleFunc("/upload/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") + r.HandleFunc("/{filename}", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("PUT") + r.HandleFunc("/", s.BasicAuthHandler(http.HandlerFunc(s.putHandler))).Methods("POST") // r.HandleFunc("/{page}", viewHandler).Methods("GET") r.NotFoundHandler = http.HandlerFunc(s.notFoundHandler)