You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

313 lines
7.0 KiB

  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package martian
  15. import (
  16. "bufio"
  17. "crypto/rand"
  18. "encoding/hex"
  19. "fmt"
  20. "net"
  21. "net/http"
  22. "sync"
  23. )
  24. // Context provides information and storage for a single request/response pair.
  25. // Contexts are linked to shared session that is used for multiple requests on
  26. // a single connection.
  27. type Context struct {
  28. session *Session
  29. id string
  30. mu sync.RWMutex
  31. vals map[string]interface{}
  32. skipRoundTrip bool
  33. skipLogging bool
  34. apiRequest bool
  35. }
  36. // Session provides information and storage about a connection.
  37. type Session struct {
  38. mu sync.RWMutex
  39. id string
  40. secure bool
  41. hijacked bool
  42. conn net.Conn
  43. brw *bufio.ReadWriter
  44. vals map[string]interface{}
  45. }
  46. var (
  47. ctxmu sync.RWMutex
  48. ctxs = make(map[*http.Request]*Context)
  49. )
  50. // NewContext returns a context for the in-flight HTTP request.
  51. func NewContext(req *http.Request) *Context {
  52. ctxmu.RLock()
  53. defer ctxmu.RUnlock()
  54. return ctxs[req]
  55. }
  56. // TestContext builds a new session and associated context and returns the
  57. // context and a function to remove the associated context. If it fails to
  58. // generate either a new session or a new context it will return an error.
  59. // Intended for tests only.
  60. func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) {
  61. ctxmu.Lock()
  62. defer ctxmu.Unlock()
  63. ctx, ok := ctxs[req]
  64. if ok {
  65. return ctx, func() { unlink(req) }, nil
  66. }
  67. s, err := newSession(conn, bw)
  68. if err != nil {
  69. return nil, nil, err
  70. }
  71. ctx, err = withSession(s)
  72. if err != nil {
  73. return nil, nil, err
  74. }
  75. ctxs[req] = ctx
  76. return ctx, func() { unlink(req) }, nil
  77. }
  78. // ID returns the session ID.
  79. func (s *Session) ID() string {
  80. s.mu.RLock()
  81. defer s.mu.RUnlock()
  82. return s.id
  83. }
  84. // IsSecure returns whether the current session is from a secure connection,
  85. // such as when receiving requests from a TLS connection that has been MITM'd.
  86. func (s *Session) IsSecure() bool {
  87. s.mu.RLock()
  88. defer s.mu.RUnlock()
  89. return s.secure
  90. }
  91. // MarkSecure marks the session as secure.
  92. func (s *Session) MarkSecure() {
  93. s.mu.Lock()
  94. defer s.mu.Unlock()
  95. s.secure = true
  96. }
  97. // MarkInsecure marks the session as insecure.
  98. func (s *Session) MarkInsecure() {
  99. s.mu.Lock()
  100. defer s.mu.Unlock()
  101. s.secure = false
  102. }
  103. // Hijack takes control of the connection from the proxy. No further action
  104. // will be taken by the proxy and the connection will be closed following the
  105. // return of the hijacker.
  106. func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  107. s.mu.Lock()
  108. defer s.mu.Unlock()
  109. if s.hijacked {
  110. return nil, nil, fmt.Errorf("martian: session has already been hijacked")
  111. }
  112. s.hijacked = true
  113. return s.conn, s.brw, nil
  114. }
  115. // Hijacked returns whether the connection has been hijacked.
  116. func (s *Session) Hijacked() bool {
  117. s.mu.RLock()
  118. defer s.mu.RUnlock()
  119. return s.hijacked
  120. }
  121. // setConn resets the underlying connection and bufio.ReadWriter of the
  122. // session. Used by the proxy when the connection is upgraded to TLS.
  123. func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) {
  124. s.mu.Lock()
  125. defer s.mu.Unlock()
  126. s.conn = conn
  127. s.brw = brw
  128. }
  129. // Get takes key and returns the associated value from the session.
  130. func (s *Session) Get(key string) (interface{}, bool) {
  131. s.mu.RLock()
  132. defer s.mu.RUnlock()
  133. val, ok := s.vals[key]
  134. return val, ok
  135. }
  136. // Set takes a key and associates it with val in the session. The value is
  137. // persisted for the entire session across multiple requests and responses.
  138. func (s *Session) Set(key string, val interface{}) {
  139. s.mu.Lock()
  140. defer s.mu.Unlock()
  141. s.vals[key] = val
  142. }
  143. // Session returns the session for the context.
  144. func (ctx *Context) Session() *Session {
  145. return ctx.session
  146. }
  147. // ID returns the context ID.
  148. func (ctx *Context) ID() string {
  149. return ctx.id
  150. }
  151. // Get takes key and returns the associated value from the context.
  152. func (ctx *Context) Get(key string) (interface{}, bool) {
  153. ctx.mu.RLock()
  154. defer ctx.mu.RUnlock()
  155. val, ok := ctx.vals[key]
  156. return val, ok
  157. }
  158. // Set takes a key and associates it with val in the context. The value is
  159. // persisted for the duration of the request and is removed on the following
  160. // request.
  161. func (ctx *Context) Set(key string, val interface{}) {
  162. ctx.mu.Lock()
  163. defer ctx.mu.Unlock()
  164. ctx.vals[key] = val
  165. }
  166. // SkipRoundTrip skips the round trip for the current request.
  167. func (ctx *Context) SkipRoundTrip() {
  168. ctx.mu.Lock()
  169. defer ctx.mu.Unlock()
  170. ctx.skipRoundTrip = true
  171. }
  172. // SkippingRoundTrip returns whether the current round trip will be skipped.
  173. func (ctx *Context) SkippingRoundTrip() bool {
  174. ctx.mu.RLock()
  175. defer ctx.mu.RUnlock()
  176. return ctx.skipRoundTrip
  177. }
  178. // SkipLogging skips logging by Martian loggers for the current request.
  179. func (ctx *Context) SkipLogging() {
  180. ctx.mu.Lock()
  181. defer ctx.mu.Unlock()
  182. ctx.skipLogging = true
  183. }
  184. // SkippingLogging returns whether the current request / response pair will be logged.
  185. func (ctx *Context) SkippingLogging() bool {
  186. ctx.mu.RLock()
  187. defer ctx.mu.RUnlock()
  188. return ctx.skipLogging
  189. }
  190. // APIRequest marks the requests as a request to the proxy API.
  191. func (ctx *Context) APIRequest() {
  192. ctx.mu.Lock()
  193. defer ctx.mu.Unlock()
  194. ctx.apiRequest = true
  195. }
  196. // IsAPIRequest returns true when the request patterns matches a pattern in the proxy
  197. // mux. The mux is usually defined as a parameter to the api.Forwarder, which uses
  198. // http.DefaultServeMux by default.
  199. func (ctx *Context) IsAPIRequest() bool {
  200. ctx.mu.RLock()
  201. defer ctx.mu.RUnlock()
  202. return ctx.apiRequest
  203. }
  204. // newID creates a new 16 character random hex ID; note these are not UUIDs.
  205. func newID() (string, error) {
  206. src := make([]byte, 8)
  207. if _, err := rand.Read(src); err != nil {
  208. return "", err
  209. }
  210. return hex.EncodeToString(src), nil
  211. }
  212. // link associates the context with request.
  213. func link(req *http.Request, ctx *Context) {
  214. ctxmu.Lock()
  215. defer ctxmu.Unlock()
  216. ctxs[req] = ctx
  217. }
  218. // unlink removes the context for request.
  219. func unlink(req *http.Request) {
  220. ctxmu.Lock()
  221. defer ctxmu.Unlock()
  222. delete(ctxs, req)
  223. }
  224. // newSession builds a new session.
  225. func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) {
  226. sid, err := newID()
  227. if err != nil {
  228. return nil, err
  229. }
  230. return &Session{
  231. id: sid,
  232. conn: conn,
  233. brw: brw,
  234. vals: make(map[string]interface{}),
  235. }, nil
  236. }
  237. // withSession builds a new context from an existing session. Session must be
  238. // non-nil.
  239. func withSession(s *Session) (*Context, error) {
  240. cid, err := newID()
  241. if err != nil {
  242. return nil, err
  243. }
  244. return &Context{
  245. session: s,
  246. id: cid,
  247. vals: make(map[string]interface{}),
  248. }, nil
  249. }