// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package gensupport import ( "context" "fmt" "io" "io/ioutil" "net/http" "reflect" "strings" "testing" ) type unexpectedReader struct{} func (unexpectedReader) Read([]byte) (int, error) { return 0, fmt.Errorf("unexpected read in test") } // event is an expected request/response pair type event struct { // the byte range header that should be present in a request. byteRange string // the http status code to send in response. responseStatus int } // interruptibleTransport is configured with a canned set of requests/responses. // It records the incoming data, unless the corresponding event is configured to return // http.StatusServiceUnavailable. type interruptibleTransport struct { events []event buf []byte bodies bodyTracker } // bodyTracker keeps track of response bodies that have not been closed. type bodyTracker map[io.ReadCloser]struct{} func (bt bodyTracker) Add(body io.ReadCloser) { bt[body] = struct{}{} } func (bt bodyTracker) Close(body io.ReadCloser) { delete(bt, body) } type trackingCloser struct { io.Reader tracker bodyTracker } func (tc *trackingCloser) Close() error { tc.tracker.Close(tc) return nil } func (tc *trackingCloser) Open() { tc.tracker.Add(tc) } func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) { ev := t.events[0] t.events = t.events[1:] if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want { return nil, fmt.Errorf("byte range: got %s; want %s", got, want) } if ev.responseStatus != http.StatusServiceUnavailable { buf, err := ioutil.ReadAll(req.Body) if err != nil { return nil, fmt.Errorf("error reading from request data: %v", err) } t.buf = append(t.buf, buf...) } tc := &trackingCloser{unexpectedReader{}, t.bodies} tc.Open() h := http.Header{} status := ev.responseStatus // Support "X-GUploader-No-308" like Google: if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" { status = 200 h.Set("X-Http-Status-Code-Override", "308") } res := &http.Response{ StatusCode: status, Header: h, Body: tc, } return res, nil } // progressRecorder records updates, and calls f for every invocation of ProgressUpdate. type progressRecorder struct { updates []int64 f func() } func (pr *progressRecorder) ProgressUpdate(current int64) { pr.updates = append(pr.updates, current) if pr.f != nil { pr.f() } } func TestInterruptedTransferChunks(t *testing.T) { type testCase struct { data string chunkSize int events []event wantProgress []int64 } for _, tc := range []testCase{ { data: strings.Repeat("a", 300), chunkSize: 90, events: []event{ {"bytes 0-89/*", http.StatusServiceUnavailable}, {"bytes 0-89/*", 308}, {"bytes 90-179/*", 308}, {"bytes 180-269/*", http.StatusServiceUnavailable}, {"bytes 180-269/*", 308}, {"bytes 270-299/300", 200}, }, wantProgress: []int64{90, 180, 270, 300}, }, { data: strings.Repeat("a", 20), chunkSize: 10, events: []event{ {"bytes 0-9/*", http.StatusServiceUnavailable}, {"bytes 0-9/*", 308}, {"bytes 10-19/*", http.StatusServiceUnavailable}, {"bytes 10-19/*", 308}, // 0 byte final request demands a byte range with leading asterix. {"bytes */20", http.StatusServiceUnavailable}, {"bytes */20", 200}, }, wantProgress: []int64{10, 20}, }, } { media := strings.NewReader(tc.data) tr := &interruptibleTransport{ buf: make([]byte, 0, len(tc.data)), events: tc.events, bodies: bodyTracker{}, } pr := progressRecorder{} rx := &ResumableUpload{ Client: &http.Client{Transport: tr}, Media: NewMediaBuffer(media, tc.chunkSize), MediaType: "text/plain", Callback: pr.ProgressUpdate, Backoff: NoPauseStrategy, } res, err := rx.Upload(context.Background()) if err == nil { res.Body.Close() } if err != nil || res == nil || res.StatusCode != http.StatusOK { if res == nil { t.Errorf("Upload not successful, res=nil: %v", err) } else { t.Errorf("Upload not successful, statusCode=%v: %v", res.StatusCode, err) } } if !reflect.DeepEqual(tr.buf, []byte(tc.data)) { t.Errorf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data) } if !reflect.DeepEqual(pr.updates, tc.wantProgress) { t.Errorf("progress updates: got %v, want %v", pr.updates, tc.wantProgress) } if len(tr.events) > 0 { t.Errorf("did not observe all expected events. leftover events: %v", tr.events) } if len(tr.bodies) > 0 { t.Errorf("unclosed request bodies: %v", tr.bodies) } } } func TestCancelUploadFast(t *testing.T) { const ( chunkSize = 90 mediaSize = 300 ) media := strings.NewReader(strings.Repeat("a", mediaSize)) tr := &interruptibleTransport{ buf: make([]byte, 0, mediaSize), } pr := progressRecorder{} rx := &ResumableUpload{ Client: &http.Client{Transport: tr}, Media: NewMediaBuffer(media, chunkSize), MediaType: "text/plain", Callback: pr.ProgressUpdate, Backoff: NoPauseStrategy, } ctx, cancelFunc := context.WithCancel(context.Background()) cancelFunc() // stop the upload that hasn't started yet res, err := rx.Upload(ctx) if err != context.Canceled { t.Errorf("Upload err: got: %v; want: context cancelled", err) } if res != nil { t.Errorf("Upload result: got: %v; want: nil", res) } if pr.updates != nil { t.Errorf("progress updates: got %v; want: nil", pr.updates) } } func TestCancelUpload(t *testing.T) { const ( chunkSize = 90 mediaSize = 300 ) media := strings.NewReader(strings.Repeat("a", mediaSize)) tr := &interruptibleTransport{ buf: make([]byte, 0, mediaSize), events: []event{ {"bytes 0-89/*", http.StatusServiceUnavailable}, {"bytes 0-89/*", 308}, {"bytes 90-179/*", 308}, {"bytes 180-269/*", 308}, // Upload should be cancelled before this event. }, bodies: bodyTracker{}, } ctx, cancelFunc := context.WithCancel(context.Background()) numUpdates := 0 pr := progressRecorder{f: func() { numUpdates++ if numUpdates >= 2 { cancelFunc() } }} rx := &ResumableUpload{ Client: &http.Client{Transport: tr}, Media: NewMediaBuffer(media, chunkSize), MediaType: "text/plain", Callback: pr.ProgressUpdate, Backoff: NoPauseStrategy, } res, err := rx.Upload(ctx) if err != context.Canceled { t.Errorf("Upload err: got: %v; want: context cancelled", err) } if res != nil { t.Errorf("Upload result: got: %v; want: nil", res) } if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) { t.Errorf("transferred contents:\ngot %s\nwant %s", got, want) } if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) { t.Errorf("progress updates: got %v; want: %v", got, want) } if len(tr.bodies) > 0 { t.Errorf("unclosed request bodies: %v", tr.bodies) } }