You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

438 lines
11 KiB

  1. package mux
  2. import (
  3. "bytes"
  4. "net/http"
  5. "net/http/httptest"
  6. "testing"
  7. )
  8. type testMiddleware struct {
  9. timesCalled uint
  10. }
  11. func (tm *testMiddleware) Middleware(h http.Handler) http.Handler {
  12. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  13. tm.timesCalled++
  14. h.ServeHTTP(w, r)
  15. })
  16. }
  17. func dummyHandler(w http.ResponseWriter, r *http.Request) {}
  18. func TestMiddlewareAdd(t *testing.T) {
  19. router := NewRouter()
  20. router.HandleFunc("/", dummyHandler).Methods("GET")
  21. mw := &testMiddleware{}
  22. router.useInterface(mw)
  23. if len(router.middlewares) != 1 || router.middlewares[0] != mw {
  24. t.Fatal("Middleware was not added correctly")
  25. }
  26. router.Use(mw.Middleware)
  27. if len(router.middlewares) != 2 {
  28. t.Fatal("MiddlewareFunc method was not added correctly")
  29. }
  30. banalMw := func(handler http.Handler) http.Handler {
  31. return handler
  32. }
  33. router.Use(banalMw)
  34. if len(router.middlewares) != 3 {
  35. t.Fatal("MiddlewareFunc method was not added correctly")
  36. }
  37. }
  38. func TestMiddleware(t *testing.T) {
  39. router := NewRouter()
  40. router.HandleFunc("/", dummyHandler).Methods("GET")
  41. mw := &testMiddleware{}
  42. router.useInterface(mw)
  43. rw := NewRecorder()
  44. req := newRequest("GET", "/")
  45. // Test regular middleware call
  46. router.ServeHTTP(rw, req)
  47. if mw.timesCalled != 1 {
  48. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  49. }
  50. // Middleware should not be called for 404
  51. req = newRequest("GET", "/not/found")
  52. router.ServeHTTP(rw, req)
  53. if mw.timesCalled != 1 {
  54. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  55. }
  56. // Middleware should not be called if there is a method mismatch
  57. req = newRequest("POST", "/")
  58. router.ServeHTTP(rw, req)
  59. if mw.timesCalled != 1 {
  60. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  61. }
  62. // Add the middleware again as function
  63. router.Use(mw.Middleware)
  64. req = newRequest("GET", "/")
  65. router.ServeHTTP(rw, req)
  66. if mw.timesCalled != 3 {
  67. t.Fatalf("Expected %d calls, but got only %d", 3, mw.timesCalled)
  68. }
  69. }
  70. func TestMiddlewareSubrouter(t *testing.T) {
  71. router := NewRouter()
  72. router.HandleFunc("/", dummyHandler).Methods("GET")
  73. subrouter := router.PathPrefix("/sub").Subrouter()
  74. subrouter.HandleFunc("/x", dummyHandler).Methods("GET")
  75. mw := &testMiddleware{}
  76. subrouter.useInterface(mw)
  77. rw := NewRecorder()
  78. req := newRequest("GET", "/")
  79. router.ServeHTTP(rw, req)
  80. if mw.timesCalled != 0 {
  81. t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
  82. }
  83. req = newRequest("GET", "/sub/")
  84. router.ServeHTTP(rw, req)
  85. if mw.timesCalled != 0 {
  86. t.Fatalf("Expected %d calls, but got only %d", 0, mw.timesCalled)
  87. }
  88. req = newRequest("GET", "/sub/x")
  89. router.ServeHTTP(rw, req)
  90. if mw.timesCalled != 1 {
  91. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  92. }
  93. req = newRequest("GET", "/sub/not/found")
  94. router.ServeHTTP(rw, req)
  95. if mw.timesCalled != 1 {
  96. t.Fatalf("Expected %d calls, but got only %d", 1, mw.timesCalled)
  97. }
  98. router.useInterface(mw)
  99. req = newRequest("GET", "/")
  100. router.ServeHTTP(rw, req)
  101. if mw.timesCalled != 2 {
  102. t.Fatalf("Expected %d calls, but got only %d", 2, mw.timesCalled)
  103. }
  104. req = newRequest("GET", "/sub/x")
  105. router.ServeHTTP(rw, req)
  106. if mw.timesCalled != 4 {
  107. t.Fatalf("Expected %d calls, but got only %d", 4, mw.timesCalled)
  108. }
  109. }
  110. func TestMiddlewareExecution(t *testing.T) {
  111. mwStr := []byte("Middleware\n")
  112. handlerStr := []byte("Logic\n")
  113. router := NewRouter()
  114. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  115. w.Write(handlerStr)
  116. })
  117. rw := NewRecorder()
  118. req := newRequest("GET", "/")
  119. // Test handler-only call
  120. router.ServeHTTP(rw, req)
  121. if !bytes.Equal(rw.Body.Bytes(), handlerStr) {
  122. t.Fatal("Handler response is not what it should be")
  123. }
  124. // Test middleware call
  125. rw = NewRecorder()
  126. router.Use(func(h http.Handler) http.Handler {
  127. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  128. w.Write(mwStr)
  129. h.ServeHTTP(w, r)
  130. })
  131. })
  132. router.ServeHTTP(rw, req)
  133. if !bytes.Equal(rw.Body.Bytes(), append(mwStr, handlerStr...)) {
  134. t.Fatal("Middleware + handler response is not what it should be")
  135. }
  136. }
  137. func TestMiddlewareNotFound(t *testing.T) {
  138. mwStr := []byte("Middleware\n")
  139. handlerStr := []byte("Logic\n")
  140. router := NewRouter()
  141. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  142. w.Write(handlerStr)
  143. })
  144. router.Use(func(h http.Handler) http.Handler {
  145. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  146. w.Write(mwStr)
  147. h.ServeHTTP(w, r)
  148. })
  149. })
  150. // Test not found call with default handler
  151. rw := NewRecorder()
  152. req := newRequest("GET", "/notfound")
  153. router.ServeHTTP(rw, req)
  154. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  155. t.Fatal("Middleware was called for a 404")
  156. }
  157. // Test not found call with custom handler
  158. rw = NewRecorder()
  159. req = newRequest("GET", "/notfound")
  160. router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  161. rw.Write([]byte("Custom 404 handler"))
  162. })
  163. router.ServeHTTP(rw, req)
  164. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  165. t.Fatal("Middleware was called for a custom 404")
  166. }
  167. }
  168. func TestMiddlewareMethodMismatch(t *testing.T) {
  169. mwStr := []byte("Middleware\n")
  170. handlerStr := []byte("Logic\n")
  171. router := NewRouter()
  172. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  173. w.Write(handlerStr)
  174. }).Methods("GET")
  175. router.Use(func(h http.Handler) http.Handler {
  176. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  177. w.Write(mwStr)
  178. h.ServeHTTP(w, r)
  179. })
  180. })
  181. // Test method mismatch
  182. rw := NewRecorder()
  183. req := newRequest("POST", "/")
  184. router.ServeHTTP(rw, req)
  185. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  186. t.Fatal("Middleware was called for a method mismatch")
  187. }
  188. // Test not found call
  189. rw = NewRecorder()
  190. req = newRequest("POST", "/")
  191. router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  192. rw.Write([]byte("Method not allowed"))
  193. })
  194. router.ServeHTTP(rw, req)
  195. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  196. t.Fatal("Middleware was called for a method mismatch")
  197. }
  198. }
  199. func TestMiddlewareNotFoundSubrouter(t *testing.T) {
  200. mwStr := []byte("Middleware\n")
  201. handlerStr := []byte("Logic\n")
  202. router := NewRouter()
  203. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  204. w.Write(handlerStr)
  205. })
  206. subrouter := router.PathPrefix("/sub/").Subrouter()
  207. subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  208. w.Write(handlerStr)
  209. })
  210. router.Use(func(h http.Handler) http.Handler {
  211. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  212. w.Write(mwStr)
  213. h.ServeHTTP(w, r)
  214. })
  215. })
  216. // Test not found call for default handler
  217. rw := NewRecorder()
  218. req := newRequest("GET", "/sub/notfound")
  219. router.ServeHTTP(rw, req)
  220. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  221. t.Fatal("Middleware was called for a 404")
  222. }
  223. // Test not found call with custom handler
  224. rw = NewRecorder()
  225. req = newRequest("GET", "/sub/notfound")
  226. subrouter.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  227. rw.Write([]byte("Custom 404 handler"))
  228. })
  229. router.ServeHTTP(rw, req)
  230. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  231. t.Fatal("Middleware was called for a custom 404")
  232. }
  233. }
  234. func TestMiddlewareMethodMismatchSubrouter(t *testing.T) {
  235. mwStr := []byte("Middleware\n")
  236. handlerStr := []byte("Logic\n")
  237. router := NewRouter()
  238. router.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  239. w.Write(handlerStr)
  240. })
  241. subrouter := router.PathPrefix("/sub/").Subrouter()
  242. subrouter.HandleFunc("/", func(w http.ResponseWriter, e *http.Request) {
  243. w.Write(handlerStr)
  244. }).Methods("GET")
  245. router.Use(func(h http.Handler) http.Handler {
  246. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  247. w.Write(mwStr)
  248. h.ServeHTTP(w, r)
  249. })
  250. })
  251. // Test method mismatch without custom handler
  252. rw := NewRecorder()
  253. req := newRequest("POST", "/sub/")
  254. router.ServeHTTP(rw, req)
  255. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  256. t.Fatal("Middleware was called for a method mismatch")
  257. }
  258. // Test method mismatch with custom handler
  259. rw = NewRecorder()
  260. req = newRequest("POST", "/sub/")
  261. router.MethodNotAllowedHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  262. rw.Write([]byte("Method not allowed"))
  263. })
  264. router.ServeHTTP(rw, req)
  265. if bytes.Contains(rw.Body.Bytes(), mwStr) {
  266. t.Fatal("Middleware was called for a method mismatch")
  267. }
  268. }
  269. func TestCORSMethodMiddleware(t *testing.T) {
  270. router := NewRouter()
  271. cases := []struct {
  272. path string
  273. response string
  274. method string
  275. testURL string
  276. expectedAllowedMethods string
  277. }{
  278. {"/g/{o}", "a", "POST", "/g/asdf", "POST,PUT,GET,OPTIONS"},
  279. {"/g/{o}", "b", "PUT", "/g/bla", "POST,PUT,GET,OPTIONS"},
  280. {"/g/{o}", "c", "GET", "/g/orilla", "POST,PUT,GET,OPTIONS"},
  281. {"/g", "d", "POST", "/g", "POST,OPTIONS"},
  282. }
  283. for _, tt := range cases {
  284. router.HandleFunc(tt.path, stringHandler(tt.response)).Methods(tt.method)
  285. }
  286. router.Use(CORSMethodMiddleware(router))
  287. for _, tt := range cases {
  288. rr := httptest.NewRecorder()
  289. req := newRequest(tt.method, tt.testURL)
  290. router.ServeHTTP(rr, req)
  291. if rr.Body.String() != tt.response {
  292. t.Errorf("Expected body '%s', found '%s'", tt.response, rr.Body.String())
  293. }
  294. allowedMethods := rr.Header().Get("Access-Control-Allow-Methods")
  295. if allowedMethods != tt.expectedAllowedMethods {
  296. t.Errorf("Expected Access-Control-Allow-Methods '%s', found '%s'", tt.expectedAllowedMethods, allowedMethods)
  297. }
  298. }
  299. }
  300. func TestMiddlewareOnMultiSubrouter(t *testing.T) {
  301. first := "first"
  302. second := "second"
  303. notFound := "404 not found"
  304. router := NewRouter()
  305. firstSubRouter := router.PathPrefix("/").Subrouter()
  306. secondSubRouter := router.PathPrefix("/").Subrouter()
  307. router.NotFoundHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
  308. rw.Write([]byte(notFound))
  309. })
  310. firstSubRouter.HandleFunc("/first", func(w http.ResponseWriter, r *http.Request) {
  311. })
  312. secondSubRouter.HandleFunc("/second", func(w http.ResponseWriter, r *http.Request) {
  313. })
  314. firstSubRouter.Use(func(h http.Handler) http.Handler {
  315. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  316. w.Write([]byte(first))
  317. h.ServeHTTP(w, r)
  318. })
  319. })
  320. secondSubRouter.Use(func(h http.Handler) http.Handler {
  321. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  322. w.Write([]byte(second))
  323. h.ServeHTTP(w, r)
  324. })
  325. })
  326. rw := NewRecorder()
  327. req := newRequest("GET", "/first")
  328. router.ServeHTTP(rw, req)
  329. if rw.Body.String() != first {
  330. t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", first, rw.Body.String())
  331. }
  332. rw = NewRecorder()
  333. req = newRequest("GET", "/second")
  334. router.ServeHTTP(rw, req)
  335. if rw.Body.String() != second {
  336. t.Fatalf("Middleware did not run: expected %s middleware to write a response (got %s)", second, rw.Body.String())
  337. }
  338. rw = NewRecorder()
  339. req = newRequest("GET", "/second/not-exist")
  340. router.ServeHTTP(rw, req)
  341. if rw.Body.String() != notFound {
  342. t.Fatalf("Notfound handler did not run: expected %s for not-exist, (got %s)", notFound, rw.Body.String())
  343. }
  344. }