You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

135 lines
3.7 KiB

  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. //
  3. // Use of this source code is governed by a BSD-style
  4. // license that can be found in the LICENSE file or at
  5. // https://developers.google.com/open-source/licenses/bsd.
  6. // This file implements a http.RoundTripper that authenticates
  7. // requests issued against api.github.com endpoint.
  8. package httputil
  9. import (
  10. "log"
  11. "net/http"
  12. "net/url"
  13. "os"
  14. "cloud.google.com/go/compute/metadata"
  15. )
  16. // AuthTransport is an implementation of http.RoundTripper that authenticates
  17. // with the GitHub API.
  18. //
  19. // When both a token and client credentials are set, the latter is preferred.
  20. type AuthTransport struct {
  21. UserAgent string
  22. Token string
  23. ClientID string
  24. ClientSecret string
  25. Base http.RoundTripper
  26. }
  27. // NewAuthTransport gives new AuthTransport created with GitHub credentials
  28. // read from GCE metadata when the metadata server is accessible (we're on GCE)
  29. // or read from environment varialbes otherwise.
  30. func NewAuthTransport(base http.RoundTripper) *AuthTransport {
  31. if metadata.OnGCE() {
  32. return NewAuthTransportFromMetadata(base)
  33. }
  34. return NewAuthTransportFromEnvironment(base)
  35. }
  36. // NewAuthTransportFromEnvironment gives new AuthTransport created with GitHub
  37. // credentials read from environment variables.
  38. func NewAuthTransportFromEnvironment(base http.RoundTripper) *AuthTransport {
  39. return &AuthTransport{
  40. UserAgent: os.Getenv("USER_AGENT"),
  41. Token: os.Getenv("GITHUB_TOKEN"),
  42. ClientID: os.Getenv("GITHUB_CLIENT_ID"),
  43. ClientSecret: os.Getenv("GITHUB_CLIENT_SECRET"),
  44. Base: base,
  45. }
  46. }
  47. // NewAuthTransportFromMetadata gives new AuthTransport created with GitHub
  48. // credentials read from GCE metadata.
  49. func NewAuthTransportFromMetadata(base http.RoundTripper) *AuthTransport {
  50. return &AuthTransport{
  51. UserAgent: gceAttr("user-agent"),
  52. Token: gceAttr("github-token"),
  53. ClientID: gceAttr("github-client-id"),
  54. ClientSecret: gceAttr("github-client-secret"),
  55. Base: base,
  56. }
  57. }
  58. // RoundTrip implements the http.RoundTripper interface.
  59. func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  60. var reqCopy *http.Request
  61. if t.UserAgent != "" {
  62. reqCopy = copyRequest(req)
  63. reqCopy.Header.Set("User-Agent", t.UserAgent)
  64. }
  65. if req.URL.Host == "api.github.com" {
  66. switch {
  67. case t.ClientID != "" && t.ClientSecret != "":
  68. if reqCopy == nil {
  69. reqCopy = copyRequest(req)
  70. }
  71. if reqCopy.URL.RawQuery == "" {
  72. reqCopy.URL.RawQuery = "client_id=" + t.ClientID + "&client_secret=" + t.ClientSecret
  73. } else {
  74. reqCopy.URL.RawQuery += "&client_id=" + t.ClientID + "&client_secret=" + t.ClientSecret
  75. }
  76. case t.Token != "":
  77. if reqCopy == nil {
  78. reqCopy = copyRequest(req)
  79. }
  80. reqCopy.Header.Set("Authorization", "token "+t.Token)
  81. }
  82. }
  83. if reqCopy != nil {
  84. return t.base().RoundTrip(reqCopy)
  85. }
  86. return t.base().RoundTrip(req)
  87. }
  88. // CancelRequest cancels an in-flight request by closing its connection.
  89. func (t *AuthTransport) CancelRequest(req *http.Request) {
  90. type canceler interface {
  91. CancelRequest(req *http.Request)
  92. }
  93. if cr, ok := t.base().(canceler); ok {
  94. cr.CancelRequest(req)
  95. }
  96. }
  97. func (t *AuthTransport) base() http.RoundTripper {
  98. if t.Base != nil {
  99. return t.Base
  100. }
  101. return http.DefaultTransport
  102. }
  103. func gceAttr(name string) string {
  104. s, err := metadata.ProjectAttributeValue(name)
  105. if err != nil {
  106. log.Printf("error querying metadata for %q: %s", name, err)
  107. return ""
  108. }
  109. return s
  110. }
  111. func copyRequest(req *http.Request) *http.Request {
  112. req2 := new(http.Request)
  113. *req2 = *req
  114. req2.URL = new(url.URL)
  115. *req2.URL = *req.URL
  116. req2.Header = make(http.Header, len(req.Header))
  117. for k, s := range req.Header {
  118. req2.Header[k] = append([]string(nil), s...)
  119. }
  120. return req2
  121. }