You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

665 lines
15 KiB

  1. package martian
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io/ioutil"
  6. "net"
  7. "net/http"
  8. "net/http/httptest"
  9. "strings"
  10. "testing"
  11. "time"
  12. "github.com/google/martian/log"
  13. "github.com/google/martian/martiantest"
  14. "github.com/google/martian/trafficshape"
  15. )
  16. // Tests that sending data of length 600 bytes with max bandwidth of 100 bytes/s takes
  17. // atleast 4.9s. Uses the Close Connection action to immediately close the connection
  18. // upon the proxy writing 600 bytes. (4.9s ~ 5s = 600b /100b/s - 1s)
  19. func TestConstantThrottleAndClose(t *testing.T) {
  20. log.SetLevel(log.Info)
  21. l, err := net.Listen("tcp", "[::]:0")
  22. if err != nil {
  23. t.Fatalf("net.Listen(): got %v, want no error", err)
  24. }
  25. tsl := trafficshape.NewListener(l)
  26. tsh := trafficshape.NewHandler(tsl)
  27. // This is the data to be sent.
  28. testString := strings.Repeat("0", 600)
  29. // Traffic shaping config request.
  30. jsonString :=
  31. `{
  32. "trafficshape": {
  33. "shapes": [
  34. {
  35. "url_regex": "http://example/example",
  36. "throttles": [
  37. {
  38. "bytes": "0-",
  39. "bandwidth": 100
  40. }
  41. ],
  42. "close_connections": [
  43. {
  44. "byte": 600,
  45. "count": 1
  46. }
  47. ]
  48. }
  49. ]
  50. }
  51. }`
  52. tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
  53. rw := httptest.NewRecorder()
  54. tsh.ServeHTTP(rw, tsReq)
  55. res := rw.Result()
  56. if got, want := res.StatusCode, 200; got != want {
  57. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  58. }
  59. p := NewProxy()
  60. defer p.Close()
  61. p.SetRequestModifier(nil)
  62. p.SetResponseModifier(nil)
  63. tr := martiantest.NewTransport()
  64. p.SetRoundTripper(tr)
  65. p.SetTimeout(15 * time.Second)
  66. tm := martiantest.NewModifier()
  67. tm.RequestFunc(func(req *http.Request) {
  68. ctx := NewContext(req)
  69. ctx.SkipRoundTrip()
  70. })
  71. tm.ResponseFunc(func(res *http.Response) {
  72. res.StatusCode = http.StatusOK
  73. res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
  74. })
  75. p.SetRequestModifier(tm)
  76. p.SetResponseModifier(tm)
  77. go p.Serve(tsl)
  78. c1 := make(chan string)
  79. conn, err := net.Dial("tcp", l.Addr().String())
  80. defer conn.Close()
  81. if err != nil {
  82. t.Fatalf("net.Dial(): got %v, want no error", err)
  83. }
  84. go func() {
  85. req, err := http.NewRequest("GET", "http://example/example", nil)
  86. if err != nil {
  87. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  88. }
  89. if err := req.WriteProxy(conn); err != nil {
  90. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  91. }
  92. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  93. if err != nil {
  94. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  95. }
  96. body, _ := ioutil.ReadAll(res.Body)
  97. bodystr := string(body)
  98. c1 <- bodystr
  99. }()
  100. var bodystr string
  101. select {
  102. case bodystringc := <-c1:
  103. t.Errorf("took < 4.9s, should take at least 4.9s")
  104. bodystr = bodystringc
  105. case <-time.After(4900 * time.Millisecond):
  106. bodystringc := <-c1
  107. bodystr = bodystringc
  108. }
  109. if bodystr != testString {
  110. t.Errorf("res.Body: got %s, want %s", bodystr, testString)
  111. }
  112. }
  113. // Tests that sleeping for 5s and then closing the connection
  114. // upon reading 200 bytes, with a bandwidth of 5000 bytes/s
  115. // takes at least 4.9s, and results in a correctly trimmed
  116. // response body. (200 0s instead of 500 0s)
  117. func TestSleepAndClose(t *testing.T) {
  118. log.SetLevel(log.Info)
  119. l, err := net.Listen("tcp", "[::]:0")
  120. if err != nil {
  121. t.Fatalf("net.Listen(): got %v, want no error", err)
  122. }
  123. tsl := trafficshape.NewListener(l)
  124. tsh := trafficshape.NewHandler(tsl)
  125. // This is the data to be sent.
  126. testString := strings.Repeat("0", 500)
  127. // Traffic shaping config request.
  128. jsonString :=
  129. `{
  130. "trafficshape": {
  131. "shapes": [
  132. {
  133. "url_regex": "http://example/example",
  134. "throttles": [
  135. {
  136. "bytes": "0-",
  137. "bandwidth": 5000
  138. }
  139. ],
  140. "halts": [
  141. {
  142. "byte": 100,
  143. "duration": 5000,
  144. "count": 1
  145. }
  146. ],
  147. "close_connections": [
  148. {
  149. "byte": 200,
  150. "count": 1
  151. }
  152. ]
  153. }
  154. ]
  155. }
  156. }`
  157. tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
  158. rw := httptest.NewRecorder()
  159. tsh.ServeHTTP(rw, tsReq)
  160. res := rw.Result()
  161. if got, want := res.StatusCode, 200; got != want {
  162. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  163. }
  164. p := NewProxy()
  165. defer p.Close()
  166. p.SetRequestModifier(nil)
  167. p.SetResponseModifier(nil)
  168. tr := martiantest.NewTransport()
  169. p.SetRoundTripper(tr)
  170. p.SetTimeout(15 * time.Second)
  171. tm := martiantest.NewModifier()
  172. tm.RequestFunc(func(req *http.Request) {
  173. ctx := NewContext(req)
  174. ctx.SkipRoundTrip()
  175. })
  176. tm.ResponseFunc(func(res *http.Response) {
  177. res.StatusCode = http.StatusOK
  178. res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
  179. })
  180. p.SetRequestModifier(tm)
  181. p.SetResponseModifier(tm)
  182. go p.Serve(tsl)
  183. c1 := make(chan string)
  184. conn, err := net.Dial("tcp", l.Addr().String())
  185. defer conn.Close()
  186. if err != nil {
  187. t.Fatalf("net.Dial(): got %v, want no error", err)
  188. }
  189. go func() {
  190. req, err := http.NewRequest("GET", "http://example/example", nil)
  191. if err != nil {
  192. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  193. }
  194. if err := req.WriteProxy(conn); err != nil {
  195. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  196. }
  197. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  198. if err != nil {
  199. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  200. }
  201. body, _ := ioutil.ReadAll(res.Body)
  202. bodystr := string(body)
  203. c1 <- bodystr
  204. }()
  205. var bodystr string
  206. select {
  207. case bodystringc := <-c1:
  208. t.Errorf("took < 4.9s, should take at least 4.9s")
  209. bodystr = bodystringc
  210. case <-time.After(4900 * time.Millisecond):
  211. bodystringc := <-c1
  212. bodystr = bodystringc
  213. }
  214. if want := strings.Repeat("0", 200); bodystr != want {
  215. t.Errorf("res.Body: got %s, want %s", bodystr, want)
  216. }
  217. }
  218. // Similar to TestConstantThrottleAndClose, except that it applies
  219. // the throttle only in a specific byte range, and modifies the
  220. // the response to lie in the byte range.
  221. func TestConstantThrottleAndCloseByteRange(t *testing.T) {
  222. log.SetLevel(log.Info)
  223. l, err := net.Listen("tcp", "[::]:0")
  224. if err != nil {
  225. t.Fatalf("net.Listen(): got %v, want no error", err)
  226. }
  227. tsl := trafficshape.NewListener(l)
  228. tsh := trafficshape.NewHandler(tsl)
  229. // This is the data to be sent.
  230. testString := strings.Repeat("0", 600)
  231. // Traffic shaping config request.
  232. jsonString :=
  233. `{
  234. "trafficshape": {
  235. "shapes": [
  236. {
  237. "url_regex": "http://example/example",
  238. "throttles": [
  239. {
  240. "bytes": "500-",
  241. "bandwidth": 100
  242. }
  243. ],
  244. "close_connections": [
  245. {
  246. "byte": 1100,
  247. "count": 1
  248. }
  249. ]
  250. }
  251. ]
  252. }
  253. }`
  254. tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
  255. rw := httptest.NewRecorder()
  256. tsh.ServeHTTP(rw, tsReq)
  257. res := rw.Result()
  258. if got, want := res.StatusCode, 200; got != want {
  259. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  260. }
  261. p := NewProxy()
  262. defer p.Close()
  263. p.SetRequestModifier(nil)
  264. p.SetResponseModifier(nil)
  265. tr := martiantest.NewTransport()
  266. p.SetRoundTripper(tr)
  267. p.SetTimeout(15 * time.Second)
  268. tm := martiantest.NewModifier()
  269. tm.RequestFunc(func(req *http.Request) {
  270. ctx := NewContext(req)
  271. ctx.SkipRoundTrip()
  272. })
  273. tm.ResponseFunc(func(res *http.Response) {
  274. res.StatusCode = http.StatusPartialContent
  275. res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
  276. res.Header.Set("Content-Range", "bytes 500-1100/1100")
  277. })
  278. p.SetRequestModifier(tm)
  279. p.SetResponseModifier(tm)
  280. go p.Serve(tsl)
  281. c1 := make(chan string)
  282. conn, err := net.Dial("tcp", l.Addr().String())
  283. defer conn.Close()
  284. if err != nil {
  285. t.Fatalf("net.Dial(): got %v, want no error", err)
  286. }
  287. go func() {
  288. req, err := http.NewRequest("GET", "http://example/example", nil)
  289. if err != nil {
  290. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  291. }
  292. if err := req.WriteProxy(conn); err != nil {
  293. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  294. }
  295. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  296. if err != nil {
  297. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  298. }
  299. body, _ := ioutil.ReadAll(res.Body)
  300. bodystr := string(body)
  301. c1 <- bodystr
  302. }()
  303. var bodystr string
  304. select {
  305. case bodystringc := <-c1:
  306. t.Errorf("took < 4.9s, should take at least 4.9s")
  307. bodystr = bodystringc
  308. case <-time.After(4900 * time.Millisecond):
  309. bodystringc := <-c1
  310. bodystr = bodystringc
  311. }
  312. if bodystr != testString {
  313. t.Errorf("res.Body: got %s, want %s", bodystr, testString)
  314. }
  315. }
  316. // Opens up 5 concurrent connections, and sets the
  317. // max global bandwidth for the url regex to be 250b/s.
  318. // Every connection tries to read 500b of data, but since
  319. // the global bandwidth for the particular regex is 250,
  320. // it should take at least 5 * 500b / 250b/s -1s = 9s to read
  321. // everything.
  322. func TestMaxBandwidth(t *testing.T) {
  323. log.SetLevel(log.Info)
  324. l, err := net.Listen("tcp", "[::]:0")
  325. if err != nil {
  326. t.Fatalf("net.Listen(): got %v, want no error", err)
  327. }
  328. tsl := trafficshape.NewListener(l)
  329. tsh := trafficshape.NewHandler(tsl)
  330. // This is the data to be sent.
  331. testString := strings.Repeat("0", 500)
  332. // Traffic shaping config request.
  333. jsonString :=
  334. `{
  335. "trafficshape": {
  336. "shapes": [
  337. {
  338. "url_regex": "http://example/example",
  339. "max_global_bandwidth": 250,
  340. "close_connections": [
  341. {
  342. "byte": 500,
  343. "count": 5
  344. }
  345. ]
  346. }
  347. ]
  348. }
  349. }`
  350. tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
  351. rw := httptest.NewRecorder()
  352. tsh.ServeHTTP(rw, tsReq)
  353. res := rw.Result()
  354. if got, want := res.StatusCode, 200; got != want {
  355. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  356. }
  357. p := NewProxy()
  358. defer p.Close()
  359. p.SetRequestModifier(nil)
  360. p.SetResponseModifier(nil)
  361. tr := martiantest.NewTransport()
  362. p.SetRoundTripper(tr)
  363. p.SetTimeout(20 * time.Second)
  364. tm := martiantest.NewModifier()
  365. tm.RequestFunc(func(req *http.Request) {
  366. ctx := NewContext(req)
  367. ctx.SkipRoundTrip()
  368. })
  369. tm.ResponseFunc(func(res *http.Response) {
  370. res.StatusCode = http.StatusOK
  371. res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
  372. })
  373. p.SetRequestModifier(tm)
  374. p.SetResponseModifier(tm)
  375. go p.Serve(tsl)
  376. numChannels := 5
  377. channels := make([]chan string, numChannels)
  378. for i := 0; i < numChannels; i++ {
  379. channels[i] = make(chan string)
  380. }
  381. for i := 0; i < numChannels; i++ {
  382. go func(i int) {
  383. conn, err := net.Dial("tcp", l.Addr().String())
  384. defer conn.Close()
  385. if err != nil {
  386. t.Fatalf("net.Dial(): got %v, want no error", err)
  387. }
  388. req, err := http.NewRequest("GET", "http://example/example", nil)
  389. if err != nil {
  390. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  391. }
  392. if err := req.WriteProxy(conn); err != nil {
  393. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  394. }
  395. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  396. if err != nil {
  397. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  398. }
  399. body, _ := ioutil.ReadAll(res.Body)
  400. bodystr := string(body)
  401. if i != 0 {
  402. <-channels[i-1]
  403. }
  404. channels[i] <- bodystr
  405. }(i)
  406. }
  407. var bodystr string
  408. select {
  409. case bodystringc := <-channels[numChannels-1]:
  410. t.Errorf("took < 8.9s, should take at least 8.9s")
  411. bodystr = bodystringc
  412. case <-time.After(8900 * time.Millisecond):
  413. bodystringc := <-channels[numChannels-1]
  414. bodystr = bodystringc
  415. }
  416. if bodystr != testString {
  417. t.Errorf("res.Body: got %s, want %s", bodystr, testString)
  418. }
  419. }
  420. // Makes 2 requests, with the first one having a byte range starting
  421. // at byte 250, and adds a close connection action at byte 450.
  422. // The first request should hit the action sooner,
  423. // and delete it. The second request should read the whole
  424. // data (500b)
  425. func TestConcurrentResponseActions(t *testing.T) {
  426. log.SetLevel(log.Info)
  427. l, err := net.Listen("tcp", "[::]:0")
  428. if err != nil {
  429. t.Fatalf("net.Listen(): got %v, want no error", err)
  430. }
  431. tsl := trafficshape.NewListener(l)
  432. tsh := trafficshape.NewHandler(tsl)
  433. // This is the data to be sent.
  434. testString := strings.Repeat("0", 500)
  435. // Traffic shaping config request.
  436. jsonString :=
  437. `{
  438. "trafficshape": {
  439. "shapes": [
  440. {
  441. "url_regex": "http://example/example",
  442. "throttles": [
  443. {
  444. "bytes": "-",
  445. "bandwidth": 250
  446. }
  447. ],
  448. "close_connections": [
  449. {
  450. "byte": 450,
  451. "count": 1
  452. },
  453. {
  454. "byte": 500,
  455. "count": 1
  456. }
  457. ]
  458. }
  459. ]
  460. }
  461. }`
  462. tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
  463. rw := httptest.NewRecorder()
  464. tsh.ServeHTTP(rw, tsReq)
  465. res := rw.Result()
  466. if got, want := res.StatusCode, 200; got != want {
  467. t.Fatalf("res.StatusCode: got %d, want %d", got, want)
  468. }
  469. p := NewProxy()
  470. defer p.Close()
  471. p.SetRequestModifier(nil)
  472. p.SetResponseModifier(nil)
  473. tr := martiantest.NewTransport()
  474. p.SetRoundTripper(tr)
  475. p.SetTimeout(20 * time.Second)
  476. tm := martiantest.NewModifier()
  477. tm.RequestFunc(func(req *http.Request) {
  478. ctx := NewContext(req)
  479. ctx.SkipRoundTrip()
  480. })
  481. tm.ResponseFunc(func(res *http.Response) {
  482. cr := res.Request.Header.Get("ContentRange")
  483. res.StatusCode = http.StatusOK
  484. res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
  485. if cr != "" {
  486. res.StatusCode = http.StatusPartialContent
  487. res.Header.Set("Content-Range", cr)
  488. }
  489. })
  490. p.SetRequestModifier(tm)
  491. p.SetResponseModifier(tm)
  492. go p.Serve(tsl)
  493. c1 := make(chan string)
  494. c2 := make(chan string)
  495. go func() {
  496. conn, err := net.Dial("tcp", l.Addr().String())
  497. defer conn.Close()
  498. if err != nil {
  499. t.Fatalf("net.Dial(): got %v, want no error", err)
  500. }
  501. req, err := http.NewRequest("GET", "http://example/example", nil)
  502. req.Header.Set("ContentRange", "bytes 250-1000/1000")
  503. if err != nil {
  504. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  505. }
  506. if err := req.WriteProxy(conn); err != nil {
  507. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  508. }
  509. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  510. if err != nil {
  511. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  512. }
  513. body, _ := ioutil.ReadAll(res.Body)
  514. bodystr := string(body)
  515. c1 <- bodystr
  516. }()
  517. go func() {
  518. conn, err := net.Dial("tcp", l.Addr().String())
  519. defer conn.Close()
  520. if err != nil {
  521. t.Fatalf("net.Dial(): got %v, want no error", err)
  522. }
  523. req, err := http.NewRequest("GET", "http://example/example", nil)
  524. if err != nil {
  525. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  526. }
  527. if err := req.WriteProxy(conn); err != nil {
  528. t.Fatalf("req.WriteProxy(): got %v, want no error", err)
  529. }
  530. res, err := http.ReadResponse(bufio.NewReader(conn), req)
  531. if err != nil {
  532. t.Fatalf("http.ReadResponse(): got %v, want no error", err)
  533. }
  534. body, _ := ioutil.ReadAll(res.Body)
  535. bodystr := string(body)
  536. c2 <- bodystr
  537. }()
  538. bodystr1 := <-c1
  539. bodystr2 := <-c2
  540. if want1 := strings.Repeat("0", 200); bodystr1 != want1 {
  541. t.Errorf("res.Body: got %s, want %s", bodystr1, want1)
  542. }
  543. if want2 := strings.Repeat("0", 500); bodystr2 != want2 {
  544. t.Errorf("res.Body: got %s, want %s", bodystr2, want2)
  545. }
  546. }