|
- // Copyright 2015 Google Inc. All rights reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package main
-
- import (
- "crypto/tls"
- "crypto/x509"
- "encoding/base64"
- "fmt"
- "io/ioutil"
- "net"
- "net/http"
- "net/url"
- "os"
- "os/exec"
- "path/filepath"
- "strings"
- "testing"
- "time"
-
- "github.com/google/martian/mitm"
- )
-
- func waitForProxy(t *testing.T, c *http.Client, apiURL string) {
- timeout := 5 * time.Second
- deadline := time.Now().Add(timeout)
- for time.Now().Before(deadline) {
- res, err := c.Get(apiURL)
- if err != nil {
- time.Sleep(200 * time.Millisecond)
- continue
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusOK; got != want {
- t.Fatalf("waitForProxy: c.Get(%q): got status %d, want %d", apiURL, got, want)
- }
- return
- }
- t.Fatalf("waitForProxy: did not start up within %.1f seconds", timeout.Seconds())
- }
-
- // getFreePort returns a port string preceded by a colon, e.g. ":1234"
- func getFreePort(t *testing.T) string {
- l, err := net.Listen("tcp", ":")
- if err != nil {
- t.Fatalf("getFreePort: could not get free port: %v", err)
- }
- defer l.Close()
- return l.Addr().String()[strings.LastIndex(l.Addr().String(), ":"):]
- }
-
- func parseURL(t *testing.T, u string) *url.URL {
- p, err := url.Parse(u)
- if err != nil {
- t.Fatalf("url.Parse(%q): got error %v, want no error", u, err)
- }
- return p
- }
-
- func TestProxyMain(t *testing.T) {
- tempDir, err := ioutil.TempDir("", t.Name())
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(tempDir)
-
- // Build proxy binary
- binPath := filepath.Join(tempDir, "proxy")
- cmd := exec.Command("go", "build", "-o", binPath)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Run(); err != nil {
- t.Fatal(err)
- }
-
- t.Run("Http", func(t *testing.T) {
- // Start proxy
- proxyPort := getFreePort(t)
- apiPort := getFreePort(t)
- cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Start(); err != nil {
- t.Fatal(err)
- }
- defer cmd.Wait()
- defer cmd.Process.Signal(os.Interrupt)
-
- proxyURL := "http://localhost" + proxyPort
- apiURL := "http://localhost" + apiPort
- configureURL := "http://martian.proxy/configure"
-
- // TODO: Make using API hostport directly work on Travis.
- apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
- waitForProxy(t, apiClient, configureURL)
-
- // Configure modifiers
- config := strings.NewReader(`
- {
- "fifo.Group": {
- "scope": ["request", "response"],
- "modifiers": [
- {
- "status.Modifier": {
- "scope": ["response"],
- "statusCode": 418
- }
- },
- {
- "skip.RoundTrip": {}
- }
- ]
- }
- }`)
- res, err := apiClient.Post(configureURL, "application/json", config)
- if err != nil {
- t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusOK; got != want {
- t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
- }
-
- // Exercise proxy
- client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}}
-
- testURL := "http://super.fake.domain/"
- res, err = client.Get(testURL)
- if err != nil {
- t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusTeapot; got != want {
- t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want)
- }
- })
-
- t.Run("HttpsGenerateCert", func(t *testing.T) {
- // Create test certificate for test TLS server
- certName := "martian.proxy"
- certOrg := "Martian Authority"
- certExpiry := 90 * time.Minute
- servCert, servPriv, err := mitm.NewAuthority(certName, certOrg, certExpiry)
- if err != nil {
- t.Fatalf("mitm.NewAuthority(%q, %q, %q): got error %v, want no error", certName, certOrg, certExpiry, err)
- }
- mc, err := mitm.NewConfig(servCert, servPriv)
- if err != nil {
- t.Fatalf("mitm.NewConfig(%p, %q): got error %v, want no error", servCert, servPriv, err)
- }
- sc := mc.TLS()
-
- // Configure and start test TLS server
- servPort := getFreePort(t)
- l, err := tls.Listen("tcp", servPort, sc)
- if err != nil {
- t.Fatalf("tls.Listen(\"tcp\", %q, %p): got error %v, want no error", servPort, sc, err)
- }
- defer l.Close()
-
- server := &http.Server{
- Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusTeapot)
- w.Write([]byte("Hello!"))
- }),
- }
- go server.Serve(l)
- defer server.Close()
-
- // Start proxy
- proxyPort := getFreePort(t)
- apiPort := getFreePort(t)
- cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-generate-ca-cert", "-skip-tls-verify")
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Start(); err != nil {
- t.Fatal(err)
- }
- defer cmd.Wait()
- defer cmd.Process.Signal(os.Interrupt)
-
- proxyURL := "http://localhost" + proxyPort
- apiURL := "http://localhost" + apiPort
- configureURL := "http://martian.proxy/configure"
-
- // TODO: Make using API hostport directly work on Travis.
- apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
- waitForProxy(t, apiClient, configureURL)
-
- // Configure modifiers
- config := strings.NewReader(fmt.Sprintf(`
- {
- "body.Modifier": {
- "scope": ["response"],
- "contentType": "text/plain",
- "body": "%s"
- }
- }`, base64.StdEncoding.EncodeToString([]byte("茶壺"))))
- res, err := apiClient.Post(configureURL, "application/json", config)
- if err != nil {
- t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusOK; got != want {
- t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
- }
-
- // Install proxy's CA cert into http client
- caCertURL := "http://martian.proxy/authority.cer"
- res, err = apiClient.Get(caCertURL)
- if err != nil {
- t.Fatalf("apiClient.Get(%q): got error %v, want no error", caCertURL, err)
- }
- defer res.Body.Close()
- caCert, err := ioutil.ReadAll(res.Body)
- if err != nil {
- t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
- }
- caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
-
- // Exercise proxy
- client := &http.Client{Transport: &http.Transport{
- Proxy: http.ProxyURL(parseURL(t, proxyURL)),
- TLSClientConfig: &tls.Config{
- RootCAs: caCertPool,
- },
- }}
-
- testURL := "https://localhost" + servPort
- res, err = client.Get(testURL)
- if err != nil {
- t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusTeapot; got != want {
- t.Fatalf("client.Get(%q): got status %d, want %d", testURL, got, want)
- }
- body, err := ioutil.ReadAll(res.Body)
- if err != nil {
- t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
- }
- if got, want := string(body), "茶壺"; got != want {
- t.Fatalf("modified response body: got %s, want %s", got, want)
- }
- })
-
- t.Run("DownstreamProxy", func(t *testing.T) {
- // Start downstream proxy
- dsProxyPort := getFreePort(t)
- dsAPIPort := getFreePort(t)
- cmd := exec.Command(binPath, "-addr="+dsProxyPort, "-api-addr="+dsAPIPort)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Start(); err != nil {
- t.Fatal(err)
- }
- defer cmd.Wait()
- defer cmd.Process.Signal(os.Interrupt)
-
- dsProxyURL := "http://localhost" + dsProxyPort
- dsAPIURL := "http://localhost" + dsAPIPort
- configureURL := "http://martian.proxy/configure"
-
- // TODO: Make using API hostport directly work on Travis.
- dsAPIClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, dsAPIURL))}}
- waitForProxy(t, dsAPIClient, configureURL)
-
- // Configure modifiers
- config := strings.NewReader(`
- {
- "fifo.Group": {
- "scope": ["request", "response"],
- "modifiers": [
- {
- "status.Modifier": {
- "scope": ["response"],
- "statusCode": 418
- }
- },
- {
- "skip.RoundTrip": {}
- }
- ]
- }
- }`)
- res, err := dsAPIClient.Post(configureURL, "application/json", config)
- if err != nil {
- t.Fatalf("dsApiClient.Post(%q): got error %v, want no error", configureURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusOK; got != want {
- t.Fatalf("dsApiClient.Post(%q): got status %d, want %d", configureURL, got, want)
- }
-
- // Start main proxy
- proxyPort := getFreePort(t)
- apiPort := getFreePort(t)
- cmd = exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-downstream-proxy-url="+dsProxyURL)
- cmd.Stdout = os.Stdout
- cmd.Stderr = os.Stderr
- if err := cmd.Start(); err != nil {
- t.Fatal(err)
- }
- defer cmd.Wait()
- defer cmd.Process.Signal(os.Interrupt)
-
- proxyURL := "http://localhost" + proxyPort
- apiURL := "http://localhost" + apiPort
-
- // TODO: Make using API hostport directly work on Travis.
- apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
- waitForProxy(t, apiClient, configureURL)
-
- // Configure modifiers
- // Setting a different Via header value to circumvent loop detection.
- config = strings.NewReader(fmt.Sprintf(`
- {
- "fifo.Group": {
- "scope": ["request", "response"],
- "modifiers": [
- {
- "header.Modifier": {
- "scope": ["request"],
- "name": "Via",
- "value": "martian_1"
- }
- },
- {
- "body.Modifier": {
- "scope": ["response"],
- "contentType": "text/plain",
- "body": "%s"
- }
- }
- ]
- }
- }`, base64.StdEncoding.EncodeToString([]byte("茶壺"))))
- res, err = apiClient.Post(configureURL, "application/json", config)
- if err != nil {
- t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusOK; got != want {
- t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
- }
-
- // Exercise proxy
- client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}}
-
- testURL := "http://super.fake.domain/"
- res, err = client.Get(testURL)
- if err != nil {
- t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
- }
- defer res.Body.Close()
- if got, want := res.StatusCode, http.StatusTeapot; got != want {
- t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want)
- }
- body, err := ioutil.ReadAll(res.Body)
- if err != nil {
- t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
- }
- if got, want := string(body), "茶壺"; got != want {
- t.Fatalf("modified response body: got %s, want %s", got, want)
- }
- })
- }
|