You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

281 line
6.9 KiB

  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package gensupport
  5. import (
  6. "context"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net/http"
  11. "reflect"
  12. "strings"
  13. "testing"
  14. )
  15. type unexpectedReader struct{}
  16. func (unexpectedReader) Read([]byte) (int, error) {
  17. return 0, fmt.Errorf("unexpected read in test")
  18. }
  19. // event is an expected request/response pair
  20. type event struct {
  21. // the byte range header that should be present in a request.
  22. byteRange string
  23. // the http status code to send in response.
  24. responseStatus int
  25. }
  26. // interruptibleTransport is configured with a canned set of requests/responses.
  27. // It records the incoming data, unless the corresponding event is configured to return
  28. // http.StatusServiceUnavailable.
  29. type interruptibleTransport struct {
  30. events []event
  31. buf []byte
  32. bodies bodyTracker
  33. }
  34. // bodyTracker keeps track of response bodies that have not been closed.
  35. type bodyTracker map[io.ReadCloser]struct{}
  36. func (bt bodyTracker) Add(body io.ReadCloser) {
  37. bt[body] = struct{}{}
  38. }
  39. func (bt bodyTracker) Close(body io.ReadCloser) {
  40. delete(bt, body)
  41. }
  42. type trackingCloser struct {
  43. io.Reader
  44. tracker bodyTracker
  45. }
  46. func (tc *trackingCloser) Close() error {
  47. tc.tracker.Close(tc)
  48. return nil
  49. }
  50. func (tc *trackingCloser) Open() {
  51. tc.tracker.Add(tc)
  52. }
  53. func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  54. ev := t.events[0]
  55. t.events = t.events[1:]
  56. if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want {
  57. return nil, fmt.Errorf("byte range: got %s; want %s", got, want)
  58. }
  59. if ev.responseStatus != http.StatusServiceUnavailable {
  60. buf, err := ioutil.ReadAll(req.Body)
  61. if err != nil {
  62. return nil, fmt.Errorf("error reading from request data: %v", err)
  63. }
  64. t.buf = append(t.buf, buf...)
  65. }
  66. tc := &trackingCloser{unexpectedReader{}, t.bodies}
  67. tc.Open()
  68. h := http.Header{}
  69. status := ev.responseStatus
  70. // Support "X-GUploader-No-308" like Google:
  71. if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" {
  72. status = 200
  73. h.Set("X-Http-Status-Code-Override", "308")
  74. }
  75. res := &http.Response{
  76. StatusCode: status,
  77. Header: h,
  78. Body: tc,
  79. }
  80. return res, nil
  81. }
  82. // progressRecorder records updates, and calls f for every invocation of ProgressUpdate.
  83. type progressRecorder struct {
  84. updates []int64
  85. f func()
  86. }
  87. func (pr *progressRecorder) ProgressUpdate(current int64) {
  88. pr.updates = append(pr.updates, current)
  89. if pr.f != nil {
  90. pr.f()
  91. }
  92. }
  93. func TestInterruptedTransferChunks(t *testing.T) {
  94. type testCase struct {
  95. data string
  96. chunkSize int
  97. events []event
  98. wantProgress []int64
  99. }
  100. for _, tc := range []testCase{
  101. {
  102. data: strings.Repeat("a", 300),
  103. chunkSize: 90,
  104. events: []event{
  105. {"bytes 0-89/*", http.StatusServiceUnavailable},
  106. {"bytes 0-89/*", 308},
  107. {"bytes 90-179/*", 308},
  108. {"bytes 180-269/*", http.StatusServiceUnavailable},
  109. {"bytes 180-269/*", 308},
  110. {"bytes 270-299/300", 200},
  111. },
  112. wantProgress: []int64{90, 180, 270, 300},
  113. },
  114. {
  115. data: strings.Repeat("a", 20),
  116. chunkSize: 10,
  117. events: []event{
  118. {"bytes 0-9/*", http.StatusServiceUnavailable},
  119. {"bytes 0-9/*", 308},
  120. {"bytes 10-19/*", http.StatusServiceUnavailable},
  121. {"bytes 10-19/*", 308},
  122. // 0 byte final request demands a byte range with leading asterix.
  123. {"bytes */20", http.StatusServiceUnavailable},
  124. {"bytes */20", 200},
  125. },
  126. wantProgress: []int64{10, 20},
  127. },
  128. } {
  129. media := strings.NewReader(tc.data)
  130. tr := &interruptibleTransport{
  131. buf: make([]byte, 0, len(tc.data)),
  132. events: tc.events,
  133. bodies: bodyTracker{},
  134. }
  135. pr := progressRecorder{}
  136. rx := &ResumableUpload{
  137. Client: &http.Client{Transport: tr},
  138. Media: NewMediaBuffer(media, tc.chunkSize),
  139. MediaType: "text/plain",
  140. Callback: pr.ProgressUpdate,
  141. Backoff: NoPauseStrategy,
  142. }
  143. res, err := rx.Upload(context.Background())
  144. if err == nil {
  145. res.Body.Close()
  146. }
  147. if err != nil || res == nil || res.StatusCode != http.StatusOK {
  148. if res == nil {
  149. t.Errorf("Upload not successful, res=nil: %v", err)
  150. } else {
  151. t.Errorf("Upload not successful, statusCode=%v: %v", res.StatusCode, err)
  152. }
  153. }
  154. if !reflect.DeepEqual(tr.buf, []byte(tc.data)) {
  155. t.Errorf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data)
  156. }
  157. if !reflect.DeepEqual(pr.updates, tc.wantProgress) {
  158. t.Errorf("progress updates: got %v, want %v", pr.updates, tc.wantProgress)
  159. }
  160. if len(tr.events) > 0 {
  161. t.Errorf("did not observe all expected events. leftover events: %v", tr.events)
  162. }
  163. if len(tr.bodies) > 0 {
  164. t.Errorf("unclosed request bodies: %v", tr.bodies)
  165. }
  166. }
  167. }
  168. func TestCancelUploadFast(t *testing.T) {
  169. const (
  170. chunkSize = 90
  171. mediaSize = 300
  172. )
  173. media := strings.NewReader(strings.Repeat("a", mediaSize))
  174. tr := &interruptibleTransport{
  175. buf: make([]byte, 0, mediaSize),
  176. }
  177. pr := progressRecorder{}
  178. rx := &ResumableUpload{
  179. Client: &http.Client{Transport: tr},
  180. Media: NewMediaBuffer(media, chunkSize),
  181. MediaType: "text/plain",
  182. Callback: pr.ProgressUpdate,
  183. Backoff: NoPauseStrategy,
  184. }
  185. ctx, cancelFunc := context.WithCancel(context.Background())
  186. cancelFunc() // stop the upload that hasn't started yet
  187. res, err := rx.Upload(ctx)
  188. if err != context.Canceled {
  189. t.Errorf("Upload err: got: %v; want: context cancelled", err)
  190. }
  191. if res != nil {
  192. t.Errorf("Upload result: got: %v; want: nil", res)
  193. }
  194. if pr.updates != nil {
  195. t.Errorf("progress updates: got %v; want: nil", pr.updates)
  196. }
  197. }
  198. func TestCancelUpload(t *testing.T) {
  199. const (
  200. chunkSize = 90
  201. mediaSize = 300
  202. )
  203. media := strings.NewReader(strings.Repeat("a", mediaSize))
  204. tr := &interruptibleTransport{
  205. buf: make([]byte, 0, mediaSize),
  206. events: []event{
  207. {"bytes 0-89/*", http.StatusServiceUnavailable},
  208. {"bytes 0-89/*", 308},
  209. {"bytes 90-179/*", 308},
  210. {"bytes 180-269/*", 308}, // Upload should be cancelled before this event.
  211. },
  212. bodies: bodyTracker{},
  213. }
  214. ctx, cancelFunc := context.WithCancel(context.Background())
  215. numUpdates := 0
  216. pr := progressRecorder{f: func() {
  217. numUpdates++
  218. if numUpdates >= 2 {
  219. cancelFunc()
  220. }
  221. }}
  222. rx := &ResumableUpload{
  223. Client: &http.Client{Transport: tr},
  224. Media: NewMediaBuffer(media, chunkSize),
  225. MediaType: "text/plain",
  226. Callback: pr.ProgressUpdate,
  227. Backoff: NoPauseStrategy,
  228. }
  229. res, err := rx.Upload(ctx)
  230. if err != context.Canceled {
  231. t.Errorf("Upload err: got: %v; want: context cancelled", err)
  232. }
  233. if res != nil {
  234. t.Errorf("Upload result: got: %v; want: nil", res)
  235. }
  236. if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) {
  237. t.Errorf("transferred contents:\ngot %s\nwant %s", got, want)
  238. }
  239. if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) {
  240. t.Errorf("progress updates: got %v; want: %v", got, want)
  241. }
  242. if len(tr.bodies) > 0 {
  243. t.Errorf("unclosed request bodies: %v", tr.bodies)
  244. }
  245. }