No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.
 
 
 

152 líneas
3.7 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // +build go1.7
  15. package trace
  16. import (
  17. "io/ioutil"
  18. "net/http"
  19. "net/http/httptest"
  20. "strings"
  21. "testing"
  22. )
  23. type recorderTransport struct {
  24. ch chan *http.Request
  25. }
  26. func (rt *recorderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  27. rt.ch <- req
  28. resp := &http.Response{
  29. Status: "200 OK",
  30. StatusCode: 200,
  31. Body: ioutil.NopCloser(strings.NewReader("{}")),
  32. }
  33. return resp, nil
  34. }
  35. func TestNewHTTPClient(t *testing.T) {
  36. rt := &recorderTransport{
  37. ch: make(chan *http.Request, 1),
  38. }
  39. tc := newTestClient(&noopTransport{})
  40. client := &http.Client{
  41. Transport: &Transport{
  42. Base: rt,
  43. },
  44. }
  45. req, _ := http.NewRequest("GET", "http://example.com", nil)
  46. t.Run("NoTrace", func(t *testing.T) {
  47. _, err := client.Do(req)
  48. if err != nil {
  49. t.Error(err)
  50. }
  51. outgoing := <-rt.ch
  52. if got, want := outgoing.Header.Get(httpHeader), ""; want != got {
  53. t.Errorf("got trace header = %q; want none", got)
  54. }
  55. })
  56. t.Run("Trace", func(t *testing.T) {
  57. span := tc.NewSpan("/foo")
  58. req = req.WithContext(NewContext(req.Context(), span))
  59. _, err := client.Do(req)
  60. if err != nil {
  61. t.Error(err)
  62. }
  63. outgoing := <-rt.ch
  64. s := tc.SpanFromHeader("/foo", outgoing.Header.Get(httpHeader))
  65. if got, want := s.TraceID(), span.TraceID(); got != want {
  66. t.Errorf("trace ID = %q; want %q", got, want)
  67. }
  68. })
  69. }
  70. func TestHTTPHandlerNoTrace(t *testing.T) {
  71. tc := newTestClient(&noopTransport{})
  72. client := &http.Client{
  73. Transport: &Transport{},
  74. }
  75. handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  76. span := FromContext(r.Context())
  77. if span == nil {
  78. t.Errorf("span is nil; want non-nil span")
  79. }
  80. }))
  81. ts := httptest.NewServer(handler)
  82. defer ts.Close()
  83. req, _ := http.NewRequest("GET", ts.URL, nil)
  84. _, err := client.Do(req)
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. }
  89. func TestHTTPHandler_response(t *testing.T) {
  90. tc := newTestClient(&noopTransport{})
  91. p, _ := NewLimitedSampler(1, 1<<32) // all
  92. tc.SetSamplingPolicy(p)
  93. handler := tc.HTTPHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
  94. ts := httptest.NewServer(handler)
  95. defer ts.Close()
  96. tests := []struct {
  97. name string
  98. traceHeader string
  99. wantTraceHeader string
  100. }{
  101. {
  102. name: "no global",
  103. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123",
  104. wantTraceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1",
  105. },
  106. {
  107. name: "global=1",
  108. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=1",
  109. wantTraceHeader: "",
  110. },
  111. {
  112. name: "global=0",
  113. traceHeader: "0123456789ABCDEF0123456789ABCDEF/123;o=0",
  114. wantTraceHeader: "",
  115. },
  116. {
  117. name: "no trace context",
  118. traceHeader: "",
  119. wantTraceHeader: "",
  120. },
  121. }
  122. for _, tt := range tests {
  123. req, _ := http.NewRequest("GET", ts.URL, nil)
  124. req.Header.Set(httpHeader, tt.traceHeader)
  125. res, err := http.DefaultClient.Do(req)
  126. if err != nil {
  127. t.Errorf("failed to request: %v", err)
  128. }
  129. if got, want := res.Header.Get(httpHeader), tt.wantTraceHeader; got != want {
  130. t.Errorf("%v: response context header = %q; want %q", tt.name, got, want)
  131. }
  132. }
  133. }