You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

287 lines
6.1 KiB

  1. package server
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "log"
  7. "mime"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "sync"
  12. "github.com/goamz/goamz/s3"
  13. )
  14. type Storage interface {
  15. Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error)
  16. Head(token string, filename string) (contentType string, contentLength uint64, err error)
  17. Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error
  18. IsNotExist(err error) bool
  19. Type() string
  20. }
  21. type LocalStorage struct {
  22. Storage
  23. basedir string
  24. }
  25. func NewLocalStorage(basedir string) (*LocalStorage, error) {
  26. return &LocalStorage{basedir: basedir}, nil
  27. }
  28. func (s *LocalStorage) Type() string {
  29. return "local"
  30. }
  31. func (s *LocalStorage) Head(token string, filename string) (contentType string, contentLength uint64, err error) {
  32. path := filepath.Join(s.basedir, token, filename)
  33. var fi os.FileInfo
  34. if fi, err = os.Lstat(path); err != nil {
  35. return
  36. }
  37. contentLength = uint64(fi.Size())
  38. contentType = mime.TypeByExtension(filepath.Ext(filename))
  39. return
  40. }
  41. func (s *LocalStorage) Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error) {
  42. path := filepath.Join(s.basedir, token, filename)
  43. // content type , content length
  44. if reader, err = os.Open(path); err != nil {
  45. return
  46. }
  47. var fi os.FileInfo
  48. if fi, err = os.Lstat(path); err != nil {
  49. return
  50. }
  51. contentLength = uint64(fi.Size())
  52. contentType = mime.TypeByExtension(filepath.Ext(filename))
  53. return
  54. }
  55. func (s *LocalStorage) IsNotExist(err error) bool {
  56. if err == nil {
  57. return false
  58. }
  59. return os.IsNotExist(err)
  60. }
  61. func (s *LocalStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error {
  62. var f io.WriteCloser
  63. var err error
  64. path := filepath.Join(s.basedir, token)
  65. if err = os.Mkdir(path, 0700); err != nil && !os.IsExist(err) {
  66. return err
  67. }
  68. if f, err = os.OpenFile(filepath.Join(path, filename), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
  69. fmt.Printf("%s", err)
  70. return err
  71. }
  72. defer f.Close()
  73. if _, err = io.Copy(f, reader); err != nil {
  74. return err
  75. }
  76. return nil
  77. }
  78. type S3Storage struct {
  79. Storage
  80. bucket *s3.Bucket
  81. }
  82. func NewS3Storage(accessKey, secretKey, bucketName, endpoint string) (*S3Storage, error) {
  83. bucket, err := getBucket(accessKey, secretKey, bucketName, endpoint)
  84. if err != nil {
  85. return nil, err
  86. }
  87. return &S3Storage{bucket: bucket}, nil
  88. }
  89. func (s *S3Storage) Type() string {
  90. return "s3"
  91. }
  92. func (s *S3Storage) Head(token string, filename string) (contentType string, contentLength uint64, err error) {
  93. key := fmt.Sprintf("%s/%s", token, filename)
  94. // content type , content length
  95. response, err := s.bucket.Head(key, map[string][]string{})
  96. if err != nil {
  97. return
  98. }
  99. contentType = response.Header.Get("Content-Type")
  100. contentLength, err = strconv.ParseUint(response.Header.Get("Content-Length"), 10, 0)
  101. if err != nil {
  102. return
  103. }
  104. return
  105. }
  106. func (s *S3Storage) IsNotExist(err error) bool {
  107. if err == nil {
  108. return false
  109. }
  110. log.Printf("IsNotExist: %s, %#v", err.Error(), err)
  111. b := (err.Error() == "The specified key does not exist.")
  112. b = b || (err.Error() == "Access Denied")
  113. return b
  114. }
  115. func (s *S3Storage) Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error) {
  116. key := fmt.Sprintf("%s/%s", token, filename)
  117. // content type , content length
  118. response, err := s.bucket.GetResponse(key)
  119. if err != nil {
  120. return
  121. }
  122. contentType = response.Header.Get("Content-Type")
  123. contentLength, err = strconv.ParseUint(response.Header.Get("Content-Length"), 10, 0)
  124. if err != nil {
  125. return
  126. }
  127. reader = response.Body
  128. return
  129. }
  130. func (s *S3Storage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) {
  131. key := fmt.Sprintf("%s/%s", token, filename)
  132. var (
  133. multi *s3.Multi
  134. parts []s3.Part
  135. )
  136. if multi, err = s.bucket.InitMulti(key, contentType, s3.Private); err != nil {
  137. log.Printf(err.Error())
  138. return
  139. }
  140. // 20 mb parts
  141. partsChan := make(chan interface{})
  142. // partsChan := make(chan s3.Part)
  143. go func() {
  144. // maximize to 20 threads
  145. sem := make(chan int, 20)
  146. index := 1
  147. var wg sync.WaitGroup
  148. for {
  149. // buffered in memory because goamz s3 multi needs seekable reader
  150. var (
  151. buffer []byte = make([]byte, (1<<20)*10)
  152. count int
  153. err error
  154. )
  155. // Amazon expects parts of at least 5MB, except for the last one
  156. if count, err = io.ReadAtLeast(reader, buffer, (1<<20)*5); err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
  157. log.Printf(err.Error())
  158. return
  159. }
  160. // always send minimal 1 part
  161. if err == io.EOF && index > 1 {
  162. log.Printf("Waiting for all parts to finish uploading.")
  163. // wait for all parts to be finished uploading
  164. wg.Wait()
  165. // and close the channel
  166. close(partsChan)
  167. return
  168. }
  169. wg.Add(1)
  170. sem <- 1
  171. // using goroutines because of retries when upload fails
  172. go func(multi *s3.Multi, buffer []byte, index int) {
  173. log.Printf("Uploading part %d %d", index, len(buffer))
  174. defer func() {
  175. log.Printf("Finished part %d %d", index, len(buffer))
  176. wg.Done()
  177. <-sem
  178. }()
  179. partReader := bytes.NewReader(buffer)
  180. var part s3.Part
  181. if part, err = multi.PutPart(index, partReader); err != nil {
  182. log.Printf("Error while uploading part %d %d %s", index, len(buffer), err.Error())
  183. partsChan <- err
  184. return
  185. }
  186. log.Printf("Finished uploading part %d %d", index, len(buffer))
  187. partsChan <- part
  188. }(multi, buffer[:count], index)
  189. index++
  190. }
  191. }()
  192. // wait for all parts to be uploaded
  193. for part := range partsChan {
  194. switch part.(type) {
  195. case s3.Part:
  196. parts = append(parts, part.(s3.Part))
  197. case error:
  198. // abort multi upload
  199. log.Printf("Error during upload, aborting %s.", part.(error).Error())
  200. err = part.(error)
  201. multi.Abort()
  202. return
  203. }
  204. }
  205. log.Printf("Completing upload %d parts", len(parts))
  206. if err = multi.Complete(parts); err != nil {
  207. log.Printf("Error during completing upload %d parts %s", len(parts), err.Error())
  208. return
  209. }
  210. log.Printf("Completed uploading %d", len(parts))
  211. return
  212. }