You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

381 lines
11 KiB

  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package main
  15. import (
  16. "crypto/tls"
  17. "crypto/x509"
  18. "encoding/base64"
  19. "fmt"
  20. "io/ioutil"
  21. "net"
  22. "net/http"
  23. "net/url"
  24. "os"
  25. "os/exec"
  26. "path/filepath"
  27. "strings"
  28. "testing"
  29. "time"
  30. "github.com/google/martian/mitm"
  31. )
  32. func waitForProxy(t *testing.T, c *http.Client, apiURL string) {
  33. timeout := 5 * time.Second
  34. deadline := time.Now().Add(timeout)
  35. for time.Now().Before(deadline) {
  36. res, err := c.Get(apiURL)
  37. if err != nil {
  38. time.Sleep(200 * time.Millisecond)
  39. continue
  40. }
  41. defer res.Body.Close()
  42. if got, want := res.StatusCode, http.StatusOK; got != want {
  43. t.Fatalf("waitForProxy: c.Get(%q): got status %d, want %d", apiURL, got, want)
  44. }
  45. return
  46. }
  47. t.Fatalf("waitForProxy: did not start up within %.1f seconds", timeout.Seconds())
  48. }
  49. // getFreePort returns a port string preceded by a colon, e.g. ":1234"
  50. func getFreePort(t *testing.T) string {
  51. l, err := net.Listen("tcp", ":")
  52. if err != nil {
  53. t.Fatalf("getFreePort: could not get free port: %v", err)
  54. }
  55. defer l.Close()
  56. return l.Addr().String()[strings.LastIndex(l.Addr().String(), ":"):]
  57. }
  58. func parseURL(t *testing.T, u string) *url.URL {
  59. p, err := url.Parse(u)
  60. if err != nil {
  61. t.Fatalf("url.Parse(%q): got error %v, want no error", u, err)
  62. }
  63. return p
  64. }
  65. func TestProxyMain(t *testing.T) {
  66. tempDir, err := ioutil.TempDir("", t.Name())
  67. if err != nil {
  68. t.Fatal(err)
  69. }
  70. defer os.RemoveAll(tempDir)
  71. // Build proxy binary
  72. binPath := filepath.Join(tempDir, "proxy")
  73. cmd := exec.Command("go", "build", "-o", binPath)
  74. cmd.Stdout = os.Stdout
  75. cmd.Stderr = os.Stderr
  76. if err := cmd.Run(); err != nil {
  77. t.Fatal(err)
  78. }
  79. t.Run("Http", func(t *testing.T) {
  80. // Start proxy
  81. proxyPort := getFreePort(t)
  82. apiPort := getFreePort(t)
  83. cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort)
  84. cmd.Stdout = os.Stdout
  85. cmd.Stderr = os.Stderr
  86. if err := cmd.Start(); err != nil {
  87. t.Fatal(err)
  88. }
  89. defer cmd.Wait()
  90. defer cmd.Process.Signal(os.Interrupt)
  91. proxyURL := "http://localhost" + proxyPort
  92. apiURL := "http://localhost" + apiPort
  93. configureURL := "http://martian.proxy/configure"
  94. // TODO: Make using API hostport directly work on Travis.
  95. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
  96. waitForProxy(t, apiClient, configureURL)
  97. // Configure modifiers
  98. config := strings.NewReader(`
  99. {
  100. "fifo.Group": {
  101. "scope": ["request", "response"],
  102. "modifiers": [
  103. {
  104. "status.Modifier": {
  105. "scope": ["response"],
  106. "statusCode": 418
  107. }
  108. },
  109. {
  110. "skip.RoundTrip": {}
  111. }
  112. ]
  113. }
  114. }`)
  115. res, err := apiClient.Post(configureURL, "application/json", config)
  116. if err != nil {
  117. t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
  118. }
  119. defer res.Body.Close()
  120. if got, want := res.StatusCode, http.StatusOK; got != want {
  121. t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
  122. }
  123. // Exercise proxy
  124. client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}}
  125. testURL := "http://super.fake.domain/"
  126. res, err = client.Get(testURL)
  127. if err != nil {
  128. t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
  129. }
  130. defer res.Body.Close()
  131. if got, want := res.StatusCode, http.StatusTeapot; got != want {
  132. t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want)
  133. }
  134. })
  135. t.Run("HttpsGenerateCert", func(t *testing.T) {
  136. // Create test certificate for test TLS server
  137. certName := "martian.proxy"
  138. certOrg := "Martian Authority"
  139. certExpiry := 90 * time.Minute
  140. servCert, servPriv, err := mitm.NewAuthority(certName, certOrg, certExpiry)
  141. if err != nil {
  142. t.Fatalf("mitm.NewAuthority(%q, %q, %q): got error %v, want no error", certName, certOrg, certExpiry, err)
  143. }
  144. mc, err := mitm.NewConfig(servCert, servPriv)
  145. if err != nil {
  146. t.Fatalf("mitm.NewConfig(%p, %q): got error %v, want no error", servCert, servPriv, err)
  147. }
  148. sc := mc.TLS()
  149. // Configure and start test TLS server
  150. servPort := getFreePort(t)
  151. l, err := tls.Listen("tcp", servPort, sc)
  152. if err != nil {
  153. t.Fatalf("tls.Listen(\"tcp\", %q, %p): got error %v, want no error", servPort, sc, err)
  154. }
  155. defer l.Close()
  156. server := &http.Server{
  157. Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  158. w.WriteHeader(http.StatusTeapot)
  159. w.Write([]byte("Hello!"))
  160. }),
  161. }
  162. go server.Serve(l)
  163. defer server.Close()
  164. // Start proxy
  165. proxyPort := getFreePort(t)
  166. apiPort := getFreePort(t)
  167. cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-generate-ca-cert", "-skip-tls-verify")
  168. cmd.Stdout = os.Stdout
  169. cmd.Stderr = os.Stderr
  170. if err := cmd.Start(); err != nil {
  171. t.Fatal(err)
  172. }
  173. defer cmd.Wait()
  174. defer cmd.Process.Signal(os.Interrupt)
  175. proxyURL := "http://localhost" + proxyPort
  176. apiURL := "http://localhost" + apiPort
  177. configureURL := "http://martian.proxy/configure"
  178. // TODO: Make using API hostport directly work on Travis.
  179. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
  180. waitForProxy(t, apiClient, configureURL)
  181. // Configure modifiers
  182. config := strings.NewReader(fmt.Sprintf(`
  183. {
  184. "body.Modifier": {
  185. "scope": ["response"],
  186. "contentType": "text/plain",
  187. "body": "%s"
  188. }
  189. }`, base64.StdEncoding.EncodeToString([]byte("茶壺"))))
  190. res, err := apiClient.Post(configureURL, "application/json", config)
  191. if err != nil {
  192. t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
  193. }
  194. defer res.Body.Close()
  195. if got, want := res.StatusCode, http.StatusOK; got != want {
  196. t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
  197. }
  198. // Install proxy's CA cert into http client
  199. caCertURL := "http://martian.proxy/authority.cer"
  200. res, err = apiClient.Get(caCertURL)
  201. if err != nil {
  202. t.Fatalf("apiClient.Get(%q): got error %v, want no error", caCertURL, err)
  203. }
  204. defer res.Body.Close()
  205. caCert, err := ioutil.ReadAll(res.Body)
  206. if err != nil {
  207. t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
  208. }
  209. caCertPool := x509.NewCertPool()
  210. caCertPool.AppendCertsFromPEM(caCert)
  211. // Exercise proxy
  212. client := &http.Client{Transport: &http.Transport{
  213. Proxy: http.ProxyURL(parseURL(t, proxyURL)),
  214. TLSClientConfig: &tls.Config{
  215. RootCAs: caCertPool,
  216. },
  217. }}
  218. testURL := "https://localhost" + servPort
  219. res, err = client.Get(testURL)
  220. if err != nil {
  221. t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
  222. }
  223. defer res.Body.Close()
  224. if got, want := res.StatusCode, http.StatusTeapot; got != want {
  225. t.Fatalf("client.Get(%q): got status %d, want %d", testURL, got, want)
  226. }
  227. body, err := ioutil.ReadAll(res.Body)
  228. if err != nil {
  229. t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
  230. }
  231. if got, want := string(body), "茶壺"; got != want {
  232. t.Fatalf("modified response body: got %s, want %s", got, want)
  233. }
  234. })
  235. t.Run("DownstreamProxy", func(t *testing.T) {
  236. // Start downstream proxy
  237. dsProxyPort := getFreePort(t)
  238. dsAPIPort := getFreePort(t)
  239. cmd := exec.Command(binPath, "-addr="+dsProxyPort, "-api-addr="+dsAPIPort)
  240. cmd.Stdout = os.Stdout
  241. cmd.Stderr = os.Stderr
  242. if err := cmd.Start(); err != nil {
  243. t.Fatal(err)
  244. }
  245. defer cmd.Wait()
  246. defer cmd.Process.Signal(os.Interrupt)
  247. dsProxyURL := "http://localhost" + dsProxyPort
  248. dsAPIURL := "http://localhost" + dsAPIPort
  249. configureURL := "http://martian.proxy/configure"
  250. // TODO: Make using API hostport directly work on Travis.
  251. dsAPIClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, dsAPIURL))}}
  252. waitForProxy(t, dsAPIClient, configureURL)
  253. // Configure modifiers
  254. config := strings.NewReader(`
  255. {
  256. "fifo.Group": {
  257. "scope": ["request", "response"],
  258. "modifiers": [
  259. {
  260. "status.Modifier": {
  261. "scope": ["response"],
  262. "statusCode": 418
  263. }
  264. },
  265. {
  266. "skip.RoundTrip": {}
  267. }
  268. ]
  269. }
  270. }`)
  271. res, err := dsAPIClient.Post(configureURL, "application/json", config)
  272. if err != nil {
  273. t.Fatalf("dsApiClient.Post(%q): got error %v, want no error", configureURL, err)
  274. }
  275. defer res.Body.Close()
  276. if got, want := res.StatusCode, http.StatusOK; got != want {
  277. t.Fatalf("dsApiClient.Post(%q): got status %d, want %d", configureURL, got, want)
  278. }
  279. // Start main proxy
  280. proxyPort := getFreePort(t)
  281. apiPort := getFreePort(t)
  282. cmd = exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-downstream-proxy-url="+dsProxyURL)
  283. cmd.Stdout = os.Stdout
  284. cmd.Stderr = os.Stderr
  285. if err := cmd.Start(); err != nil {
  286. t.Fatal(err)
  287. }
  288. defer cmd.Wait()
  289. defer cmd.Process.Signal(os.Interrupt)
  290. proxyURL := "http://localhost" + proxyPort
  291. apiURL := "http://localhost" + apiPort
  292. // TODO: Make using API hostport directly work on Travis.
  293. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}}
  294. waitForProxy(t, apiClient, configureURL)
  295. // Configure modifiers
  296. // Setting a different Via header value to circumvent loop detection.
  297. config = strings.NewReader(fmt.Sprintf(`
  298. {
  299. "fifo.Group": {
  300. "scope": ["request", "response"],
  301. "modifiers": [
  302. {
  303. "header.Modifier": {
  304. "scope": ["request"],
  305. "name": "Via",
  306. "value": "martian_1"
  307. }
  308. },
  309. {
  310. "body.Modifier": {
  311. "scope": ["response"],
  312. "contentType": "text/plain",
  313. "body": "%s"
  314. }
  315. }
  316. ]
  317. }
  318. }`, base64.StdEncoding.EncodeToString([]byte("茶壺"))))
  319. res, err = apiClient.Post(configureURL, "application/json", config)
  320. if err != nil {
  321. t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err)
  322. }
  323. defer res.Body.Close()
  324. if got, want := res.StatusCode, http.StatusOK; got != want {
  325. t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want)
  326. }
  327. // Exercise proxy
  328. client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}}
  329. testURL := "http://super.fake.domain/"
  330. res, err = client.Get(testURL)
  331. if err != nil {
  332. t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err)
  333. }
  334. defer res.Body.Close()
  335. if got, want := res.StatusCode, http.StatusTeapot; got != want {
  336. t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want)
  337. }
  338. body, err := ioutil.ReadAll(res.Body)
  339. if err != nil {
  340. t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err)
  341. }
  342. if got, want := string(body), "茶壺"; got != want {
  343. t.Fatalf("modified response body: got %s, want %s", got, want)
  344. }
  345. })
  346. }