You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

150 lines
3.7 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package trace
  15. import (
  16. "io/ioutil"
  17. "net/http"
  18. "net/http/httptest"
  19. "strings"
  20. "testing"
  21. )
  22. type recorderTransport struct {
  23. ch chan *http.Request
  24. }
  25. func (rt *recorderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  26. rt.ch <- req
  27. resp := &http.Response{
  28. Status: "200 OK",
  29. StatusCode: 200,
  30. Body: ioutil.NopCloser(strings.NewReader("{}")),
  31. }
  32. return resp, nil
  33. }
  34. func TestNewHTTPClient(t *testing.T) {
  35. rt := &recorderTransport{
  36. ch: make(chan *http.Request, 1),
  37. }
  38. tc := newTestClient(&noopTransport{})
  39. client := &http.Client{
  40. Transport: &Transport{
  41. Base: rt,
  42. },
  43. }
  44. req, _ := http.NewRequest("GET", "http://example.com", nil)
  45. t.Run("NoTrace", func(t *testing.T) {
  46. _, err := client.Do(req)
  47. if err != nil {
  48. t.Error(err)
  49. }
  50. outgoing := <-rt.ch
  51. if got, want := outgoing.Header.Get(httpHeader), ""; want != got {
  52. t.Errorf("got trace header = %q; want none", got)
  53. }
  54. })
  55. t.Run("Trace", func(t *testing.T) {
  56. span := tc.NewSpan("/foo")
  57. req = req.WithContext(NewContext(req.Context(), span))
  58. _, err := client.Do(req)
  59. if err != nil {
  60. t.Error(err)
  61. }
  62. outgoing := <-rt.ch
  63. s := tc.SpanFromHeader("/foo", outgoing.Header.Get(httpHeader))
  64. if got, want := s.TraceID(), span.TraceID(); got != want {
  65. t.Errorf("trace ID = %q; want %q", got, want)
  66. }
  67. })
  68. }
  69. func TestHTTPHandlerNoTrace(t *testing.T) {
  70. tc := newTestClient(&noopTransport{})
  71. client := &http.Client{
  72. Transport: &Transport{},
  73. }
  74. handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  75. span := FromContext(r.Context())
  76. if span == nil {
  77. t.Errorf("span is nil; want non-nil span")
  78. }
  79. }))
  80. ts := httptest.NewServer(handler)
  81. defer ts.Close()
  82. req, _ := http.NewRequest("GET", ts.URL, nil)
  83. _, err := client.Do(req)
  84. if err != nil {
  85. t.Fatal(err)
  86. }
  87. }
  88. func TestHTTPHandler_response(t *testing.T) {
  89. tc := newTestClient(&noopTransport{})
  90. p, _ := NewLimitedSampler(1, 1<<32) // all
  91. tc.SetSamplingPolicy(p)
  92. handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  93. ts := httptest.NewServer(handler)
  94. defer ts.Close()
  95. tests := []struct {
  96. name string
  97. traceHeader string
  98. wantTraceHeader string
  99. }{
  100. {
  101. name: "no global",
  102. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123",
  103. wantTraceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1",
  104. },
  105. {
  106. name: "global=1",
  107. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1",
  108. wantTraceHeader: "",
  109. },
  110. {
  111. name: "global=0",
  112. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=0",
  113. wantTraceHeader: "",
  114. },
  115. {
  116. name: "no trace context",
  117. traceHeader: "",
  118. wantTraceHeader: "",
  119. },
  120. }
  121. for _, tt := range tests {
  122. req, _ := http.NewRequest("GET", ts.URL, nil)
  123. req.Header.Set(httpHeader, tt.traceHeader)
  124. res, err := http.DefaultClient.Do(req)
  125. if err != nil {
  126. t.Errorf("failed to request: %v", err)
  127. }
  128. if got, want := res.Header.Get(httpHeader), tt.wantTraceHeader; got != want {
  129. t.Errorf("%v: response context header = %q; want %q", tt.name, got, want)
  130. }
  131. }
  132. }