You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

269 lines
5.8 KiB

  1. package main
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "log"
  7. "mime"
  8. "os"
  9. "path/filepath"
  10. "strconv"
  11. "sync"
  12. "github.com/goamz/goamz/s3"
  13. )
  14. type Storage interface {
  15. Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error)
  16. Head(token string, filename string) (contentType string, contentLength uint64, err error)
  17. Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error
  18. IsNotExist(err error) bool
  19. }
  20. type LocalStorage struct {
  21. Storage
  22. basedir string
  23. }
  24. func NewLocalStorage(basedir string) (*LocalStorage, error) {
  25. return &LocalStorage{basedir: basedir}, nil
  26. }
  27. func (s *LocalStorage) Head(token string, filename string) (contentType string, contentLength uint64, err error) {
  28. path := filepath.Join(s.basedir, token, filename)
  29. var fi os.FileInfo
  30. if fi, err = os.Lstat(path); err != nil {
  31. return
  32. }
  33. contentLength = uint64(fi.Size())
  34. contentType = mime.TypeByExtension(filepath.Ext(filename))
  35. return
  36. }
  37. func (s *LocalStorage) Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error) {
  38. path := filepath.Join(s.basedir, token, filename)
  39. // content type , content length
  40. if reader, err = os.Open(path); err != nil {
  41. return
  42. }
  43. var fi os.FileInfo
  44. if fi, err = os.Lstat(path); err != nil {
  45. return
  46. }
  47. contentLength = uint64(fi.Size())
  48. contentType = mime.TypeByExtension(filepath.Ext(filename))
  49. return
  50. }
  51. func (s *LocalStorage) IsNotExist(err error) bool {
  52. return os.IsNotExist(err)
  53. }
  54. func (s *LocalStorage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) error {
  55. var f io.WriteCloser
  56. var err error
  57. path := filepath.Join(s.basedir, token)
  58. if err = os.Mkdir(path, 0700); err != nil && !os.IsExist(err) {
  59. return err
  60. }
  61. if f, err = os.OpenFile(filepath.Join(path, filename), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600); err != nil {
  62. fmt.Printf("%s", err)
  63. return err
  64. }
  65. defer f.Close()
  66. if _, err = io.Copy(f, reader); err != nil {
  67. return err
  68. }
  69. return nil
  70. }
  71. type S3Storage struct {
  72. Storage
  73. bucket *s3.Bucket
  74. }
  75. func NewS3Storage() (*S3Storage, error) {
  76. bucket, err := getBucket()
  77. if err != nil {
  78. return nil, err
  79. }
  80. return &S3Storage{bucket: bucket}, nil
  81. }
  82. func (s *S3Storage) Head(token string, filename string) (contentType string, contentLength uint64, err error) {
  83. key := fmt.Sprintf("%s/%s", token, filename)
  84. // content type , content length
  85. response, err := s.bucket.Head(key, map[string][]string{})
  86. if err != nil {
  87. return
  88. }
  89. contentType = response.Header.Get("Content-Type")
  90. contentLength, err = strconv.ParseUint(response.Header.Get("Content-Length"), 10, 0)
  91. if err != nil {
  92. return
  93. }
  94. return
  95. }
  96. func (s *S3Storage) IsNotExist(err error) bool {
  97. log.Printf("IsNotExist: %s, %#v", err.Error(), err)
  98. b := (err.Error() == "The specified key does not exist.")
  99. b = b || (err.Error() == "Access Denied")
  100. return b
  101. }
  102. func (s *S3Storage) Get(token string, filename string) (reader io.ReadCloser, contentType string, contentLength uint64, err error) {
  103. key := fmt.Sprintf("%s/%s", token, filename)
  104. // content type , content length
  105. response, err := s.bucket.GetResponse(key)
  106. if err != nil {
  107. return
  108. }
  109. contentType = response.Header.Get("Content-Type")
  110. contentLength, err = strconv.ParseUint(response.Header.Get("Content-Length"), 10, 0)
  111. if err != nil {
  112. return
  113. }
  114. reader = response.Body
  115. return
  116. }
  117. func (s *S3Storage) Put(token string, filename string, reader io.Reader, contentType string, contentLength uint64) (err error) {
  118. key := fmt.Sprintf("%s/%s", token, filename)
  119. var (
  120. multi *s3.Multi
  121. parts []s3.Part
  122. )
  123. if multi, err = s.bucket.InitMulti(key, contentType, s3.Private); err != nil {
  124. log.Printf(err.Error())
  125. return
  126. }
  127. // 20 mb parts
  128. partsChan := make(chan interface{})
  129. // partsChan := make(chan s3.Part)
  130. go func() {
  131. // maximize to 20 threads
  132. sem := make(chan int, 20)
  133. index := 1
  134. var wg sync.WaitGroup
  135. for {
  136. // buffered in memory because goamz s3 multi needs seekable reader
  137. var (
  138. buffer []byte = make([]byte, (1<<20)*10)
  139. count int
  140. err error
  141. )
  142. // Amazon expects parts of at least 5MB, except for the last one
  143. if count, err = io.ReadAtLeast(reader, buffer, (1<<20)*5); err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
  144. log.Printf(err.Error())
  145. return
  146. }
  147. // always send minimal 1 part
  148. if err == io.EOF && index > 1 {
  149. log.Printf("Waiting for all parts to finish uploading.")
  150. // wait for all parts to be finished uploading
  151. wg.Wait()
  152. // and close the channel
  153. close(partsChan)
  154. return
  155. }
  156. wg.Add(1)
  157. sem <- 1
  158. // using goroutines because of retries when upload fails
  159. go func(multi *s3.Multi, buffer []byte, index int) {
  160. log.Printf("Uploading part %d %d", index, len(buffer))
  161. defer func() {
  162. log.Printf("Finished part %d %d", index, len(buffer))
  163. wg.Done()
  164. <-sem
  165. }()
  166. partReader := bytes.NewReader(buffer)
  167. var part s3.Part
  168. if part, err = multi.PutPart(index, partReader); err != nil {
  169. log.Printf("Error while uploading part %d %d %s", index, len(buffer), err.Error())
  170. partsChan <- err
  171. return
  172. }
  173. log.Printf("Finished uploading part %d %d", index, len(buffer))
  174. partsChan <- part
  175. }(multi, buffer[:count], index)
  176. index++
  177. }
  178. }()
  179. // wait for all parts to be uploaded
  180. for part := range partsChan {
  181. switch part.(type) {
  182. case s3.Part:
  183. parts = append(parts, part.(s3.Part))
  184. case error:
  185. // abort multi upload
  186. log.Printf("Error during upload, aborting %s.", part.(error).Error())
  187. err = part.(error)
  188. multi.Abort()
  189. return
  190. }
  191. }
  192. log.Printf("Completing upload %d parts", len(parts))
  193. if err = multi.Complete(parts); err != nil {
  194. log.Printf("Error during completing upload %d parts %s", len(parts), err.Error())
  195. return
  196. }
  197. log.Printf("Completed uploading %d", len(parts))
  198. return
  199. }