Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 

169 wiersze
4.0 KiB

  1. package oauth2
  2. import (
  3. "errors"
  4. "io"
  5. "net/http"
  6. "net/http/httptest"
  7. "testing"
  8. "time"
  9. )
  10. type tokenSource struct{ token *Token }
  11. func (t *tokenSource) Token() (*Token, error) {
  12. return t.token, nil
  13. }
  14. func TestTransportNilTokenSource(t *testing.T) {
  15. tr := &Transport{}
  16. server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
  17. defer server.Close()
  18. client := &http.Client{Transport: tr}
  19. resp, err := client.Get(server.URL)
  20. if err == nil {
  21. t.Errorf("got no errors, want an error with nil token source")
  22. }
  23. if resp != nil {
  24. t.Errorf("Response = %v; want nil", resp)
  25. }
  26. }
  27. type readCloseCounter struct {
  28. CloseCount int
  29. ReadErr error
  30. }
  31. func (r *readCloseCounter) Read(b []byte) (int, error) {
  32. return 0, r.ReadErr
  33. }
  34. func (r *readCloseCounter) Close() error {
  35. r.CloseCount++
  36. return nil
  37. }
  38. func TestTransportCloseRequestBody(t *testing.T) {
  39. tr := &Transport{}
  40. server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
  41. defer server.Close()
  42. client := &http.Client{Transport: tr}
  43. body := &readCloseCounter{
  44. ReadErr: errors.New("readCloseCounter.Read not implemented"),
  45. }
  46. resp, err := client.Post(server.URL, "application/json", body)
  47. if err == nil {
  48. t.Errorf("got no errors, want an error with nil token source")
  49. }
  50. if resp != nil {
  51. t.Errorf("Response = %v; want nil", resp)
  52. }
  53. if expected := 1; body.CloseCount != expected {
  54. t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
  55. }
  56. }
  57. func TestTransportCloseRequestBodySuccess(t *testing.T) {
  58. tr := &Transport{
  59. Source: StaticTokenSource(&Token{
  60. AccessToken: "abc",
  61. }),
  62. }
  63. server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
  64. defer server.Close()
  65. client := &http.Client{Transport: tr}
  66. body := &readCloseCounter{
  67. ReadErr: io.EOF,
  68. }
  69. resp, err := client.Post(server.URL, "application/json", body)
  70. if err != nil {
  71. t.Errorf("got error %v; expected none", err)
  72. }
  73. if resp == nil {
  74. t.Errorf("Response is nil; expected non-nil")
  75. }
  76. if expected := 1; body.CloseCount != expected {
  77. t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
  78. }
  79. }
  80. func TestTransportTokenSource(t *testing.T) {
  81. ts := &tokenSource{
  82. token: &Token{
  83. AccessToken: "abc",
  84. },
  85. }
  86. tr := &Transport{
  87. Source: ts,
  88. }
  89. server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
  90. if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
  91. t.Errorf("Authorization header = %q; want %q", got, want)
  92. }
  93. })
  94. defer server.Close()
  95. client := &http.Client{Transport: tr}
  96. res, err := client.Get(server.URL)
  97. if err != nil {
  98. t.Fatal(err)
  99. }
  100. res.Body.Close()
  101. }
  102. // Test for case-sensitive token types, per https://github.com/golang/oauth2/issues/113
  103. func TestTransportTokenSourceTypes(t *testing.T) {
  104. const val = "abc"
  105. tests := []struct {
  106. key string
  107. val string
  108. want string
  109. }{
  110. {key: "bearer", val: val, want: "Bearer abc"},
  111. {key: "mac", val: val, want: "MAC abc"},
  112. {key: "basic", val: val, want: "Basic abc"},
  113. }
  114. for _, tc := range tests {
  115. ts := &tokenSource{
  116. token: &Token{
  117. AccessToken: tc.val,
  118. TokenType: tc.key,
  119. },
  120. }
  121. tr := &Transport{
  122. Source: ts,
  123. }
  124. server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
  125. if got, want := r.Header.Get("Authorization"), tc.want; got != want {
  126. t.Errorf("Authorization header (%q) = %q; want %q", val, got, want)
  127. }
  128. })
  129. defer server.Close()
  130. client := &http.Client{Transport: tr}
  131. res, err := client.Get(server.URL)
  132. if err != nil {
  133. t.Fatal(err)
  134. }
  135. res.Body.Close()
  136. }
  137. }
  138. func TestTokenValidNoAccessToken(t *testing.T) {
  139. token := &Token{}
  140. if token.Valid() {
  141. t.Errorf("got valid with no access token; want invalid")
  142. }
  143. }
  144. func TestExpiredWithExpiry(t *testing.T) {
  145. token := &Token{
  146. Expiry: time.Now().Add(-5 * time.Hour),
  147. }
  148. if token.Valid() {
  149. t.Errorf("got valid with expired token; want invalid")
  150. }
  151. }
  152. func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
  153. return httptest.NewServer(http.HandlerFunc(handler))
  154. }