diff --git a/README.md b/README.md index f7f8400..b439a2f 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,10 @@ Contributions are welcome. **Uvis Grinfelds** +## Maintainer + +**Andrea Spacca** + ## Copyright and license Code and documentation copyright 2011-2018 Remco Verhoef. diff --git a/cmd/cmd.go b/cmd/cmd.go index 3c820e1..a4cf881 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -10,6 +10,7 @@ import ( "github.com/dutchcoders/transfer.sh/server" "github.com/fatih/color" "github.com/minio/cli" + "log" ) var Version = "0.1" @@ -184,6 +185,8 @@ func VersionAction(c *cli.Context) { } func New() *Cmd { + logger := log.New(os.Stdout, "[transfer.sh]", log.LstdFlags) + app := cli.NewApp() app.Name = "transfer.sh" app.Author = "" @@ -235,6 +238,12 @@ func New() *Cmd { options = append(options, server.TempPath(v)) } + if v := c.String("log"); v != "" { + options = append(options, server.LogFile(logger, v)) + } else { + options = append(options, server.Logger(logger)) + } + if v := c.String("lets-encrypt-hosts"); v != "" { options = append(options, server.UseLetsEncrypt(strings.Split(v, ","))) } @@ -279,7 +288,7 @@ func New() *Cmd { panic("secret-key not set.") } else if bucket := c.String("bucket"); bucket == "" { panic("bucket not set.") - } else if storage, err := server.NewS3Storage(accessKey, secretKey, bucket, c.String("s3-endpoint")); err != nil { + } else if storage, err := server.NewS3Storage(accessKey, secretKey, bucket, c.String("s3-endpoint"), logger); err != nil { panic(err) } else { options = append(options, server.UseStorage(storage)) @@ -291,7 +300,7 @@ func New() *Cmd { panic("local-config-path not set.") } else if basedir := c.String("basedir"); basedir == "" { panic("basedir not set.") - } else if storage, err := server.NewGDriveStorage(clientJsonFilepath, localConfigPath, basedir); err != nil { + } else if storage, err := server.NewGDriveStorage(clientJsonFilepath, localConfigPath, basedir, logger); err != nil { panic(err) } else { options = append(options, server.UseStorage(storage)) @@ -299,7 +308,7 @@ func New() *Cmd { case "local": if v := c.String("basedir"); v == "" { panic("basedir not set.") - } else if storage, err := server.NewLocalStorage(v); err != nil { + } else if storage, err := server.NewLocalStorage(v, logger); err != nil { panic(err) } else { options = append(options, server.UseStorage(storage)) @@ -313,7 +322,7 @@ func New() *Cmd { ) if err != nil { - fmt.Println(color.RedString("Error starting server: %s", err.Error())) + logger.Println(color.RedString("Error starting server: %s", err.Error())) return } diff --git a/server/clamav.go b/server/clamav.go index e3d18da..4352fbc 100644 --- a/server/clamav.go +++ b/server/clamav.go @@ -47,7 +47,7 @@ func (s *Server) scanHandler(w http.ResponseWriter, r *http.Request) { contentLength := r.ContentLength contentType := r.Header.Get("Content-Type") - log.Printf("Scanning %s %d %s", filename, contentLength, contentType) + s.logger.Printf("Scanning %s %d %s", filename, contentLength, contentType) var reader io.Reader diff --git a/server/server.go b/server/server.go index 9e780ae..281f693 100644 --- a/server/server.go +++ b/server/server.go @@ -26,7 +26,6 @@ package server import ( "errors" - "fmt" "log" "math/rand" "mime" @@ -129,14 +128,21 @@ func TempPath(s string) OptionFn { } } -func LogFile(s string) OptionFn { +func LogFile(logger *log.Logger, s string) OptionFn { return func(srvr *Server) { f, err := os.OpenFile(s, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { log.Fatalf("error opening file: %v", err) } - log.SetOutput(f) + logger.SetOutput(f) + srvr.logger = logger + } +} + +func Logger(logger *log.Logger) OptionFn { + return func(srvr *Server) { + srvr.logger = logger } } @@ -214,6 +220,8 @@ type Server struct { AuthUser string AuthPass string + logger *log.Logger + tlsConfig *tls.Config profilerEnabled bool @@ -269,7 +277,7 @@ func (s *Server) Run() { listening = true go func() { - fmt.Println("Profiled listening at: :6060") + s.logger.Println("Profiled listening at: :6060") http.ListenAndServe(":6060", nil) }() @@ -280,7 +288,7 @@ func (s *Server) Run() { var fs http.FileSystem if s.webPath != "" { - log.Println("Using static file path: ", s.webPath) + s.logger.Println("Using static file path: ", s.webPath) fs = http.Dir(s.webPath) @@ -299,7 +307,7 @@ func (s *Server) Run() { for _, path := range web.AssetNames() { bytes, err := web.Asset(path) if err != nil { - log.Panicf("Unable to parse: path=%s, err=%s", path, err) + s.logger.Panicf("Unable to parse: path=%s, err=%s", path, err) } htmlTemplates.New(stripPrefix(path)).Parse(string(bytes)) @@ -341,7 +349,7 @@ func (s *Server) Run() { u, err := url.Parse(r.Referer()) if err != nil { - log.Fatal(err) + s.logger.Fatal(err) return } @@ -371,9 +379,9 @@ func (s *Server) Run() { mime.AddExtensionType(".md", "text/x-markdown") - log.Printf("Transfer.sh server started.\nusing temp folder: %s\nusing storage provider: %s", s.tempPath, s.storage.Type()) + s.logger.Printf("Transfer.sh server started.\nusing temp folder: %s\nusing storage provider: %s", s.tempPath, s.storage.Type()) - h := handlers.PanicHandler(handlers.LogHandler(LoveHandler(s.RedirectHandler(r)), handlers.NewLogOptions(log.Printf, "_default_")), nil) + h := handlers.PanicHandler(handlers.LogHandler(LoveHandler(s.RedirectHandler(r)), handlers.NewLogOptions(s.logger.Printf, "_default_")), nil) if !s.TLSListenerOnly { srvr := &http.Server{ @@ -382,7 +390,7 @@ func (s *Server) Run() { } listening = true - log.Printf("listening on port: %v\n", s.ListenerString) + s.logger.Printf("listening on port: %v\n", s.ListenerString) go func() { srvr.ListenAndServe() @@ -391,7 +399,7 @@ func (s *Server) Run() { if s.TLSListenerString != "" { listening = true - log.Printf("listening on port: %v\n", s.TLSListenerString) + s.logger.Printf("listening on port: %v\n", s.TLSListenerString) go func() { s := &http.Server{ @@ -406,7 +414,7 @@ func (s *Server) Run() { }() } - log.Printf("---------------------------") + s.logger.Printf("---------------------------") term := make(chan os.Signal, 1) signal.Notify(term, os.Interrupt) @@ -415,8 +423,8 @@ func (s *Server) Run() { if listening { <-term } else { - log.Printf("No listener active.") + s.logger.Printf("No listener active.") } - log.Printf("Server stopped.") + s.logger.Printf("Server stopped.") } diff --git a/server/storage.go b/server/storage.go index 732af46..61d3d86 100644 --- a/server/storage.go +++ b/server/storage.go @@ -37,10 +37,11 @@ type Storage interface { type LocalStorage struct { Storage basedir string + logger *log.Logger } -func NewLocalStorage(basedir string) (*LocalStorage, error) { - return &LocalStorage{basedir: basedir}, nil +func NewLocalStorage(basedir string, logger *log.Logger) (*LocalStorage, error) { + return &LocalStorage{basedir: basedir, logger: logger}, nil } func (s *LocalStorage) Type() string { @@ -110,7 +111,6 @@ func (s *LocalStorage) Put(token string, filename string, reader io.Reader, cont } if f, err = os.OpenFile(filepath.Join(path, filename), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600); err != nil { - fmt.Printf("%s", err) return err } @@ -126,15 +126,16 @@ func (s *LocalStorage) Put(token string, filename string, reader io.Reader, cont type S3Storage struct { Storage bucket *s3.Bucket + logger *log.Logger } -func NewS3Storage(accessKey, secretKey, bucketName, endpoint string) (*S3Storage, error) { +func NewS3Storage(accessKey, secretKey, bucketName, endpoint string, logger *log.Logger) (*S3Storage, error) { bucket, err := getBucket(accessKey, secretKey, bucketName, endpoint) if err != nil { return nil, err } - return &S3Storage{bucket: bucket}, nil + return &S3Storage{bucket: bucket, logger: logger}, nil } func (s *S3Storage) Type() string { @@ -165,7 +166,7 @@ func (s *S3Storage) IsNotExist(err error) bool { return false } - log.Printf("IsNotExist: %s, %#v", err.Error(), err) + s.logger.Printf("IsNotExist: %s, %#v", err.Error(), err) b := (err.Error() == "The specified key does not exist.") b = b || (err.Error() == "Access Denied") @@ -210,7 +211,7 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content ) if multi, err = s.bucket.InitMulti(key, contentType, s3.Private); err != nil { - log.Printf(err.Error()) + s.logger.Printf(err.Error()) return } @@ -234,13 +235,13 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content // Amazon expects parts of at least 5MB, except for the last one if count, err = io.ReadAtLeast(reader, buffer, (1<<20)*5); err != nil && err != io.ErrUnexpectedEOF && err != io.EOF { - log.Printf(err.Error()) + s.logger.Printf(err.Error()) return } // always send minimal 1 part if err == io.EOF && index > 1 { - log.Printf("Waiting for all parts to finish uploading.") + s.logger.Printf("Waiting for all parts to finish uploading.") // wait for all parts to be finished uploading wg.Wait() @@ -257,10 +258,10 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content // using goroutines because of retries when upload fails go func(multi *s3.Multi, buffer []byte, index int) { - log.Printf("Uploading part %d %d", index, len(buffer)) + s.logger.Printf("Uploading part %d %d", index, len(buffer)) defer func() { - log.Printf("Finished part %d %d", index, len(buffer)) + s.logger.Printf("Finished part %d %d", index, len(buffer)) wg.Done() @@ -272,12 +273,12 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content var part s3.Part if part, err = multi.PutPart(index, partReader); err != nil { - log.Printf("Error while uploading part %d %d %s", index, len(buffer), err.Error()) + s.logger.Printf("Error while uploading part %d %d %s", index, len(buffer), err.Error()) partsChan <- err return } - log.Printf("Finished uploading part %d %d", index, len(buffer)) + s.logger.Printf("Finished uploading part %d %d", index, len(buffer)) partsChan <- part @@ -294,7 +295,7 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content parts = append(parts, part.(s3.Part)) case error: // abort multi upload - log.Printf("Error during upload, aborting %s.", part.(error).Error()) + s.logger.Printf("Error during upload, aborting %s.", part.(error).Error()) err = part.(error) multi.Abort() @@ -303,14 +304,14 @@ func (s *S3Storage) Put(token string, filename string, reader io.Reader, content } - log.Printf("Completing upload %d parts", len(parts)) + s.logger.Printf("Completing upload %d parts", len(parts)) if err = multi.Complete(parts); err != nil { - log.Printf("Error during completing upload %d parts %s", len(parts), err.Error()) + s.logger.Printf("Error during completing upload %d parts %s", len(parts), err.Error()) return } - log.Printf("Completed uploading %d", len(parts)) + s.logger.Printf("Completed uploading %d", len(parts)) return } @@ -320,9 +321,10 @@ type GDrive struct { rootId string basedir string localConfigPath string + logger *log.Logger } -func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir string) (*GDrive, error) { +func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir string, logger *log.Logger) (*GDrive, error) { b, err := ioutil.ReadFile(clientJsonFilepath) if err != nil { return nil, err @@ -334,12 +336,12 @@ func NewGDriveStorage(clientJsonFilepath string, localConfigPath string, basedir return nil, err } - srv, err := drive.New(getGDriveClient(config, localConfigPath)) + srv, err := drive.New(getGDriveClient(config, localConfigPath, logger)) if err != nil { return nil, err } - storage := &GDrive{service: srv, basedir: basedir, rootId: "", localConfigPath: localConfigPath} + storage := &GDrive{service: srv, basedir: basedir, rootId: "", localConfigPath: localConfigPath, logger: logger} err = storage.setupRoot() if err != nil { return nil, err @@ -570,31 +572,31 @@ func (s *GDrive) Put(token string, filename string, reader io.Reader, contentTyp } // Retrieve a token, saves the token, then returns the generated client. -func getGDriveClient(config *oauth2.Config, localConfigPath string) *http.Client { +func getGDriveClient(config *oauth2.Config, localConfigPath string, logger *log.Logger) *http.Client { tokenFile := filepath.Join(localConfigPath, GDriveTokenJsonFile) tok, err := gDriveTokenFromFile(tokenFile) if err != nil { - tok = getGDriveTokenFromWeb(config) - saveGDriveToken(tokenFile, tok) + tok = getGDriveTokenFromWeb(config, logger) + saveGDriveToken(tokenFile, tok, logger) } return config.Client(context.Background(), tok) } // Request a token from the web, then returns the retrieved token. -func getGDriveTokenFromWeb(config *oauth2.Config) *oauth2.Token { +func getGDriveTokenFromWeb(config *oauth2.Config, logger *log.Logger) *oauth2.Token { authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline) fmt.Printf("Go to the following link in your browser then type the "+ "authorization code: \n%v\n", authURL) var authCode string if _, err := fmt.Scan(&authCode); err != nil { - log.Fatalf("Unable to read authorization code %v", err) + logger.Fatalf("Unable to read authorization code %v", err) } tok, err := config.Exchange(context.TODO(), authCode) if err != nil { - log.Fatalf("Unable to retrieve token from web %v", err) + logger.Fatalf("Unable to retrieve token from web %v", err) } return tok } @@ -612,12 +614,13 @@ func gDriveTokenFromFile(file string) (*oauth2.Token, error) { } // Saves a token to a file path. -func saveGDriveToken(path string, token *oauth2.Token) { - fmt.Printf("Saving credential file to: %s\n", path) +func saveGDriveToken(path string, token *oauth2.Token, logger *log.Logger) { + logger.Printf("Saving credential file to: %s\n", path) f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) defer f.Close() if err != nil { - log.Fatalf("Unable to cache oauth token: %v", err) + logger.Fatalf("Unable to cache oauth token: %v", err) } + json.NewEncoder(f).Encode(token) } diff --git a/server/virustotal.go b/server/virustotal.go index 61c81d2..3e0f618 100644 --- a/server/virustotal.go +++ b/server/virustotal.go @@ -27,7 +27,6 @@ package server import ( "fmt" "io" - "log" "net/http" _ "github.com/PuerkitoBio/ghost/handlers" @@ -44,7 +43,7 @@ func (s *Server) virusTotalHandler(w http.ResponseWriter, r *http.Request) { contentLength := r.ContentLength contentType := r.Header.Get("Content-Type") - log.Printf("Submitting to VirusTotal: %s %d %s", filename, contentLength, contentType) + s.logger.Printf("Submitting to VirusTotal: %s %d %s", filename, contentLength, contentType) vt, err := virustotal.NewVirusTotal(s.VirusTotalKey) if err != nil { @@ -60,6 +59,6 @@ func (s *Server) virusTotalHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), 500) } - log.Println(result) + s.logger.Println(result) w.Write([]byte(fmt.Sprintf("%v\n", result.Permalink))) }