|
- // Copyright 2015 Google Inc. All rights reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package proxyutil
-
- import (
- "fmt"
- "net/http"
- "strconv"
- )
-
- // Header is a generic representation of a set of HTTP headers for requests and
- // responses.
- type Header struct {
- h http.Header
-
- host func() string
- cl func() int64
- te func() []string
-
- setHost func(string)
- setCL func(int64)
- setTE func([]string)
- }
-
- // RequestHeader returns a new set of headers from a request.
- func RequestHeader(req *http.Request) *Header {
- return &Header{
- h: req.Header,
- host: func() string { return req.Host },
- cl: func() int64 { return req.ContentLength },
- te: func() []string { return req.TransferEncoding },
- setHost: func(host string) { req.Host = host },
- setCL: func(cl int64) { req.ContentLength = cl },
- setTE: func(te []string) { req.TransferEncoding = te },
- }
- }
-
- // ResponseHeader returns a new set of headers from a request.
- func ResponseHeader(res *http.Response) *Header {
- return &Header{
- h: res.Header,
- host: func() string { return "" },
- cl: func() int64 { return res.ContentLength },
- te: func() []string { return res.TransferEncoding },
- setHost: func(string) {},
- setCL: func(cl int64) { res.ContentLength = cl },
- setTE: func(te []string) { res.TransferEncoding = te },
- }
- }
-
- // Set sets value at header name for the request or response.
- func (h *Header) Set(name, value string) error {
- switch http.CanonicalHeaderKey(name) {
- case "Host":
- h.setHost(value)
- case "Content-Length":
- cl, err := strconv.ParseInt(value, 10, 64)
- if err != nil {
- return err
- }
-
- h.setCL(cl)
- case "Transfer-Encoding":
- h.setTE([]string{value})
- default:
- h.h.Set(name, value)
- }
-
- return nil
- }
-
- // Add appends the value to the existing header at name for the request or
- // response.
- func (h *Header) Add(name, value string) error {
- switch http.CanonicalHeaderKey(name) {
- case "Host":
- if h.host() != "" {
- return fmt.Errorf("proxyutil: illegal header multiple: %s", "Host")
- }
-
- return h.Set(name, value)
- case "Content-Length":
- if h.cl() > 0 {
- return fmt.Errorf("proxyutil: illegal header multiple: %s", "Content-Length")
- }
-
- return h.Set(name, value)
- case "Transfer-Encoding":
- h.setTE(append(h.te(), value))
- default:
- h.h.Add(name, value)
- }
-
- return nil
- }
-
- // Get returns the first value at header name for the request or response.
- func (h *Header) Get(name string) string {
- switch http.CanonicalHeaderKey(name) {
- case "Host":
- return h.host()
- case "Content-Length":
- if h.cl() < 0 {
- return ""
- }
-
- return strconv.FormatInt(h.cl(), 10)
- case "Transfer-Encoding":
- if len(h.te()) < 1 {
- return ""
- }
-
- return h.te()[0]
- default:
- return h.h.Get(name)
- }
- }
-
- // All returns all the values for header name. If the header does not exist it
- // returns nil, false.
- func (h *Header) All(name string) ([]string, bool) {
- switch http.CanonicalHeaderKey(name) {
- case "Host":
- if h.host() == "" {
- return nil, false
- }
-
- return []string{h.host()}, true
- case "Content-Length":
- if h.cl() <= 0 {
- return nil, false
- }
-
- return []string{strconv.FormatInt(h.cl(), 10)}, true
- case "Transfer-Encoding":
- if h.te() == nil {
- return nil, false
- }
-
- return h.te(), true
- default:
- vs, ok := h.h[http.CanonicalHeaderKey(name)]
- return vs, ok
- }
- }
-
- // Del deletes the header at name for the request or response.
- func (h *Header) Del(name string) {
- switch http.CanonicalHeaderKey(name) {
- case "Host":
- h.setHost("")
- case "Content-Length":
- h.setCL(-1)
- case "Transfer-Encoding":
- h.setTE(nil)
- default:
- h.h.Del(name)
- }
- }
-
- // Map returns an http.Header that includes Host, Content-Length, and
- // Transfer-Encoding.
- func (h *Header) Map() http.Header {
- hm := make(http.Header)
-
- for k, vs := range h.h {
- hm[k] = vs
- }
-
- for _, k := range []string{
- "Host",
- "Content-Length",
- "Transfer-Encoding",
- } {
- vs, ok := h.All(k)
- if !ok {
- continue
- }
-
- hm[k] = vs
- }
-
- return hm
- }
|