// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "bufio" "crypto/rand" "encoding/hex" "fmt" "net" "net/http" "sync" ) // Context provides information and storage for a single request/response pair. // Contexts are linked to shared session that is used for multiple requests on // a single connection. type Context struct { session *Session id string mu sync.RWMutex vals map[string]interface{} skipRoundTrip bool skipLogging bool apiRequest bool } // Session provides information and storage about a connection. type Session struct { mu sync.RWMutex id string secure bool hijacked bool conn net.Conn brw *bufio.ReadWriter vals map[string]interface{} } var ( ctxmu sync.RWMutex ctxs = make(map[*http.Request]*Context) ) // NewContext returns a context for the in-flight HTTP request. func NewContext(req *http.Request) *Context { ctxmu.RLock() defer ctxmu.RUnlock() return ctxs[req] } // TestContext builds a new session and associated context and returns the // context and a function to remove the associated context. If it fails to // generate either a new session or a new context it will return an error. // Intended for tests only. func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) { ctxmu.Lock() defer ctxmu.Unlock() ctx, ok := ctxs[req] if ok { return ctx, func() { unlink(req) }, nil } s, err := newSession(conn, bw) if err != nil { return nil, nil, err } ctx, err = withSession(s) if err != nil { return nil, nil, err } ctxs[req] = ctx return ctx, func() { unlink(req) }, nil } // ID returns the session ID. func (s *Session) ID() string { s.mu.RLock() defer s.mu.RUnlock() return s.id } // IsSecure returns whether the current session is from a secure connection, // such as when receiving requests from a TLS connection that has been MITM'd. func (s *Session) IsSecure() bool { s.mu.RLock() defer s.mu.RUnlock() return s.secure } // MarkSecure marks the session as secure. func (s *Session) MarkSecure() { s.mu.Lock() defer s.mu.Unlock() s.secure = true } // MarkInsecure marks the session as insecure. func (s *Session) MarkInsecure() { s.mu.Lock() defer s.mu.Unlock() s.secure = false } // Hijack takes control of the connection from the proxy. No further action // will be taken by the proxy and the connection will be closed following the // return of the hijacker. func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) { s.mu.Lock() defer s.mu.Unlock() if s.hijacked { return nil, nil, fmt.Errorf("martian: session has already been hijacked") } s.hijacked = true return s.conn, s.brw, nil } // Hijacked returns whether the connection has been hijacked. func (s *Session) Hijacked() bool { s.mu.RLock() defer s.mu.RUnlock() return s.hijacked } // setConn resets the underlying connection and bufio.ReadWriter of the // session. Used by the proxy when the connection is upgraded to TLS. func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) { s.mu.Lock() defer s.mu.Unlock() s.conn = conn s.brw = brw } // Get takes key and returns the associated value from the session. func (s *Session) Get(key string) (interface{}, bool) { s.mu.RLock() defer s.mu.RUnlock() val, ok := s.vals[key] return val, ok } // Set takes a key and associates it with val in the session. The value is // persisted for the entire session across multiple requests and responses. func (s *Session) Set(key string, val interface{}) { s.mu.Lock() defer s.mu.Unlock() s.vals[key] = val } // Session returns the session for the context. func (ctx *Context) Session() *Session { return ctx.session } // ID returns the context ID. func (ctx *Context) ID() string { return ctx.id } // Get takes key and returns the associated value from the context. func (ctx *Context) Get(key string) (interface{}, bool) { ctx.mu.RLock() defer ctx.mu.RUnlock() val, ok := ctx.vals[key] return val, ok } // Set takes a key and associates it with val in the context. The value is // persisted for the duration of the request and is removed on the following // request. func (ctx *Context) Set(key string, val interface{}) { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.vals[key] = val } // SkipRoundTrip skips the round trip for the current request. func (ctx *Context) SkipRoundTrip() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.skipRoundTrip = true } // SkippingRoundTrip returns whether the current round trip will be skipped. func (ctx *Context) SkippingRoundTrip() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.skipRoundTrip } // SkipLogging skips logging by Martian loggers for the current request. func (ctx *Context) SkipLogging() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.skipLogging = true } // SkippingLogging returns whether the current request / response pair will be logged. func (ctx *Context) SkippingLogging() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.skipLogging } // APIRequest marks the requests as a request to the proxy API. func (ctx *Context) APIRequest() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.apiRequest = true } // IsAPIRequest returns true when the request patterns matches a pattern in the proxy // mux. The mux is usually defined as a parameter to the api.Forwarder, which uses // http.DefaultServeMux by default. func (ctx *Context) IsAPIRequest() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.apiRequest } // newID creates a new 16 character random hex ID; note these are not UUIDs. func newID() (string, error) { src := make([]byte, 8) if _, err := rand.Read(src); err != nil { return "", err } return hex.EncodeToString(src), nil } // link associates the context with request. func link(req *http.Request, ctx *Context) { ctxmu.Lock() defer ctxmu.Unlock() ctxs[req] = ctx } // unlink removes the context for request. func unlink(req *http.Request) { ctxmu.Lock() defer ctxmu.Unlock() delete(ctxs, req) } // newSession builds a new session. func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) { sid, err := newID() if err != nil { return nil, err } return &Session{ id: sid, conn: conn, brw: brw, vals: make(map[string]interface{}), }, nil } // withSession builds a new context from an existing session. Session must be // non-nil. func withSession(s *Session) (*Context, error) { cid, err := newID() if err != nil { return nil, err } return &Context{ session: s, id: cid, vals: make(map[string]interface{}), }, nil }