You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

496 lines
14 KiB

  1. // Copyright 2012 The Gorilla Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package mux
  5. import (
  6. "errors"
  7. "fmt"
  8. "net/http"
  9. "path"
  10. "regexp"
  11. )
  12. // NewRouter returns a new router instance.
  13. func NewRouter() *Router {
  14. return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
  15. }
  16. // Router registers routes to be matched and dispatches a handler.
  17. //
  18. // It implements the http.Handler interface, so it can be registered to serve
  19. // requests:
  20. //
  21. // var router = mux.NewRouter()
  22. //
  23. // func main() {
  24. // http.Handle("/", router)
  25. // }
  26. //
  27. // Or, for Google App Engine, register it in a init() function:
  28. //
  29. // func init() {
  30. // http.Handle("/", router)
  31. // }
  32. //
  33. // This will send all incoming requests to the router.
  34. type Router struct {
  35. // Configurable Handler to be used when no route matches.
  36. NotFoundHandler http.Handler
  37. // Parent route, if this is a subrouter.
  38. parent parentRoute
  39. // Routes to be matched, in order.
  40. routes []*Route
  41. // Routes by name for URL building.
  42. namedRoutes map[string]*Route
  43. // See Router.StrictSlash(). This defines the flag for new routes.
  44. strictSlash bool
  45. // See Router.SkipClean(). This defines the flag for new routes.
  46. skipClean bool
  47. // If true, do not clear the request context after handling the request.
  48. // This has no effect when go1.7+ is used, since the context is stored
  49. // on the request itself.
  50. KeepContext bool
  51. }
  52. // Match matches registered routes against the request.
  53. func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
  54. for _, route := range r.routes {
  55. if route.Match(req, match) {
  56. return true
  57. }
  58. }
  59. // Closest match for a router (includes sub-routers)
  60. if r.NotFoundHandler != nil {
  61. match.Handler = r.NotFoundHandler
  62. return true
  63. }
  64. return false
  65. }
  66. // ServeHTTP dispatches the handler registered in the matched route.
  67. //
  68. // When there is a match, the route variables can be retrieved calling
  69. // mux.Vars(request).
  70. func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  71. if !r.skipClean {
  72. // Clean path to canonical form and redirect.
  73. if p := cleanPath(req.URL.Path); p != req.URL.Path {
  74. // Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
  75. // This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue:
  76. // http://code.google.com/p/go/issues/detail?id=5252
  77. url := *req.URL
  78. url.Path = p
  79. p = url.String()
  80. w.Header().Set("Location", p)
  81. w.WriteHeader(http.StatusMovedPermanently)
  82. return
  83. }
  84. }
  85. var match RouteMatch
  86. var handler http.Handler
  87. if r.Match(req, &match) {
  88. handler = match.Handler
  89. req = setVars(req, match.Vars)
  90. req = setCurrentRoute(req, match.Route)
  91. }
  92. if handler == nil {
  93. handler = http.NotFoundHandler()
  94. }
  95. if !r.KeepContext {
  96. defer contextClear(req)
  97. }
  98. handler.ServeHTTP(w, req)
  99. }
  100. // Get returns a route registered with the given name.
  101. func (r *Router) Get(name string) *Route {
  102. return r.getNamedRoutes()[name]
  103. }
  104. // GetRoute returns a route registered with the given name. This method
  105. // was renamed to Get() and remains here for backwards compatibility.
  106. func (r *Router) GetRoute(name string) *Route {
  107. return r.getNamedRoutes()[name]
  108. }
  109. // StrictSlash defines the trailing slash behavior for new routes. The initial
  110. // value is false.
  111. //
  112. // When true, if the route path is "/path/", accessing "/path" will redirect
  113. // to the former and vice versa. In other words, your application will always
  114. // see the path as specified in the route.
  115. //
  116. // When false, if the route path is "/path", accessing "/path/" will not match
  117. // this route and vice versa.
  118. //
  119. // Special case: when a route sets a path prefix using the PathPrefix() method,
  120. // strict slash is ignored for that route because the redirect behavior can't
  121. // be determined from a prefix alone. However, any subrouters created from that
  122. // route inherit the original StrictSlash setting.
  123. func (r *Router) StrictSlash(value bool) *Router {
  124. r.strictSlash = value
  125. return r
  126. }
  127. // SkipClean defines the path cleaning behaviour for new routes. The initial
  128. // value is false. Users should be careful about which routes are not cleaned
  129. //
  130. // When true, if the route path is "/path//to", it will remain with the double
  131. // slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
  132. //
  133. // When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
  134. // become /fetch/http/xkcd.com/534
  135. func (r *Router) SkipClean(value bool) *Router {
  136. r.skipClean = value
  137. return r
  138. }
  139. // ----------------------------------------------------------------------------
  140. // parentRoute
  141. // ----------------------------------------------------------------------------
  142. // getNamedRoutes returns the map where named routes are registered.
  143. func (r *Router) getNamedRoutes() map[string]*Route {
  144. if r.namedRoutes == nil {
  145. if r.parent != nil {
  146. r.namedRoutes = r.parent.getNamedRoutes()
  147. } else {
  148. r.namedRoutes = make(map[string]*Route)
  149. }
  150. }
  151. return r.namedRoutes
  152. }
  153. // getRegexpGroup returns regexp definitions from the parent route, if any.
  154. func (r *Router) getRegexpGroup() *routeRegexpGroup {
  155. if r.parent != nil {
  156. return r.parent.getRegexpGroup()
  157. }
  158. return nil
  159. }
  160. func (r *Router) buildVars(m map[string]string) map[string]string {
  161. if r.parent != nil {
  162. m = r.parent.buildVars(m)
  163. }
  164. return m
  165. }
  166. // ----------------------------------------------------------------------------
  167. // Route factories
  168. // ----------------------------------------------------------------------------
  169. // NewRoute registers an empty route.
  170. func (r *Router) NewRoute() *Route {
  171. route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean}
  172. r.routes = append(r.routes, route)
  173. return route
  174. }
  175. // Handle registers a new route with a matcher for the URL path.
  176. // See Route.Path() and Route.Handler().
  177. func (r *Router) Handle(path string, handler http.Handler) *Route {
  178. return r.NewRoute().Path(path).Handler(handler)
  179. }
  180. // HandleFunc registers a new route with a matcher for the URL path.
  181. // See Route.Path() and Route.HandlerFunc().
  182. func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
  183. *http.Request)) *Route {
  184. return r.NewRoute().Path(path).HandlerFunc(f)
  185. }
  186. // Headers registers a new route with a matcher for request header values.
  187. // See Route.Headers().
  188. func (r *Router) Headers(pairs ...string) *Route {
  189. return r.NewRoute().Headers(pairs...)
  190. }
  191. // Host registers a new route with a matcher for the URL host.
  192. // See Route.Host().
  193. func (r *Router) Host(tpl string) *Route {
  194. return r.NewRoute().Host(tpl)
  195. }
  196. // MatcherFunc registers a new route with a custom matcher function.
  197. // See Route.MatcherFunc().
  198. func (r *Router) MatcherFunc(f MatcherFunc) *Route {
  199. return r.NewRoute().MatcherFunc(f)
  200. }
  201. // Methods registers a new route with a matcher for HTTP methods.
  202. // See Route.Methods().
  203. func (r *Router) Methods(methods ...string) *Route {
  204. return r.NewRoute().Methods(methods...)
  205. }
  206. // Path registers a new route with a matcher for the URL path.
  207. // See Route.Path().
  208. func (r *Router) Path(tpl string) *Route {
  209. return r.NewRoute().Path(tpl)
  210. }
  211. // PathPrefix registers a new route with a matcher for the URL path prefix.
  212. // See Route.PathPrefix().
  213. func (r *Router) PathPrefix(tpl string) *Route {
  214. return r.NewRoute().PathPrefix(tpl)
  215. }
  216. // Queries registers a new route with a matcher for URL query values.
  217. // See Route.Queries().
  218. func (r *Router) Queries(pairs ...string) *Route {
  219. return r.NewRoute().Queries(pairs...)
  220. }
  221. // Schemes registers a new route with a matcher for URL schemes.
  222. // See Route.Schemes().
  223. func (r *Router) Schemes(schemes ...string) *Route {
  224. return r.NewRoute().Schemes(schemes...)
  225. }
  226. // BuildVarsFunc registers a new route with a custom function for modifying
  227. // route variables before building a URL.
  228. func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
  229. return r.NewRoute().BuildVarsFunc(f)
  230. }
  231. // Walk walks the router and all its sub-routers, calling walkFn for each route
  232. // in the tree. The routes are walked in the order they were added. Sub-routers
  233. // are explored depth-first.
  234. func (r *Router) Walk(walkFn WalkFunc) error {
  235. return r.walk(walkFn, []*Route{})
  236. }
  237. // SkipRouter is used as a return value from WalkFuncs to indicate that the
  238. // router that walk is about to descend down to should be skipped.
  239. var SkipRouter = errors.New("skip this router")
  240. // WalkFunc is the type of the function called for each route visited by Walk.
  241. // At every invocation, it is given the current route, and the current router,
  242. // and a list of ancestor routes that lead to the current route.
  243. type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
  244. func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
  245. for _, t := range r.routes {
  246. if t.regexp == nil || t.regexp.path == nil || t.regexp.path.template == "" {
  247. continue
  248. }
  249. err := walkFn(t, r, ancestors)
  250. if err == SkipRouter {
  251. continue
  252. }
  253. for _, sr := range t.matchers {
  254. if h, ok := sr.(*Router); ok {
  255. err := h.walk(walkFn, ancestors)
  256. if err != nil {
  257. return err
  258. }
  259. }
  260. }
  261. if h, ok := t.handler.(*Router); ok {
  262. ancestors = append(ancestors, t)
  263. err := h.walk(walkFn, ancestors)
  264. if err != nil {
  265. return err
  266. }
  267. ancestors = ancestors[:len(ancestors)-1]
  268. }
  269. }
  270. return nil
  271. }
  272. // ----------------------------------------------------------------------------
  273. // Context
  274. // ----------------------------------------------------------------------------
  275. // RouteMatch stores information about a matched route.
  276. type RouteMatch struct {
  277. Route *Route
  278. Handler http.Handler
  279. Vars map[string]string
  280. }
  281. type contextKey int
  282. const (
  283. varsKey contextKey = iota
  284. routeKey
  285. )
  286. // Vars returns the route variables for the current request, if any.
  287. func Vars(r *http.Request) map[string]string {
  288. if rv := contextGet(r, varsKey); rv != nil {
  289. return rv.(map[string]string)
  290. }
  291. return nil
  292. }
  293. // CurrentRoute returns the matched route for the current request, if any.
  294. // This only works when called inside the handler of the matched route
  295. // because the matched route is stored in the request context which is cleared
  296. // after the handler returns, unless the KeepContext option is set on the
  297. // Router.
  298. func CurrentRoute(r *http.Request) *Route {
  299. if rv := contextGet(r, routeKey); rv != nil {
  300. return rv.(*Route)
  301. }
  302. return nil
  303. }
  304. func setVars(r *http.Request, val interface{}) *http.Request {
  305. return contextSet(r, varsKey, val)
  306. }
  307. func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
  308. return contextSet(r, routeKey, val)
  309. }
  310. // ----------------------------------------------------------------------------
  311. // Helpers
  312. // ----------------------------------------------------------------------------
  313. // cleanPath returns the canonical path for p, eliminating . and .. elements.
  314. // Borrowed from the net/http package.
  315. func cleanPath(p string) string {
  316. if p == "" {
  317. return "/"
  318. }
  319. if p[0] != '/' {
  320. p = "/" + p
  321. }
  322. np := path.Clean(p)
  323. // path.Clean removes trailing slash except for root;
  324. // put the trailing slash back if necessary.
  325. if p[len(p)-1] == '/' && np != "/" {
  326. np += "/"
  327. }
  328. return np
  329. }
  330. // uniqueVars returns an error if two slices contain duplicated strings.
  331. func uniqueVars(s1, s2 []string) error {
  332. for _, v1 := range s1 {
  333. for _, v2 := range s2 {
  334. if v1 == v2 {
  335. return fmt.Errorf("mux: duplicated route variable %q", v2)
  336. }
  337. }
  338. }
  339. return nil
  340. }
  341. // checkPairs returns the count of strings passed in, and an error if
  342. // the count is not an even number.
  343. func checkPairs(pairs ...string) (int, error) {
  344. length := len(pairs)
  345. if length%2 != 0 {
  346. return length, fmt.Errorf(
  347. "mux: number of parameters must be multiple of 2, got %v", pairs)
  348. }
  349. return length, nil
  350. }
  351. // mapFromPairsToString converts variadic string parameters to a
  352. // string to string map.
  353. func mapFromPairsToString(pairs ...string) (map[string]string, error) {
  354. length, err := checkPairs(pairs...)
  355. if err != nil {
  356. return nil, err
  357. }
  358. m := make(map[string]string, length/2)
  359. for i := 0; i < length; i += 2 {
  360. m[pairs[i]] = pairs[i+1]
  361. }
  362. return m, nil
  363. }
  364. // mapFromPairsToRegex converts variadic string paramers to a
  365. // string to regex map.
  366. func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
  367. length, err := checkPairs(pairs...)
  368. if err != nil {
  369. return nil, err
  370. }
  371. m := make(map[string]*regexp.Regexp, length/2)
  372. for i := 0; i < length; i += 2 {
  373. regex, err := regexp.Compile(pairs[i+1])
  374. if err != nil {
  375. return nil, err
  376. }
  377. m[pairs[i]] = regex
  378. }
  379. return m, nil
  380. }
  381. // matchInArray returns true if the given string value is in the array.
  382. func matchInArray(arr []string, value string) bool {
  383. for _, v := range arr {
  384. if v == value {
  385. return true
  386. }
  387. }
  388. return false
  389. }
  390. // matchMapWithString returns true if the given key/value pairs exist in a given map.
  391. func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
  392. for k, v := range toCheck {
  393. // Check if key exists.
  394. if canonicalKey {
  395. k = http.CanonicalHeaderKey(k)
  396. }
  397. if values := toMatch[k]; values == nil {
  398. return false
  399. } else if v != "" {
  400. // If value was defined as an empty string we only check that the
  401. // key exists. Otherwise we also check for equality.
  402. valueExists := false
  403. for _, value := range values {
  404. if v == value {
  405. valueExists = true
  406. break
  407. }
  408. }
  409. if !valueExists {
  410. return false
  411. }
  412. }
  413. }
  414. return true
  415. }
  416. // matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
  417. // the given regex
  418. func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
  419. for k, v := range toCheck {
  420. // Check if key exists.
  421. if canonicalKey {
  422. k = http.CanonicalHeaderKey(k)
  423. }
  424. if values := toMatch[k]; values == nil {
  425. return false
  426. } else if v != nil {
  427. // If value was defined as an empty string we only check that the
  428. // key exists. Otherwise we also check for equality.
  429. valueExists := false
  430. for _, value := range values {
  431. if v.MatchString(value) {
  432. valueExists = true
  433. break
  434. }
  435. }
  436. if !valueExists {
  437. return false
  438. }
  439. }
  440. }
  441. return true
  442. }