You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

56 lines
1.6 KiB

  1. package handlers
  2. import (
  3. "net/http"
  4. )
  5. // Structure that holds the context map and exposes the ResponseWriter interface.
  6. type contextResponseWriter struct {
  7. http.ResponseWriter
  8. m map[interface{}]interface{}
  9. }
  10. // Implement the WrapWriter interface.
  11. func (this *contextResponseWriter) WrappedWriter() http.ResponseWriter {
  12. return this.ResponseWriter
  13. }
  14. // ContextHandlerFunc is the same as ContextHandler, it is just a convenience
  15. // signature that accepts a func(http.ResponseWriter, *http.Request) instead of
  16. // a http.Handler interface. It saves the boilerplate http.HandlerFunc() cast.
  17. func ContextHandlerFunc(h http.HandlerFunc, cap int) http.HandlerFunc {
  18. return ContextHandler(h, cap)
  19. }
  20. // ContextHandler gives a context storage that lives only for the duration of
  21. // the request, with no locking involved.
  22. func ContextHandler(h http.Handler, cap int) http.HandlerFunc {
  23. return func(w http.ResponseWriter, r *http.Request) {
  24. if _, ok := GetContext(w); ok {
  25. // Self-awareness, context handler is already set up
  26. h.ServeHTTP(w, r)
  27. return
  28. }
  29. // Create the context-providing ResponseWriter replacement.
  30. ctxw := &contextResponseWriter{
  31. w,
  32. make(map[interface{}]interface{}, cap),
  33. }
  34. // Call the wrapped handler with the context-aware writer
  35. h.ServeHTTP(ctxw, r)
  36. }
  37. }
  38. // Helper function to retrieve the context map from the ResponseWriter interface.
  39. func GetContext(w http.ResponseWriter) (map[interface{}]interface{}, bool) {
  40. ctxw, ok := GetResponseWriter(w, func(tst http.ResponseWriter) bool {
  41. _, ok := tst.(*contextResponseWriter)
  42. return ok
  43. })
  44. if ok {
  45. return ctxw.(*contextResponseWriter).m, true
  46. }
  47. return nil, false
  48. }