// Copyright 2018, OpenCensus Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ochttp import ( "bytes" "context" "encoding/hex" "errors" "fmt" "io" "io/ioutil" "log" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "go.opencensus.io/plugin/ochttp/propagation/b3" "go.opencensus.io/plugin/ochttp/propagation/tracecontext" "go.opencensus.io/trace" ) type testExporter struct { spans []*trace.SpanData } func (t *testExporter) ExportSpan(s *trace.SpanData) { t.spans = append(t.spans, s) } type testTransport struct { ch chan *http.Request } func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { t.ch <- req return nil, errors.New("noop") } type testPropagator struct{} func (t testPropagator) SpanContextFromRequest(req *http.Request) (sc trace.SpanContext, ok bool) { header := req.Header.Get("trace") buf, err := hex.DecodeString(header) if err != nil { log.Fatalf("Cannot decode trace header: %q", header) } r := bytes.NewReader(buf) r.Read(sc.TraceID[:]) r.Read(sc.SpanID[:]) opts, err := r.ReadByte() if err != nil { log.Fatalf("Cannot read trace options from trace header: %q", header) } sc.TraceOptions = trace.TraceOptions(opts) return sc, true } func (t testPropagator) SpanContextToRequest(sc trace.SpanContext, req *http.Request) { var buf bytes.Buffer buf.Write(sc.TraceID[:]) buf.Write(sc.SpanID[:]) buf.WriteByte(byte(sc.TraceOptions)) req.Header.Set("trace", hex.EncodeToString(buf.Bytes())) } func TestTransport_RoundTrip_Race(t *testing.T) { // This tests that we don't modify the request in accordance with the // specification for http.RoundTripper. // We attempt to trigger a race by reading the request from a separate // goroutine. If the request is modified by Transport, this should trigger // the race detector. transport := &testTransport{ch: make(chan *http.Request, 1)} rt := &Transport{ Propagation: &testPropagator{}, Base: transport, } req, _ := http.NewRequest("GET", "http://foo.com", nil) go func() { fmt.Println(*req) }() rt.RoundTrip(req) _ = <-transport.ch } func TestTransport_RoundTrip(t *testing.T) { _, parent := trace.StartSpan(context.Background(), "parent") tests := []struct { name string parent *trace.Span }{ { name: "no parent", parent: nil, }, { name: "parent", parent: parent, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { transport := &testTransport{ch: make(chan *http.Request, 1)} rt := &Transport{ Propagation: &testPropagator{}, Base: transport, } req, _ := http.NewRequest("GET", "http://foo.com", nil) if tt.parent != nil { req = req.WithContext(trace.NewContext(req.Context(), tt.parent)) } rt.RoundTrip(req) req = <-transport.ch span := trace.FromContext(req.Context()) if header := req.Header.Get("trace"); header == "" { t.Fatalf("Trace header = empty; want valid trace header") } if span == nil { t.Fatalf("Got no spans in req context; want one") } if tt.parent != nil { if got, want := span.SpanContext().TraceID, tt.parent.SpanContext().TraceID; got != want { t.Errorf("span.SpanContext().TraceID=%v; want %v", got, want) } } }) } } func TestHandler(t *testing.T) { traceID := [16]byte{16, 84, 69, 170, 120, 67, 188, 139, 242, 6, 177, 32, 0, 16, 0, 0} tests := []struct { header string wantTraceID trace.TraceID wantTraceOptions trace.TraceOptions }{ { header: "105445aa7843bc8bf206b12000100000000000000000000000", wantTraceID: traceID, wantTraceOptions: trace.TraceOptions(0), }, { header: "105445aa7843bc8bf206b12000100000000000000000000001", wantTraceID: traceID, wantTraceOptions: trace.TraceOptions(1), }, } for _, tt := range tests { t.Run(tt.header, func(t *testing.T) { handler := &Handler{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.FromContext(r.Context()) sc := span.SpanContext() if got, want := sc.TraceID, tt.wantTraceID; got != want { t.Errorf("TraceID = %q; want %q", got, want) } if got, want := sc.TraceOptions, tt.wantTraceOptions; got != want { t.Errorf("TraceOptions = %v; want %v", got, want) } }), StartOptions: trace.StartOptions{Sampler: trace.ProbabilitySampler(0.0)}, Propagation: &testPropagator{}, } req, _ := http.NewRequest("GET", "http://foo.com", nil) req.Header.Add("trace", tt.header) handler.ServeHTTP(nil, req) }) } } var _ http.RoundTripper = (*traceTransport)(nil) type collector []*trace.SpanData func (c *collector) ExportSpan(s *trace.SpanData) { *c = append(*c, s) } func TestEndToEnd(t *testing.T) { tc := []struct { name string handler *Handler transport *Transport wantSameTraceID bool wantLinks bool // expect a link between client and server span }{ { name: "internal default propagation", handler: &Handler{}, transport: &Transport{}, wantSameTraceID: true, }, { name: "external default propagation", handler: &Handler{IsPublicEndpoint: true}, transport: &Transport{}, wantSameTraceID: false, wantLinks: true, }, { name: "internal TraceContext propagation", handler: &Handler{Propagation: &tracecontext.HTTPFormat{}}, transport: &Transport{Propagation: &tracecontext.HTTPFormat{}}, wantSameTraceID: true, }, { name: "misconfigured propagation", handler: &Handler{IsPublicEndpoint: true, Propagation: &tracecontext.HTTPFormat{}}, transport: &Transport{Propagation: &b3.HTTPFormat{}}, wantSameTraceID: false, wantLinks: false, }, } for _, tt := range tc { t.Run(tt.name, func(t *testing.T) { var spans collector trace.RegisterExporter(&spans) defer trace.UnregisterExporter(&spans) // Start the server. serverDone := make(chan struct{}) serverReturn := make(chan time.Time) tt.handler.StartOptions.Sampler = trace.AlwaysSample() url := serveHTTP(tt.handler, serverDone, serverReturn) ctx := context.Background() // Make the request. req, err := http.NewRequest( http.MethodPost, fmt.Sprintf("%s/example/url/path?qparam=val", url), strings.NewReader("expected-request-body")) if err != nil { t.Fatal(err) } req = req.WithContext(ctx) tt.transport.StartOptions.Sampler = trace.AlwaysSample() c := &http.Client{ Transport: tt.transport, } resp, err := c.Do(req) if err != nil { t.Fatal(err) } if resp.StatusCode != http.StatusOK { t.Fatalf("resp.StatusCode = %d", resp.StatusCode) } // Tell the server to return from request handling. serverReturn <- time.Now().Add(time.Millisecond) respBody, err := ioutil.ReadAll(resp.Body) if err != nil { t.Fatal(err) } if got, want := string(respBody), "expected-response"; got != want { t.Fatalf("respBody = %q; want %q", got, want) } resp.Body.Close() <-serverDone trace.UnregisterExporter(&spans) if got, want := len(spans), 2; got != want { t.Fatalf("len(spans) = %d; want %d", got, want) } var client, server *trace.SpanData for _, sp := range spans { switch sp.SpanKind { case trace.SpanKindClient: client = sp if got, want := client.Name, "/example/url/path"; got != want { t.Errorf("Span name: %q; want %q", got, want) } case trace.SpanKindServer: server = sp if got, want := server.Name, "/example/url/path"; got != want { t.Errorf("Span name: %q; want %q", got, want) } default: t.Fatalf("server or client span missing; kind = %v", sp.SpanKind) } } if tt.wantSameTraceID { if server.TraceID != client.TraceID { t.Errorf("TraceID does not match: server.TraceID=%q client.TraceID=%q", server.TraceID, client.TraceID) } if !server.HasRemoteParent { t.Errorf("server span should have remote parent") } if server.ParentSpanID != client.SpanID { t.Errorf("server span should have client span as parent") } } if !tt.wantSameTraceID { if server.TraceID == client.TraceID { t.Errorf("TraceID should not be trusted") } } if tt.wantLinks { if got, want := len(server.Links), 1; got != want { t.Errorf("len(server.Links) = %d; want %d", got, want) } else { link := server.Links[0] if got, want := link.Type, trace.LinkTypeParent; got != want { t.Errorf("link.Type = %v; want %v", got, want) } } } if server.StartTime.Before(client.StartTime) { t.Errorf("server span starts before client span") } if server.EndTime.After(client.EndTime) { t.Errorf("client span ends before server span") } }) } } func serveHTTP(handler *Handler, done chan struct{}, wait chan time.Time) string { handler.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) w.(http.Flusher).Flush() // Simulate a slow-responding server. sleepUntil := <-wait for time.Now().Before(sleepUntil) { time.Sleep(sleepUntil.Sub(time.Now())) } io.WriteString(w, "expected-response") close(done) }) server := httptest.NewServer(handler) go func() { <-done server.Close() }() return server.URL } func TestSpanNameFromURL(t *testing.T) { tests := []struct { u string want string }{ { u: "http://localhost:80/hello?q=a", want: "/hello", }, { u: "/a/b?q=c", want: "/a/b", }, } for _, tt := range tests { t.Run(tt.u, func(t *testing.T) { req, err := http.NewRequest("GET", tt.u, nil) if err != nil { t.Errorf("url issue = %v", err) } if got := spanNameFromURL(req); got != tt.want { t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) } }) } } func TestFormatSpanName(t *testing.T) { formatSpanName := func(r *http.Request) string { return r.Method + " " + r.URL.Path } handler := &Handler{ Handler: http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { resp.Write([]byte("Hello, world!")) }), FormatSpanName: formatSpanName, } server := httptest.NewServer(handler) defer server.Close() client := &http.Client{ Transport: &Transport{ FormatSpanName: formatSpanName, StartOptions: trace.StartOptions{ Sampler: trace.AlwaysSample(), }, }, } tests := []struct { u string want string }{ { u: "/hello?q=a", want: "GET /hello", }, { u: "/a/b?q=c", want: "GET /a/b", }, } for _, tt := range tests { t.Run(tt.u, func(t *testing.T) { var te testExporter trace.RegisterExporter(&te) res, err := client.Get(server.URL + tt.u) if err != nil { t.Fatalf("error creating request: %v", err) } res.Body.Close() trace.UnregisterExporter(&te) if want, got := 2, len(te.spans); want != got { t.Fatalf("got exported spans %#v, wanted two spans", te.spans) } if got := te.spans[0].Name; got != tt.want { t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) } if got := te.spans[1].Name; got != tt.want { t.Errorf("spanNameFromURL() = %v, want %v", got, tt.want) } }) } } func TestRequestAttributes(t *testing.T) { tests := []struct { name string makeReq func() *http.Request wantAttrs []trace.Attribute }{ { name: "GET example.com/hello", makeReq: func() *http.Request { req, _ := http.NewRequest("GET", "http://example.com:779/hello", nil) req.Header.Add("User-Agent", "ua") return req }, wantAttrs: []trace.Attribute{ trace.StringAttribute("http.path", "/hello"), trace.StringAttribute("http.host", "example.com:779"), trace.StringAttribute("http.method", "GET"), trace.StringAttribute("http.user_agent", "ua"), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := tt.makeReq() attrs := requestAttrs(req) if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) { t.Errorf("Request attributes = %#v; want %#v", got, want) } }) } } func TestResponseAttributes(t *testing.T) { tests := []struct { name string resp *http.Response wantAttrs []trace.Attribute }{ { name: "non-zero HTTP 200 response", resp: &http.Response{StatusCode: 200}, wantAttrs: []trace.Attribute{ trace.Int64Attribute("http.status_code", 200), }, }, { name: "zero HTTP 500 response", resp: &http.Response{StatusCode: 500}, wantAttrs: []trace.Attribute{ trace.Int64Attribute("http.status_code", 500), }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { attrs := responseAttrs(tt.resp) if got, want := attrs, tt.wantAttrs; !reflect.DeepEqual(got, want) { t.Errorf("Response attributes = %#v; want %#v", got, want) } }) } } func TestStatusUnitTest(t *testing.T) { tests := []struct { in int want trace.Status }{ {200, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, {204, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, {100, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, {500, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, {404, trace.Status{Code: trace.StatusCodeNotFound, Message: `NOT_FOUND`}}, {600, trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, {401, trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}}, {403, trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}}, {301, trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, {501, trace.Status{Code: trace.StatusCodeUnimplemented, Message: `UNIMPLEMENTED`}}, } for _, tt := range tests { got, want := TraceStatus(tt.in, ""), tt.want if got != want { t.Errorf("status(%d) got = (%#v) want = (%#v)", tt.in, got, want) } } }