You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

96 lines
2.2 KiB

  1. package ratelimit
  2. import (
  3. "fmt"
  4. "net/http"
  5. "time"
  6. )
  7. func Request(keyFn KeyFn) *requestBuilder {
  8. return &requestBuilder{
  9. keyFn: keyFn,
  10. }
  11. }
  12. type requestBuilder struct {
  13. keyFn KeyFn
  14. rate int
  15. window time.Duration
  16. rateHeader string
  17. resetHeader string
  18. }
  19. func (b *requestBuilder) Rate(rate int, window time.Duration) *requestBuilder {
  20. b.rate = rate
  21. b.window = window
  22. b.rateHeader = fmt.Sprintf("%v", float32(rate)*float32(window/time.Second))
  23. b.resetHeader = fmt.Sprintf("%d", time.Now().Unix())
  24. return b
  25. }
  26. // TODO: Custom burst?
  27. // func (b *requestBuilder) Burst(burst int) *requestBuilder {}
  28. func (b *requestBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
  29. store.InitRate(b.rate, b.window)
  30. for _, store := range fallbackStores {
  31. store.InitRate(b.rate, b.window)
  32. }
  33. limiter := requestLimiter{
  34. requestBuilder: b,
  35. store: store,
  36. fallbackStores: fallbackStores,
  37. }
  38. fn := func(next http.Handler) http.Handler {
  39. limiter.next = next
  40. return &limiter
  41. }
  42. return fn
  43. }
  44. type requestLimiter struct {
  45. *requestBuilder
  46. next http.Handler
  47. store TokenBucketStore
  48. fallbackStores []TokenBucketStore
  49. }
  50. // ServeHTTPC implements http.Handler interface.
  51. func (l *requestLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  52. key := l.keyFn(r)
  53. if key == "" {
  54. l.next.ServeHTTP(w, r)
  55. return
  56. }
  57. ok, remaining, reset, err := l.store.Take("request:" + key)
  58. if err != nil {
  59. for _, store := range l.fallbackStores {
  60. ok, remaining, reset, err = store.Take("request:" + key)
  61. if err == nil {
  62. break
  63. }
  64. }
  65. }
  66. if err != nil {
  67. l.next.ServeHTTP(w, r)
  68. return
  69. }
  70. if !ok {
  71. w.Header().Add("Retry-After", reset.Format(http.TimeFormat))
  72. http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
  73. return
  74. }
  75. w.Header().Add("X-RateLimit-Key", key)
  76. w.Header().Add("X-RateLimit-Rate", l.rateHeader)
  77. w.Header().Add("X-RateLimit-Limit", fmt.Sprintf("%d", l.rate))
  78. w.Header().Add("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
  79. w.Header().Add("X-RateLimit-Reset", fmt.Sprintf("%d", reset.Unix()))
  80. w.Header().Add("Retry-After", reset.Format(http.TimeFormat))
  81. l.next.ServeHTTP(w, r)
  82. }