소스 검색

refactor: Fix ZST Support

zstd-klauspost-attempt
Fionera 3 년 전
부모
커밋
efbde3a3ff
2개의 변경된 파일78개의 추가작업 그리고 71개의 파일을 삭제
  1. +4
    -0
      main.go
  2. +74
    -71
      server/handlers.go

+ 4
- 0
main.go 파일 보기

@@ -1,10 +1,14 @@
package main package main


import ( import (
"log"

"github.com/dutchcoders/transfer.sh/cmd" "github.com/dutchcoders/transfer.sh/cmd"
) )


func main() { func main() {
log.SetFlags(log.Lshortfile | log.LstdFlags)

app := cmd.New() app := cmd.New()
app.RunAndExitOnError() app.RunAndExitOnError()
} }

+ 74
- 71
server/handlers.go 파일 보기

@@ -32,6 +32,7 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -42,6 +43,7 @@ import (
"log" "log"
"math/rand" "math/rand"
"mime" "mime"
"net"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -53,14 +55,11 @@ import (
text_template "text/template" text_template "text/template"
"time" "time"


"net"

web "github.com/dutchcoders/transfer.sh-web" web "github.com/dutchcoders/transfer.sh-web"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/russross/blackfriday" "github.com/russross/blackfriday"


"encoding/base64"
qrcode "github.com/skip2/go-qrcode"
"github.com/skip2/go-qrcode"


"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
) )
@@ -559,9 +558,9 @@ func (s *Server) CheckMetadata(token, filename string) error {


r, _, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename)) r, _, _, err := s.storage.Get(token, fmt.Sprintf("%s.metadata", filename))
//if s.storage.IsNotExist(err) { //if s.storage.IsNotExist(err) {
// return nil
// return nil
//} else if err != nil { //} else if err != nil {
if err != nil {
if err != nil {
return err return err
} }


@@ -860,6 +859,62 @@ func (s *Server) headHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "close") w.Header().Set("Connection", "close")
} }


func (s *Server) getHandlerZst(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)

action := vars["action"]
token := vars["token"]
filename := vars["filename"]

if err := s.CheckMetadata(token, filename+".zst"); err != nil {
log.Printf("Error metadata: %v", err)
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

reader, contentType, _, err := s.storage.Get(token, filename+".zst")
if s.storage.IsNotExist(err) {
log.Println(err)
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
if err != nil {
log.Printf("Failed to get from storage: %s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
}
defer reader.Close()
contentType = mime.TypeByExtension(filepath.Ext(filename))

d, err := zstd.NewReader(reader)
if err != nil {
log.Printf("Failed to create zstd reader: %s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
}
defer d.Close()

var disposition string
if action == "inline" {
disposition = "inline"
} else {
disposition = "attachment"
}

w.Header().Set("Content-Type", contentType)
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, filename))
w.Header().Set("Connection", "keep-alive")

if w.Header().Get("Range") != "" {
log.Printf("Range request with decompression")
http.Error(w, "Range requests with decompression are not supported", 400)
return
}

_, _ = io.Copy(w, d)
}

func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r) vars := mux.Vars(r)


@@ -868,45 +923,23 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
filename := vars["filename"] filename := vars["filename"]


if err := s.CheckMetadata(token, filename); err != nil { if err := s.CheckMetadata(token, filename); err != nil {
if err2 := s.CheckMetadata(token, filename + ".zst"); err2 != nil {
log.Printf("Error metadata: %s and %s", err.Error(), err2.Error())
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}
log.Printf("Error metadata: %v; trying with .zst", err)
s.getHandlerZst(w, r)
return
} }


reader, contentType, contentLength, err := s.storage.Get(token, filename) reader, contentType, contentLength, err := s.storage.Get(token, filename)
isZstd := false
var d zstd.Decoder
_ = d // Only used when isZstd is true; silence compiler
if s.storage.IsNotExist(err) { if s.storage.IsNotExist(err) {
reader, _, _, err := s.storage.Get(token, filename + ".zst")
if s.storage.IsNotExist(err) {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if err != nil {
log.Printf("Failed to get .zst from storage: %s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
}
defer reader.Close()
d, err := zstd.NewReader(reader)
if err != nil {
log.Printf("Failed to create zstd reader: %s", err.Error())
http.Error(w, "Could not retrieve file.", 500)
return
}
defer d.Close()
isZstd = true
contentType = mime.TypeByExtension(filepath.Ext(filename))
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if err != nil { } else if err != nil {
log.Printf("Failed to get from storage: %s", err.Error())
log.Printf("%s", err.Error())
http.Error(w, "Could not retrieve file.", 500) http.Error(w, "Could not retrieve file.", 500)
return return
} else {
defer reader.Close()
} }


defer reader.Close()

var disposition string var disposition string


if action == "inline" { if action == "inline" {
@@ -916,47 +949,17 @@ func (s *Server) getHandler(w http.ResponseWriter, r *http.Request) {
} }


w.Header().Set("Content-Type", contentType) w.Header().Set("Content-Type", contentType)
if !isZstd {
w.Header().Set("Content-Length", strconv.FormatUint(contentLength, 10))
} else {
w.Header().Set("Transfer-Encoding", "chunked")
}
w.Header().Set("Content-Length", strconv.FormatUint(contentLength, 10))
w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, filename)) w.Header().Set("Content-Disposition", fmt.Sprintf("%s; filename=\"%s\"", disposition, filename))
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")


if w.Header().Get("Range") == "" { if w.Header().Get("Range") == "" {
if !isZstd {
if _, err = io.Copy(w, reader); err != nil {
log.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", 500)
return
}
} else {
buffer := make([]byte, 1024)
for {
log.Printf("Reading from decoded stream")
n, err := d.Read(buffer)
log.Printf("Read from decoded stream")
if err != nil && err != io.EOF {
log.Printf("Failed to read from file: %s", err.Error())
panic("Error reading data")
}
log.Printf("Writing to HTTP")
w.Write(buffer[0:n])
log.Printf("Trying to flush")
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
if err == io.EOF {
break
}
}
if _, err = io.Copy(w, reader); err != nil {
log.Printf("%s", err.Error())
http.Error(w, "Error occurred copying to output stream", 500)
return
} }


return
} else if isZstd {
log.Printf("Range request with decompression")
http.Error(w, "Range requests with decompression are not supported", 400)
return return
} }




불러오는 중...
취소
저장