You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

244 lines
6.5 KiB

  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package priority
  15. import (
  16. "errors"
  17. "net/http"
  18. "reflect"
  19. "testing"
  20. "github.com/google/martian/martiantest"
  21. "github.com/google/martian/parse"
  22. "github.com/google/martian/proxyutil"
  23. // Import to register header.Modifier with JSON parser.
  24. _ "github.com/google/martian/header"
  25. )
  26. func TestPriorityGroupModifyRequest(t *testing.T) {
  27. var order []string
  28. pg := NewGroup()
  29. tm50 := martiantest.NewModifier()
  30. tm50.RequestFunc(func(*http.Request) {
  31. order = append(order, "tm50")
  32. })
  33. pg.AddRequestModifier(tm50, 50)
  34. tm100a := martiantest.NewModifier()
  35. tm100a.RequestFunc(func(*http.Request) {
  36. order = append(order, "tm100a")
  37. })
  38. pg.AddRequestModifier(tm100a, 100)
  39. tm100b := martiantest.NewModifier()
  40. tm100b.RequestFunc(func(*http.Request) {
  41. order = append(order, "tm100b")
  42. })
  43. pg.AddRequestModifier(tm100b, 100)
  44. tm75 := martiantest.NewModifier()
  45. tm75.RequestFunc(func(*http.Request) {
  46. order = append(order, "tm75")
  47. })
  48. if err := pg.RemoveRequestModifier(tm75); err != ErrModifierNotFound {
  49. t.Fatalf("RemoveRequestModifier(): got %v, want ErrModifierNotFound", err)
  50. }
  51. pg.AddRequestModifier(tm75, 100)
  52. if err := pg.RemoveRequestModifier(tm75); err != nil {
  53. t.Fatalf("RemoveRequestModifier(): got %v, want no error", err)
  54. }
  55. req, err := http.NewRequest("GET", "http://example.com/", nil)
  56. if err != nil {
  57. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  58. }
  59. if err := pg.ModifyRequest(req); err != nil {
  60. t.Fatalf("ModifyRequest(): got %v, want no error", err)
  61. }
  62. if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) {
  63. t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want)
  64. }
  65. }
  66. func TestPriorityGroupModifyRequestHaltsOnError(t *testing.T) {
  67. pg := NewGroup()
  68. reqerr := errors.New("request error")
  69. tm := martiantest.NewModifier()
  70. tm.RequestError(reqerr)
  71. pg.AddRequestModifier(tm, 100)
  72. tm2 := martiantest.NewModifier()
  73. pg.AddRequestModifier(tm2, 75)
  74. req, err := http.NewRequest("GET", "http://example.com/", nil)
  75. if err != nil {
  76. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  77. }
  78. if err := pg.ModifyRequest(req); err != reqerr {
  79. t.Fatalf("ModifyRequest(): got %v, want %v", err, reqerr)
  80. }
  81. if tm2.RequestModified() {
  82. t.Error("tm2.RequestModified(): got true, want false")
  83. }
  84. }
  85. func TestPriorityGroupModifyResponse(t *testing.T) {
  86. var order []string
  87. pg := NewGroup()
  88. tm50 := martiantest.NewModifier()
  89. tm50.ResponseFunc(func(*http.Response) {
  90. order = append(order, "tm50")
  91. })
  92. pg.AddResponseModifier(tm50, 50)
  93. tm100a := martiantest.NewModifier()
  94. tm100a.ResponseFunc(func(*http.Response) {
  95. order = append(order, "tm100a")
  96. })
  97. pg.AddResponseModifier(tm100a, 100)
  98. tm100b := martiantest.NewModifier()
  99. tm100b.ResponseFunc(func(*http.Response) {
  100. order = append(order, "tm100b")
  101. })
  102. pg.AddResponseModifier(tm100b, 100)
  103. tm75 := martiantest.NewModifier()
  104. tm75.ResponseFunc(func(*http.Response) {
  105. order = append(order, "tm75")
  106. })
  107. if err := pg.RemoveResponseModifier(tm75); err != ErrModifierNotFound {
  108. t.Fatalf("RemoveResponseModifier(): got %v, want ErrModifierNotFound", err)
  109. }
  110. pg.AddResponseModifier(tm75, 100)
  111. if err := pg.RemoveResponseModifier(tm75); err != nil {
  112. t.Fatalf("RemoveResponseModifier(): got %v, want no error", err)
  113. }
  114. res := proxyutil.NewResponse(200, nil, nil)
  115. if err := pg.ModifyResponse(res); err != nil {
  116. t.Fatalf("ModifyResponse(): got %v, want no error", err)
  117. }
  118. if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) {
  119. t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want)
  120. }
  121. }
  122. func TestPriorityGroupModifyResponseHaltsOnError(t *testing.T) {
  123. pg := NewGroup()
  124. reserr := errors.New("response error")
  125. tm := martiantest.NewModifier()
  126. tm.ResponseError(reserr)
  127. pg.AddResponseModifier(tm, 100)
  128. tm2 := martiantest.NewModifier()
  129. pg.AddResponseModifier(tm2, 75)
  130. res := proxyutil.NewResponse(200, nil, nil)
  131. if err := pg.ModifyResponse(res); err != reserr {
  132. t.Fatalf("ModifyRequest(): got %v, want %v", err, reserr)
  133. }
  134. if tm2.ResponseModified() {
  135. t.Error("tm2.ResponseModified(): got true, want false")
  136. }
  137. }
  138. func TestGroupFromJSON(t *testing.T) {
  139. msg := []byte(`{
  140. "priority.Group": {
  141. "scope": ["request", "response"],
  142. "modifiers": [
  143. {
  144. "priority": 100,
  145. "modifier": {
  146. "header.Modifier": {
  147. "scope": ["request", "response"],
  148. "name": "X-Testing",
  149. "value": "true"
  150. }
  151. }
  152. },
  153. {
  154. "priority": 0,
  155. "modifier": {
  156. "header.Modifier": {
  157. "scope": ["request", "response"],
  158. "name": "Y-Testing",
  159. "value": "true"
  160. }
  161. }
  162. }
  163. ]
  164. }
  165. }`)
  166. r, err := parse.FromJSON(msg)
  167. if err != nil {
  168. t.Fatalf("parse.FromJSON(): got %v, want no error", err)
  169. }
  170. reqmod := r.RequestModifier()
  171. if reqmod == nil {
  172. t.Fatal("reqmod: got nil, want not nil")
  173. }
  174. req, err := http.NewRequest("GET", "http://example.com", nil)
  175. if err != nil {
  176. t.Fatalf("http.NewRequest(): got %v, want no error", err)
  177. }
  178. if err := reqmod.ModifyRequest(req); err != nil {
  179. t.Fatalf("ModifyRequest(): got %v, want no error", err)
  180. }
  181. if got, want := req.Header.Get("X-Testing"), "true"; got != want {
  182. t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Testing", got, want)
  183. }
  184. if got, want := req.Header.Get("Y-Testing"), "true"; got != want {
  185. t.Errorf("req.Header.Get(%q): got %q, want %q", "Y-Testing", got, want)
  186. }
  187. resmod := r.ResponseModifier()
  188. if resmod == nil {
  189. t.Fatal("resmod: got nil, want not nil")
  190. }
  191. res := proxyutil.NewResponse(200, nil, req)
  192. if err := resmod.ModifyResponse(res); err != nil {
  193. t.Fatalf("ModifyResponse(): got %v, want no error", err)
  194. }
  195. if got, want := res.Header.Get("X-Testing"), "true"; got != want {
  196. t.Errorf("res.Header.Get(%q): got %q, want %q", "X-Testing", got, want)
  197. }
  198. if got, want := res.Header.Get("Y-Testing"), "true"; got != want {
  199. t.Errorf("res.Header.Get(%q): got %q, want %q", "Y-Testing", got, want)
  200. }
  201. }