You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1317 regels
34 KiB

  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package martian
  15. import (
  16. "bufio"
  17. "bytes"
  18. "crypto/tls"
  19. "crypto/x509"
  20. "errors"
  21. "fmt"
  22. "io"
  23. "io/ioutil"
  24. "net"
  25. "net/http"
  26. "net/url"
  27. "os"
  28. "strings"
  29. "testing"
  30. "time"
  31. "github.com/google/martian/log"
  32. "github.com/google/martian/martiantest"
  33. "github.com/google/martian/mitm"
  34. "github.com/google/martian/proxyutil"
  35. )
  36. type tempError struct{}
  37. func (e *tempError) Error() string { return "temporary" }
  38. func (e *tempError) Timeout() bool { return true }
  39. func (e *tempError) Temporary() bool { return true }
  40. type timeoutListener struct {
  41. net.Listener
  42. errCount int
  43. err error
  44. }
  45. func newTimeoutListener(l net.Listener, errCount int) net.Listener {
  46. return &timeoutListener{
  47. Listener: l,
  48. errCount: errCount,
  49. err: &tempError{},
  50. }
  51. }
  52. func (l *timeoutListener) Accept() (net.Conn, error) {
  53. if l.errCount > 0 {
  54. l.errCount--
  55. return nil, l.err
  56. }
  57. return l.Listener.Accept()
  58. }
  59. func TestIntegrationTemporaryTimeout(t *testing.T) {
  60. t.Parallel()
  61. l, err := net.Listen("tcp", "[::]:0")
  62. if err != nil {
  63. t.Fatalf("net.Listen(): got %v, want no error", err)
  64. }
  65. p := NewProxy()
  66. defer p.Close()
  67. tr := martiantest.NewTransport()
  68. p.SetRoundTripper(tr)
  69. p.SetTimeout(200 * time.Millisecond)
  70. // Start the proxy with a listener that will return a temporary error on
  71. // Accept() three times.
  72. go p.Serve(newTimeoutListener(l, 3))
  73. conn, err := net.Dial("tcp", l.Addr().String())
  74. if err != nil {
  75. t.Fatalf("net.Dial(): got %v, want no error", err)
  76. }
  77. defer conn.Close()
  78. req, err := http.NewRequest("GET", "http://example.com", nil)
  79. if err != nil {
  80. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  81. }
  82. req.Header.Set("Connection", "close")
  83. // GET http://example.com/ HTTP/1.1
  84. // Host: example.com
  85. if err := req.WriteProxy(conn); err != nil {
  86. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  87. }
  88. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  89. if err != nil {
  90. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  91. }
  92. defer res.Body.Close()
  93. if got, want := res.StatusCode, 200; got != want {
  94. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  95. }
  96. }
  97. func TestIntegrationHTTP(t *testing.T) {
  98. t.Parallel()
  99. l, err := net.Listen("tcp", "[::]:0")
  100. if err != nil {
  101. t.Fatalf("net.Listen(): got %v, want no error", err)
  102. }
  103. p := NewProxy()
  104. defer p.Close()
  105. p.SetRequestModifier(nil)
  106. p.SetResponseModifier(nil)
  107. tr := martiantest.NewTransport()
  108. p.SetRoundTripper(tr)
  109. p.SetTimeout(200 * time.Millisecond)
  110. tm := martiantest.NewModifier()
  111. tm.RequestFunc(func(req *http.Request) {
  112. ctx := NewContext(req)
  113. ctx.Set("martian.test", "true")
  114. })
  115. tm.ResponseFunc(func(res *http.Response) {
  116. ctx := NewContext(res.Request)
  117. v, _ := ctx.Get("martian.test")
  118. res.Header.Set("Martian-Test", v.(string))
  119. })
  120. p.SetRequestModifier(tm)
  121. p.SetResponseModifier(tm)
  122. go p.Serve(l)
  123. conn, err := net.Dial("tcp", l.Addr().String())
  124. if err != nil {
  125. t.Fatalf("net.Dial(): got %v, want no error", err)
  126. }
  127. defer conn.Close()
  128. req, err := http.NewRequest("GET", "http://example.com", nil)
  129. if err != nil {
  130. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  131. }
  132. // GET http://example.com/ HTTP/1.1
  133. // Host: example.com
  134. if err := req.WriteProxy(conn); err != nil {
  135. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  136. }
  137. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  138. if err != nil {
  139. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  140. }
  141. if got, want := res.StatusCode, 200; got != want {
  142. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  143. }
  144. if got, want := res.Header.Get("Martian-Test"), "true"; got != want {
  145. t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want)
  146. }
  147. }
  148. func TestIntegrationHTTP100Continue(t *testing.T) {
  149. t.Parallel()
  150. l, err := net.Listen("tcp", "[::]:0")
  151. if err != nil {
  152. t.Fatalf("net.Listen(): got %v, want no error", err)
  153. }
  154. p := NewProxy()
  155. defer p.Close()
  156. p.SetTimeout(2 * time.Second)
  157. sl, err := net.Listen("tcp", "[::]:0")
  158. if err != nil {
  159. t.Fatalf("net.Listen(): got %v, want no error", err)
  160. }
  161. go func() {
  162. conn, err := sl.Accept()
  163. if err != nil {
  164. log.Errorf("proxy_test: failed to accept connection: %v", err)
  165. return
  166. }
  167. defer conn.Close()
  168. log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr())
  169. req, err := http.ReadRequest(bufio.NewReader(conn))
  170. if err != nil {
  171. log.Errorf("proxy_test: failed to read request: %v", err)
  172. return
  173. }
  174. if req.Header.Get("Expect") == "100-continue" {
  175. log.Infof("proxy_test: received 100-continue request")
  176. conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n"))
  177. log.Infof("proxy_test: sent 100-continue response")
  178. } else {
  179. log.Infof("proxy_test: received non 100-continue request")
  180. res := proxyutil.NewResponse(417, nil, req)
  181. res.Header.Set("Connection", "close")
  182. res.Write(conn)
  183. return
  184. }
  185. res := proxyutil.NewResponse(200, req.Body, req)
  186. res.Header.Set("Connection", "close")
  187. res.Write(conn)
  188. log.Infof("proxy_test: sent 200 response")
  189. }()
  190. tm := martiantest.NewModifier()
  191. p.SetRequestModifier(tm)
  192. p.SetResponseModifier(tm)
  193. go p.Serve(l)
  194. conn, err := net.Dial("tcp", l.Addr().String())
  195. if err != nil {
  196. t.Fatalf("net.Dial(): got %v, want no error", err)
  197. }
  198. defer conn.Close()
  199. host := sl.Addr().String()
  200. raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+
  201. "Host: %s\r\n"+
  202. "Content-Length: 12\r\n"+
  203. "Expect: 100-continue\r\n\r\n", host, host)
  204. if _, err := conn.Write([]byte(raw)); err != nil {
  205. t.Fatalf("conn.Write(headers): got %v, want no error", err)
  206. }
  207. go func() {
  208. select {
  209. case <-time.After(time.Second):
  210. conn.Write([]byte("body content"))
  211. }
  212. }()
  213. res, err := http.ReadResponse(bufio.NewReader(conn), nil)
  214. if err != nil {
  215. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  216. }
  217. defer res.Body.Close()
  218. if got, want := res.StatusCode, 200; got != want {
  219. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  220. }
  221. got, err := ioutil.ReadAll(res.Body)
  222. if err != nil {
  223. t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
  224. }
  225. if want := []byte("body content"); !bytes.Equal(got, want) {
  226. t.Errorf("res.Body: got %q, want %q", got, want)
  227. }
  228. if !tm.RequestModified() {
  229. t.Error("tm.RequestModified(): got false, want true")
  230. }
  231. if !tm.ResponseModified() {
  232. t.Error("tm.ResponseModified(): got false, want true")
  233. }
  234. }
  235. func TestIntegrationHTTPDownstreamProxy(t *testing.T) {
  236. t.Parallel()
  237. // Start first proxy to use as downstream.
  238. dl, err := net.Listen("tcp", "[::]:0")
  239. if err != nil {
  240. t.Fatalf("net.Listen(): got %v, want no error", err)
  241. }
  242. downstream := NewProxy()
  243. defer downstream.Close()
  244. dtr := martiantest.NewTransport()
  245. dtr.Respond(299)
  246. downstream.SetRoundTripper(dtr)
  247. downstream.SetTimeout(600 * time.Millisecond)
  248. go downstream.Serve(dl)
  249. // Start second proxy as upstream proxy, will write to downstream proxy.
  250. ul, err := net.Listen("tcp", "[::]:0")
  251. if err != nil {
  252. t.Fatalf("net.Listen(): got %v, want no error", err)
  253. }
  254. upstream := NewProxy()
  255. defer upstream.Close()
  256. // Set upstream proxy's downstream proxy to the host:port of the first proxy.
  257. upstream.SetDownstreamProxy(&url.URL{
  258. Host: dl.Addr().String(),
  259. })
  260. upstream.SetTimeout(600 * time.Millisecond)
  261. go upstream.Serve(ul)
  262. // Open connection to upstream proxy.
  263. conn, err := net.Dial("tcp", ul.Addr().String())
  264. if err != nil {
  265. t.Fatalf("net.Dial(): got %v, want no error", err)
  266. }
  267. defer conn.Close()
  268. req, err := http.NewRequest("GET", "http://example.com", nil)
  269. if err != nil {
  270. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  271. }
  272. // GET http://example.com/ HTTP/1.1
  273. // Host: example.com
  274. if err := req.WriteProxy(conn); err != nil {
  275. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  276. }
  277. // Response from downstream proxy.
  278. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  279. if err != nil {
  280. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  281. }
  282. if got, want := res.StatusCode, 299; got != want {
  283. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  284. }
  285. }
  286. func TestIntegrationHTTPDownstreamProxyError(t *testing.T) {
  287. t.Parallel()
  288. l, err := net.Listen("tcp", "[::]:0")
  289. if err != nil {
  290. t.Fatalf("net.Listen(): got %v, want no error", err)
  291. }
  292. p := NewProxy()
  293. defer p.Close()
  294. // Set proxy's downstream proxy to invalid host:port to force failure.
  295. p.SetDownstreamProxy(&url.URL{
  296. Host: "[::]:0",
  297. })
  298. p.SetTimeout(600 * time.Millisecond)
  299. tm := martiantest.NewModifier()
  300. reserr := errors.New("response error")
  301. tm.ResponseError(reserr)
  302. p.SetResponseModifier(tm)
  303. go p.Serve(l)
  304. // Open connection to upstream proxy.
  305. conn, err := net.Dial("tcp", l.Addr().String())
  306. if err != nil {
  307. t.Fatalf("net.Dial(): got %v, want no error", err)
  308. }
  309. defer conn.Close()
  310. req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
  311. if err != nil {
  312. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  313. }
  314. // CONNECT example.com:443 HTTP/1.1
  315. // Host: example.com
  316. if err := req.Write(conn); err != nil {
  317. t.Fatalf("req.Write(): got %v, want no error", err)
  318. }
  319. // Response from upstream proxy, assuming downstream proxy failed to CONNECT.
  320. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  321. if err != nil {
  322. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  323. }
  324. if got, want := res.StatusCode, 502; got != want {
  325. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  326. }
  327. if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) {
  328. t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want)
  329. }
  330. }
  331. func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) {
  332. t.Parallel()
  333. l, err := net.Listen("tcp", "[::]:0")
  334. if err != nil {
  335. t.Fatalf("net.Listen(): got %v, want no error", err)
  336. }
  337. p := NewProxy()
  338. defer p.Close()
  339. // Test TLS server.
  340. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
  341. if err != nil {
  342. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  343. }
  344. mc, err := mitm.NewConfig(ca, priv)
  345. if err != nil {
  346. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  347. }
  348. var herr error
  349. mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { herr = fmt.Errorf("handshake error") })
  350. p.SetMITM(mc)
  351. tl, err := net.Listen("tcp", "[::]:0")
  352. if err != nil {
  353. t.Fatalf("tls.Listen(): got %v, want no error", err)
  354. }
  355. tl = tls.NewListener(tl, mc.TLS())
  356. go http.Serve(tl, http.HandlerFunc(
  357. func(rw http.ResponseWriter, req *http.Request) {
  358. rw.WriteHeader(200)
  359. }))
  360. tm := martiantest.NewModifier()
  361. // Force the CONNECT request to dial the local TLS server.
  362. tm.RequestFunc(func(req *http.Request) {
  363. req.URL.Host = tl.Addr().String()
  364. })
  365. go p.Serve(l)
  366. conn, err := net.Dial("tcp", l.Addr().String())
  367. if err != nil {
  368. t.Fatalf("net.Dial(): got %v, want no error", err)
  369. }
  370. defer conn.Close()
  371. req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
  372. if err != nil {
  373. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  374. }
  375. // CONNECT example.com:443 HTTP/1.1
  376. // Host: example.com
  377. //
  378. // Rewritten to CONNECT to host:port in CONNECT request modifier.
  379. if err := req.Write(conn); err != nil {
  380. t.Fatalf("req.Write(): got %v, want no error", err)
  381. }
  382. // CONNECT response after establishing tunnel.
  383. if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil {
  384. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  385. }
  386. tlsconn := tls.Client(conn, &tls.Config{
  387. ServerName: "example.com",
  388. // Client has no cert so it will get "x509: certificate signed by unknown authority" from the
  389. // handshake and send "remote error: bad certificate" to the server.
  390. RootCAs: x509.NewCertPool(),
  391. })
  392. defer tlsconn.Close()
  393. req, err = http.NewRequest("GET", "https://example.com", nil)
  394. if err != nil {
  395. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  396. }
  397. req.Header.Set("Connection", "close")
  398. if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) {
  399. t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want)
  400. }
  401. // TODO: herr is not being asserted against. It should be pushed on to a channel
  402. // of err, and the assertion should pull off of it and assert. That design resulted in the test
  403. // hanging for unknown reasons.
  404. t.Skip("skipping assertion of handshake error callback error due to mysterious deadlock")
  405. if got, want := herr, "remote error: bad certificate"; !strings.Contains(got.Error(), want) {
  406. t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want)
  407. }
  408. }
  409. func TestIntegrationConnect(t *testing.T) {
  410. t.Parallel()
  411. l, err := net.Listen("tcp", "[::]:0")
  412. if err != nil {
  413. t.Fatalf("net.Listen(): got %v, want no error", err)
  414. }
  415. p := NewProxy()
  416. defer p.Close()
  417. // Test TLS server.
  418. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour)
  419. if err != nil {
  420. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  421. }
  422. mc, err := mitm.NewConfig(ca, priv)
  423. if err != nil {
  424. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  425. }
  426. tl, err := net.Listen("tcp", "[::]:0")
  427. if err != nil {
  428. t.Fatalf("tls.Listen(): got %v, want no error", err)
  429. }
  430. tl = tls.NewListener(tl, mc.TLS())
  431. go http.Serve(tl, http.HandlerFunc(
  432. func(rw http.ResponseWriter, req *http.Request) {
  433. rw.WriteHeader(299)
  434. }))
  435. tm := martiantest.NewModifier()
  436. reqerr := errors.New("request error")
  437. reserr := errors.New("response error")
  438. // Force the CONNECT request to dial the local TLS server.
  439. tm.RequestFunc(func(req *http.Request) {
  440. req.URL.Host = tl.Addr().String()
  441. })
  442. tm.RequestError(reqerr)
  443. tm.ResponseError(reserr)
  444. p.SetRequestModifier(tm)
  445. p.SetResponseModifier(tm)
  446. go p.Serve(l)
  447. conn, err := net.Dial("tcp", l.Addr().String())
  448. if err != nil {
  449. t.Fatalf("net.Dial(): got %v, want no error", err)
  450. }
  451. defer conn.Close()
  452. req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
  453. if err != nil {
  454. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  455. }
  456. // CONNECT example.com:443 HTTP/1.1
  457. // Host: example.com
  458. //
  459. // Rewritten to CONNECT to host:port in CONNECT request modifier.
  460. if err := req.Write(conn); err != nil {
  461. t.Fatalf("req.Write(): got %v, want no error", err)
  462. }
  463. // CONNECT response after establishing tunnel.
  464. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  465. if err != nil {
  466. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  467. }
  468. if got, want := res.StatusCode, 200; got != want {
  469. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  470. }
  471. if !tm.RequestModified() {
  472. t.Error("tm.RequestModified(): got false, want true")
  473. }
  474. if !tm.ResponseModified() {
  475. t.Error("tm.ResponseModified(): got false, want true")
  476. }
  477. if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
  478. t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
  479. }
  480. roots := x509.NewCertPool()
  481. roots.AddCert(ca)
  482. tlsconn := tls.Client(conn, &tls.Config{
  483. ServerName: "example.com",
  484. RootCAs: roots,
  485. })
  486. defer tlsconn.Close()
  487. req, err = http.NewRequest("GET", "https://example.com", nil)
  488. if err != nil {
  489. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  490. }
  491. req.Header.Set("Connection", "close")
  492. // GET / HTTP/1.1
  493. // Host: example.com
  494. // Connection: close
  495. if err := req.Write(tlsconn); err != nil {
  496. t.Fatalf("req.Write(): got %v, want no error", err)
  497. }
  498. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
  499. if err != nil {
  500. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  501. }
  502. defer res.Body.Close()
  503. if got, want := res.StatusCode, 299; got != want {
  504. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  505. }
  506. if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) {
  507. t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want)
  508. }
  509. }
  510. func TestIntegrationConnectDownstreamProxy(t *testing.T) {
  511. t.Parallel()
  512. // Start first proxy to use as downstream.
  513. dl, err := net.Listen("tcp", "[::]:0")
  514. if err != nil {
  515. t.Fatalf("net.Listen(): got %v, want no error", err)
  516. }
  517. downstream := NewProxy()
  518. defer downstream.Close()
  519. dtr := martiantest.NewTransport()
  520. dtr.Respond(299)
  521. downstream.SetRoundTripper(dtr)
  522. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  523. if err != nil {
  524. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  525. }
  526. mc, err := mitm.NewConfig(ca, priv)
  527. if err != nil {
  528. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  529. }
  530. downstream.SetMITM(mc)
  531. go downstream.Serve(dl)
  532. // Start second proxy as upstream proxy, will CONNECT to downstream proxy.
  533. ul, err := net.Listen("tcp", "[::]:0")
  534. if err != nil {
  535. t.Fatalf("net.Listen(): got %v, want no error", err)
  536. }
  537. upstream := NewProxy()
  538. defer upstream.Close()
  539. // Set upstream proxy's downstream proxy to the host:port of the first proxy.
  540. upstream.SetDownstreamProxy(&url.URL{
  541. Host: dl.Addr().String(),
  542. })
  543. go upstream.Serve(ul)
  544. // Open connection to upstream proxy.
  545. conn, err := net.Dial("tcp", ul.Addr().String())
  546. if err != nil {
  547. t.Fatalf("net.Dial(): got %v, want no error", err)
  548. }
  549. defer conn.Close()
  550. req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
  551. if err != nil {
  552. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  553. }
  554. // CONNECT example.com:443 HTTP/1.1
  555. // Host: example.com
  556. if err := req.Write(conn); err != nil {
  557. t.Fatalf("req.Write(): got %v, want no error", err)
  558. }
  559. // Response from downstream proxy starting MITM.
  560. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  561. if err != nil {
  562. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  563. }
  564. if got, want := res.StatusCode, 200; got != want {
  565. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  566. }
  567. roots := x509.NewCertPool()
  568. roots.AddCert(ca)
  569. tlsconn := tls.Client(conn, &tls.Config{
  570. // Validate the hostname.
  571. ServerName: "example.com",
  572. // The certificate will have been MITM'd, verify using the MITM CA
  573. // certificate.
  574. RootCAs: roots,
  575. })
  576. defer tlsconn.Close()
  577. req, err = http.NewRequest("GET", "https://example.com", nil)
  578. if err != nil {
  579. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  580. }
  581. // GET / HTTP/1.1
  582. // Host: example.com
  583. if err := req.Write(tlsconn); err != nil {
  584. t.Fatalf("req.Write(): got %v, want no error", err)
  585. }
  586. // Response from MITM in downstream proxy.
  587. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
  588. if err != nil {
  589. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  590. }
  591. defer res.Body.Close()
  592. if got, want := res.StatusCode, 299; got != want {
  593. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  594. }
  595. }
  596. func TestIntegrationMITM(t *testing.T) {
  597. t.Parallel()
  598. l, err := net.Listen("tcp", "[::]:0")
  599. if err != nil {
  600. t.Fatalf("net.Listen(): got %v, want no error", err)
  601. }
  602. p := NewProxy()
  603. defer p.Close()
  604. tr := martiantest.NewTransport()
  605. tr.Func(func(req *http.Request) (*http.Response, error) {
  606. res := proxyutil.NewResponse(200, nil, req)
  607. res.Header.Set("Request-Scheme", req.URL.Scheme)
  608. return res, nil
  609. })
  610. p.SetRoundTripper(tr)
  611. p.SetTimeout(600 * time.Millisecond)
  612. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  613. if err != nil {
  614. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  615. }
  616. mc, err := mitm.NewConfig(ca, priv)
  617. if err != nil {
  618. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  619. }
  620. p.SetMITM(mc)
  621. tm := martiantest.NewModifier()
  622. reqerr := errors.New("request error")
  623. reserr := errors.New("response error")
  624. tm.RequestError(reqerr)
  625. tm.ResponseError(reserr)
  626. p.SetRequestModifier(tm)
  627. p.SetResponseModifier(tm)
  628. go p.Serve(l)
  629. conn, err := net.Dial("tcp", l.Addr().String())
  630. if err != nil {
  631. t.Fatalf("net.Dial(): got %v, want no error", err)
  632. }
  633. defer conn.Close()
  634. req, err := http.NewRequest("CONNECT", "//example.com:443", nil)
  635. if err != nil {
  636. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  637. }
  638. // CONNECT example.com:443 HTTP/1.1
  639. // Host: example.com
  640. if err := req.Write(conn); err != nil {
  641. t.Fatalf("req.Write(): got %v, want no error", err)
  642. }
  643. // Response MITM'd from proxy.
  644. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  645. if err != nil {
  646. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  647. }
  648. if got, want := res.StatusCode, 200; got != want {
  649. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  650. }
  651. if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
  652. t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
  653. }
  654. roots := x509.NewCertPool()
  655. roots.AddCert(ca)
  656. tlsconn := tls.Client(conn, &tls.Config{
  657. ServerName: "example.com",
  658. RootCAs: roots,
  659. })
  660. defer tlsconn.Close()
  661. req, err = http.NewRequest("GET", "https://example.com", nil)
  662. if err != nil {
  663. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  664. }
  665. // GET / HTTP/1.1
  666. // Host: example.com
  667. if err := req.Write(tlsconn); err != nil {
  668. t.Fatalf("req.Write(): got %v, want no error", err)
  669. }
  670. // Response from MITM proxy.
  671. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req)
  672. if err != nil {
  673. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  674. }
  675. defer res.Body.Close()
  676. if got, want := res.StatusCode, 200; got != want {
  677. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  678. }
  679. if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
  680. t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
  681. }
  682. if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) {
  683. t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
  684. }
  685. }
  686. func TestIntegrationTransparentHTTP(t *testing.T) {
  687. t.Parallel()
  688. l, err := net.Listen("tcp", "[::]:0")
  689. if err != nil {
  690. t.Fatalf("net.Listen(): got %v, want no error", err)
  691. }
  692. p := NewProxy()
  693. defer p.Close()
  694. tr := martiantest.NewTransport()
  695. p.SetRoundTripper(tr)
  696. p.SetTimeout(200 * time.Millisecond)
  697. tm := martiantest.NewModifier()
  698. p.SetRequestModifier(tm)
  699. p.SetResponseModifier(tm)
  700. go p.Serve(l)
  701. conn, err := net.Dial("tcp", l.Addr().String())
  702. if err != nil {
  703. t.Fatalf("net.Dial(): got %v, want no error", err)
  704. }
  705. defer conn.Close()
  706. req, err := http.NewRequest("GET", "http://example.com", nil)
  707. if err != nil {
  708. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  709. }
  710. // GET / HTTP/1.1
  711. // Host: www.example.com
  712. if err := req.Write(conn); err != nil {
  713. t.Fatalf("req.Write(): got %v, want no error", err)
  714. }
  715. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  716. if err != nil {
  717. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  718. }
  719. if got, want := res.StatusCode, 200; got != want {
  720. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  721. }
  722. if !tm.RequestModified() {
  723. t.Error("tm.RequestModified(): got false, want true")
  724. }
  725. if !tm.ResponseModified() {
  726. t.Error("tm.ResponseModified(): got false, want true")
  727. }
  728. }
  729. func TestIntegrationTransparentMITM(t *testing.T) {
  730. t.Parallel()
  731. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  732. if err != nil {
  733. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  734. }
  735. mc, err := mitm.NewConfig(ca, priv)
  736. if err != nil {
  737. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  738. }
  739. // Start TLS listener with config that will generate certificates based on
  740. // SNI from connection.
  741. //
  742. // BUG: tls.Listen will not accept a tls.Config where Certificates is empty,
  743. // even though it is supported by tls.Server when GetCertificate is not nil.
  744. l, err := net.Listen("tcp", "[::]:0")
  745. if err != nil {
  746. t.Fatalf("net.Listen(): got %v, want no error", err)
  747. }
  748. l = tls.NewListener(l, mc.TLS())
  749. p := NewProxy()
  750. defer p.Close()
  751. tr := martiantest.NewTransport()
  752. tr.Func(func(req *http.Request) (*http.Response, error) {
  753. res := proxyutil.NewResponse(200, nil, req)
  754. res.Header.Set("Request-Scheme", req.URL.Scheme)
  755. return res, nil
  756. })
  757. p.SetRoundTripper(tr)
  758. tm := martiantest.NewModifier()
  759. p.SetRequestModifier(tm)
  760. p.SetResponseModifier(tm)
  761. go p.Serve(l)
  762. roots := x509.NewCertPool()
  763. roots.AddCert(ca)
  764. tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
  765. // Verify the hostname is example.com.
  766. ServerName: "example.com",
  767. // The certificate will have been generated during MITM, so we need to
  768. // verify it with the generated CA certificate.
  769. RootCAs: roots,
  770. })
  771. if err != nil {
  772. t.Fatalf("tls.Dial(): got %v, want no error", err)
  773. }
  774. defer tlsconn.Close()
  775. req, err := http.NewRequest("GET", "https://example.com", nil)
  776. if err != nil {
  777. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  778. }
  779. // Write Encrypted request directly, no CONNECT.
  780. // GET / HTTP/1.1
  781. // Host: example.com
  782. if err := req.Write(tlsconn); err != nil {
  783. t.Fatalf("req.Write(): got %v, want no error", err)
  784. }
  785. res, err := http.ReadResponse(bufio.NewReader(tlsconn), req)
  786. if err != nil {
  787. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  788. }
  789. defer res.Body.Close()
  790. if got, want := res.StatusCode, 200; got != want {
  791. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  792. }
  793. if got, want := res.Header.Get("Request-Scheme"), "https"; got != want {
  794. t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want)
  795. }
  796. if !tm.RequestModified() {
  797. t.Errorf("tm.RequestModified(): got false, want true")
  798. }
  799. if !tm.ResponseModified() {
  800. t.Errorf("tm.ResponseModified(): got false, want true")
  801. }
  802. }
  803. func TestIntegrationFailedRoundTrip(t *testing.T) {
  804. t.Parallel()
  805. l, err := net.Listen("tcp", "[::]:0")
  806. if err != nil {
  807. t.Fatalf("net.Listen(): got %v, want no error", err)
  808. }
  809. p := NewProxy()
  810. defer p.Close()
  811. tr := martiantest.NewTransport()
  812. trerr := errors.New("round trip error")
  813. tr.RespondError(trerr)
  814. p.SetRoundTripper(tr)
  815. p.SetTimeout(200 * time.Millisecond)
  816. go p.Serve(l)
  817. conn, err := net.Dial("tcp", l.Addr().String())
  818. if err != nil {
  819. t.Fatalf("net.Dial(): got %v, want no error", err)
  820. }
  821. defer conn.Close()
  822. req, err := http.NewRequest("GET", "http://example.com", nil)
  823. if err != nil {
  824. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  825. }
  826. // GET http://example.com/ HTTP/1.1
  827. // Host: example.com
  828. if err := req.WriteProxy(conn); err != nil {
  829. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  830. }
  831. // Response from failed round trip.
  832. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  833. if err != nil {
  834. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  835. }
  836. defer res.Body.Close()
  837. if got, want := res.StatusCode, 502; got != want {
  838. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  839. }
  840. if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) {
  841. t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want)
  842. }
  843. }
  844. func TestIntegrationSkipRoundTrip(t *testing.T) {
  845. t.Parallel()
  846. l, err := net.Listen("tcp", "[::]:0")
  847. if err != nil {
  848. t.Fatalf("net.Listen(): got %v, want no error", err)
  849. }
  850. p := NewProxy()
  851. defer p.Close()
  852. // Transport will be skipped, no 500.
  853. tr := martiantest.NewTransport()
  854. tr.Respond(500)
  855. p.SetRoundTripper(tr)
  856. p.SetTimeout(200 * time.Millisecond)
  857. tm := martiantest.NewModifier()
  858. tm.RequestFunc(func(req *http.Request) {
  859. ctx := NewContext(req)
  860. ctx.SkipRoundTrip()
  861. })
  862. p.SetRequestModifier(tm)
  863. go p.Serve(l)
  864. conn, err := net.Dial("tcp", l.Addr().String())
  865. if err != nil {
  866. t.Fatalf("net.Dial(): got %v, want no error", err)
  867. }
  868. defer conn.Close()
  869. req, err := http.NewRequest("GET", "http://example.com", nil)
  870. if err != nil {
  871. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  872. }
  873. // GET http://example.com/ HTTP/1.1
  874. // Host: example.com
  875. if err := req.WriteProxy(conn); err != nil {
  876. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  877. }
  878. // Response from skipped round trip.
  879. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  880. if err != nil {
  881. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  882. }
  883. defer res.Body.Close()
  884. if got, want := res.StatusCode, 200; got != want {
  885. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  886. }
  887. }
  888. func TestHTTPThroughConnectWithMITM(t *testing.T) {
  889. t.Parallel()
  890. l, err := net.Listen("tcp", "[::]:0")
  891. if err != nil {
  892. t.Fatalf("net.Listen(): got %v, want no error", err)
  893. }
  894. p := NewProxy()
  895. defer p.Close()
  896. tm := martiantest.NewModifier()
  897. tm.RequestFunc(func(req *http.Request) {
  898. ctx := NewContext(req)
  899. ctx.SkipRoundTrip()
  900. if req.Method != "GET" && req.Method != "CONNECT" {
  901. t.Errorf("unexpected method on request handler: %v", req.Method)
  902. }
  903. })
  904. p.SetRequestModifier(tm)
  905. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  906. if err != nil {
  907. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  908. }
  909. mc, err := mitm.NewConfig(ca, priv)
  910. if err != nil {
  911. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  912. }
  913. p.SetMITM(mc)
  914. go p.Serve(l)
  915. conn, err := net.Dial("tcp", l.Addr().String())
  916. if err != nil {
  917. t.Fatalf("net.Dial(): got %v, want no error", err)
  918. }
  919. defer conn.Close()
  920. req, err := http.NewRequest("CONNECT", "//example.com:80", nil)
  921. if err != nil {
  922. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  923. }
  924. // CONNECT example.com:80 HTTP/1.1
  925. // Host: example.com
  926. if err := req.Write(conn); err != nil {
  927. t.Fatalf("req.Write(): got %v, want no error", err)
  928. }
  929. // Response skipped round trip.
  930. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  931. if err != nil {
  932. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  933. }
  934. res.Body.Close()
  935. if got, want := res.StatusCode, 200; got != want {
  936. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  937. }
  938. req, err = http.NewRequest("GET", "http://example.com", nil)
  939. if err != nil {
  940. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  941. }
  942. // GET http://example.com/ HTTP/1.1
  943. // Host: example.com
  944. if err := req.WriteProxy(conn); err != nil {
  945. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  946. }
  947. // Response from skipped round trip.
  948. res, err = http.ReadResponse(bufio.NewReader(conn), req)
  949. if err != nil {
  950. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  951. }
  952. res.Body.Close()
  953. if got, want := res.StatusCode, 200; got != want {
  954. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  955. }
  956. req, err = http.NewRequest("GET", "http://example.com", nil)
  957. if err != nil {
  958. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  959. }
  960. // GET http://example.com/ HTTP/1.1
  961. // Host: example.com
  962. if err := req.WriteProxy(conn); err != nil {
  963. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  964. }
  965. // Response from skipped round trip.
  966. res, err = http.ReadResponse(bufio.NewReader(conn), req)
  967. if err != nil {
  968. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  969. }
  970. res.Body.Close()
  971. if got, want := res.StatusCode, 200; got != want {
  972. t.Errorf("res.StatusCode: got %d, want %d", got, want)
  973. }
  974. }
  975. func TestServerClosesConnection(t *testing.T) {
  976. t.Parallel()
  977. dstl, err := net.Listen("tcp", "[::]:0")
  978. if err != nil {
  979. t.Fatalf("Failed to create http listener: %v", err)
  980. }
  981. defer dstl.Close()
  982. go func() {
  983. t.Logf("Waiting for server side connection")
  984. conn, err := dstl.Accept()
  985. if err != nil {
  986. t.Fatalf("Got error while accepting connection on destination listener: %v", err)
  987. }
  988. t.Logf("Accepted server side connection")
  989. buf := make([]byte, 16384)
  990. if _, err := conn.Read(buf); err != nil {
  991. t.Fatalf("Error reading: %v", err)
  992. }
  993. _, err = conn.Write([]byte("HTTP/1.1 301 MOVED PERMANENTLY\r\n" +
  994. "Server: \r\n" +
  995. "Date: \r\n" +
  996. "Referer: \r\n" +
  997. "Location: http://www.foo.com/\r\n" +
  998. "Content-type: text/html\r\n" +
  999. "Connection: close\r\n\r\n"))
  1000. if err != nil {
  1001. t.Fatalf("Got error while writting to connection on destination listener: %v", err)
  1002. }
  1003. conn.Close()
  1004. }()
  1005. l, err := net.Listen("tcp", "[::]:0")
  1006. if err != nil {
  1007. t.Fatalf("net.Listen(): got %v, want no error", err)
  1008. }
  1009. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour)
  1010. if err != nil {
  1011. t.Fatalf("mitm.NewAuthority(): got %v, want no error", err)
  1012. }
  1013. mc, err := mitm.NewConfig(ca, priv)
  1014. if err != nil {
  1015. t.Fatalf("mitm.NewConfig(): got %v, want no error", err)
  1016. }
  1017. p := NewProxy()
  1018. p.SetMITM(mc)
  1019. defer p.Close()
  1020. // Start the proxy with a listener that will return a temporary error on
  1021. // Accept() three times.
  1022. go p.Serve(newTimeoutListener(l, 3))
  1023. conn, err := net.Dial("tcp", l.Addr().String())
  1024. if err != nil {
  1025. t.Fatalf("net.Dial(): got %v, want no error", err)
  1026. }
  1027. defer conn.Close()
  1028. req, err := http.NewRequest("CONNECT", fmt.Sprintf("//%s", dstl.Addr().String()), nil)
  1029. if err != nil {
  1030. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  1031. }
  1032. // CONNECT example.com:443 HTTP/1.1
  1033. // Host: example.com
  1034. if err := req.Write(conn); err != nil {
  1035. t.Fatalf("req.Write(): got %v, want no error", err)
  1036. }
  1037. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  1038. if err != nil {
  1039. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1040. }
  1041. res.Body.Close()
  1042. _, err = conn.Write([]byte("GET / HTTP/1.1\r\n" +
  1043. "User-Agent: curl/7.35.0\r\n" +
  1044. fmt.Sprintf("Host: %s\r\n", dstl.Addr()) +
  1045. "Accept: */*\r\n\r\n"))
  1046. if err != nil {
  1047. t.Fatalf("Error while writing GET request: %v", err)
  1048. }
  1049. res, err = http.ReadResponse(bufio.NewReader(io.TeeReader(conn, os.Stderr)), req)
  1050. if err != nil {
  1051. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  1052. }
  1053. _, err = ioutil.ReadAll(res.Body)
  1054. if err != nil {
  1055. t.Fatalf("error while ReadAll: %v", err)
  1056. }
  1057. defer res.Body.Close()
  1058. }