|
- // Copyright 2015 Google Inc. All rights reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package header
-
- import (
- "crypto/rand"
- "fmt"
- "io"
- "net/http"
- "regexp"
- "strings"
-
- "github.com/google/martian"
- )
-
- const viaLoopKey = "via.LoopDetection"
-
- var whitespace = regexp.MustCompile("[\t ]+")
-
- // ViaModifier is a header modifier that checks for proxy redirect loops.
- type ViaModifier struct {
- requestedBy string
- boundary string
- }
-
- // NewViaModifier returns a new Via modifier.
- func NewViaModifier(requestedBy string) *ViaModifier {
- return &ViaModifier{
- requestedBy: requestedBy,
- boundary: randomBoundary(),
- }
- }
-
- // ModifyRequest sets the Via header and provides loop-detection. If Via is
- // already present, it will be appended to the existing value. If a loop is
- // detected an error is added to the context and the request round trip is
- // skipped.
- //
- // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9
- func (m *ViaModifier) ModifyRequest(req *http.Request) error {
- via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary)
-
- if v := req.Header.Get("Via"); v != "" {
- if m.hasLoop(v) {
- err := fmt.Errorf("via: detected request loop, header contains %s", via)
-
- ctx := martian.NewContext(req)
- ctx.Set(viaLoopKey, err)
- ctx.SkipRoundTrip()
-
- return err
- }
-
- via = fmt.Sprintf("%s, %s", v, via)
- }
-
- req.Header.Set("Via", via)
-
- return nil
- }
-
- // ModifyResponse sets the status code to 400 Bad Request if a loop was
- // detected in the request.
- func (m *ViaModifier) ModifyResponse(res *http.Response) error {
- ctx := martian.NewContext(res.Request)
-
- if err, _ := ctx.Get(viaLoopKey); err != nil {
- res.StatusCode = 400
- res.Status = http.StatusText(400)
-
- return err.(error)
- }
-
- return nil
- }
-
- // hasLoop parses via and attempts to match requestedBy against the contained
- // pseudonyms/host:port pairs.
- func (m *ViaModifier) hasLoop(via string) bool {
- for _, v := range strings.Split(via, ",") {
- parts := whitespace.Split(strings.TrimSpace(v), 3)
-
- // No pseudonym or host:port, assume there is no loop.
- if len(parts) < 2 {
- continue
- }
-
- if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] {
- return true
- }
- }
-
- return false
- }
-
- // SetBoundary sets the boundary string (random 10 character by default) used to
- // disabiguate Martians that are chained together with identical requestedBy values.
- // This should only be used for testing.
- func (m *ViaModifier) SetBoundary(boundary string) {
- m.boundary = boundary
- }
-
- // randomBoundary generates a 10 character string to ensure that Martians that
- // are chained together with the same requestedBy value do not collide. This func
- // panics if io.Readfull fails.
- func randomBoundary() string {
- var buf [10]byte
- _, err := io.ReadFull(rand.Reader, buf[:])
- if err != nil {
- panic(err)
- }
- return fmt.Sprintf("%x", buf[:])
- }
|