From efbde3a3ff5518f3740c7803033ae3f0ec3365d3 Mon Sep 17 00:00:00 2001 From: Fionera Date: Sun, 7 Feb 2021 03:35:24 +0100 Subject: [PATCH] refactor: Fix ZST Support --- main.go | 4 ++ server/handlers.go | 145 +++++++++++++++++++++++---------------------- 2 files changed, 78 insertions(+), 71 deletions(-) diff --git a/main.go b/main.go index 8e37b69..544be73 100644 --- a/main.go +++ b/main.go @@ -1,10 +1,14 @@ package main import ( + "log" + "github.com/dutchcoders/transfer.sh/cmd" ) func main() { + log.SetFlags(log.Lshortfile | log.LstdFlags) + app := cmd.New() app.RunAndExitOnError() } diff --git a/server/handlers.go b/server/handlers.go index 42aa518..0971e33 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -32,6 +32,7 @@ import ( "archive/zip" "bytes" "compress/gzip" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -42,6 +43,7 @@ import ( "log" "math/rand" "mime" + "net" "net/http" "net/url" "os" @@ -53,14 +55,11 @@ import ( text_template "text/template" "time" - "net" - web "github.com/dutchcoders/transfer.sh-web" "github.com/gorilla/mux" "github.com/russross/blackfriday" - "encoding/base64" - qrcode "github.com/skip2/go-qrcode" + "github.com/skip2/go-qrcode" "github.com/klauspost/compress/zstd" ) @@ -559,9 +558,9 @@ func (s *Server) CheckMetadata(token, filename string) error { r, _, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename)) //if s.storage.IsNotExist(err) { - // return nil + // return nil //} else if err != nil { - if err != nil { + if err != nil { return err } @@ -860,6 +859,62 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "close") } +func (s *Server) getHandlerZst(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + + action := vars["action"] + token := vars["token"] + filename := vars["filename"] + + if err := s.CheckMetadata(token, filename+".zst"); err != nil { + log.Printf("Error metadata: %v", err) + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + + reader, contentType, _, err := s.storage.Get(token, filename+".zst") + if s.storage.IsNotExist(err) { + log.Println(err) + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + if err != nil { + log.Printf("Failed to get from storage: %s", err.Error()) + http.Error(w, "Could not retrieve file.", 500) + return + } + defer reader.Close() + contentType = mime.TypeByExtension(filepath.Ext(filename)) + + d, err := zstd.NewReader(reader) + if err != nil { + log.Printf("Failed to create zstd reader: %s", err.Error()) + http.Error(w, "Could not retrieve file.", 500) + return + } + defer d.Close() + + var disposition string + if action == "inline" { + disposition = "inline" + } else { + disposition = "attachment" + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, filename)) + w.Header().Set("Connection", "keep-alive") + + if w.Header().Get("Range") != "" { + log.Printf("Range request with decompression") + http.Error(w, "Range requests with decompression are not supported", 400) + return + } + + _, _ = io.Copy(w, d) +} + func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) @@ -868,45 +923,23 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { filename := vars["filename"] if err := s.CheckMetadata(token, filename); err != nil { - if err2 := s.CheckMetadata(token, filename + ".zst"); err2 != nil { - log.Printf("Error metadata: %s and %s", err.Error(), err2.Error()) - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - return - } + log.Printf("Error metadata: %v; trying with .zst", err) + s.getHandlerZst(w, r) + return } reader, contentType, contentLength, err := s.storage.Get(token, filename) - isZstd := false - var d zstd.Decoder - _ = d // Only used when isZstd is true; silence compiler if s.storage.IsNotExist(err) { - reader, _, _, err := s.storage.Get(token, filename + ".zst") - if s.storage.IsNotExist(err) { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - return - } else if err != nil { - log.Printf("Failed to get .zst from storage: %s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) - return - } - defer reader.Close() - d, err := zstd.NewReader(reader) - if err != nil { - log.Printf("Failed to create zstd reader: %s", err.Error()) - http.Error(w, "Could not retrieve file.", 500) - return - } - defer d.Close() - isZstd = true - contentType = mime.TypeByExtension(filepath.Ext(filename)) + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return } else if err != nil { - log.Printf("Failed to get from storage: %s", err.Error()) + log.Printf("%s", err.Error()) http.Error(w, "Could not retrieve file.", 500) return - } else { - defer reader.Close() } + defer reader.Close() + var disposition string if action == "inline" { @@ -916,47 +949,17 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", contentType) - if !isZstd { - w.Header().Set("Content-Length", strconv.FormatUint(contentLength, 10)) - } else { - w.Header().Set("Transfer-Encoding", "chunked") - } + w.Header().Set("Content-Length", strconv.FormatUint(contentLength, 10)) w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, filename)) w.Header().Set("Connection", "keep-alive") if w.Header().Get("Range") == "" { - if !isZstd { - if _, err = io.Copy(w, reader); err != nil { - log.Printf("%s", err.Error()) - http.Error(w, "Error occurred copying to output stream", 500) - return - } - } else { - buffer := make([]byte, 1024) - for { - log.Printf("Reading from decoded stream") - n, err := d.Read(buffer) - log.Printf("Read from decoded stream") - if err != nil && err != io.EOF { - log.Printf("Failed to read from file: %s", err.Error()) - panic("Error reading data") - } - log.Printf("Writing to HTTP") - w.Write(buffer[0:n]) - log.Printf("Trying to flush") - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - if err == io.EOF { - break - } - } + if _, err = io.Copy(w, reader); err != nil { + log.Printf("%s", err.Error()) + http.Error(w, "Error occurred copying to output stream", 500) + return } - return - } else if isZstd { - log.Printf("Range request with decompression") - http.Error(w, "Range requests with decompression are not supported", 400) return }