|
- package ratelimit
-
- import (
- "fmt"
- "net/http"
- "time"
- )
-
- func Request(keyFn KeyFn) *requestBuilder {
- return &requestBuilder{
- keyFn: keyFn,
- }
- }
-
- type requestBuilder struct {
- keyFn KeyFn
- rate int
- window time.Duration
- rateHeader string
- resetHeader string
- }
-
- func (b *requestBuilder) Rate(rate int, window time.Duration) *requestBuilder {
- b.rate = rate
- b.window = window
- b.rateHeader = fmt.Sprintf("%v", float32(rate)*float32(window/time.Second))
- b.resetHeader = fmt.Sprintf("%d", time.Now().Unix())
- return b
- }
-
- // TODO: Custom burst?
- // func (b *requestBuilder) Burst(burst int) *requestBuilder {}
-
- func (b *requestBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
- store.InitRate(b.rate, b.window)
- for _, store := range fallbackStores {
- store.InitRate(b.rate, b.window)
- }
-
- limiter := requestLimiter{
- requestBuilder: b,
- store: store,
- fallbackStores: fallbackStores,
- }
-
- fn := func(next http.Handler) http.Handler {
- limiter.next = next
- return &limiter
- }
-
- return fn
- }
-
- type requestLimiter struct {
- *requestBuilder
-
- next http.Handler
- store TokenBucketStore
- fallbackStores []TokenBucketStore
- }
-
- // ServeHTTPC implements http.Handler interface.
- func (l *requestLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- key := l.keyFn(r)
- if key == "" {
- l.next.ServeHTTP(w, r)
- return
- }
-
- ok, remaining, reset, err := l.store.Take("request:" + key)
- if err != nil {
- for _, store := range l.fallbackStores {
- ok, remaining, reset, err = store.Take("request:" + key)
- if err == nil {
- break
- }
- }
- }
- if err != nil {
- l.next.ServeHTTP(w, r)
- return
- }
- if !ok {
- w.Header().Add("Retry-After", reset.Format(http.TimeFormat))
- http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
- return
- }
- w.Header().Add("X-RateLimit-Key", key)
- w.Header().Add("X-RateLimit-Rate", l.rateHeader)
- w.Header().Add("X-RateLimit-Limit", fmt.Sprintf("%d", l.rate))
- w.Header().Add("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
- w.Header().Add("X-RateLimit-Reset", fmt.Sprintf("%d", reset.Unix()))
- w.Header().Add("Retry-After", reset.Format(http.TimeFormat))
- l.next.ServeHTTP(w, r)
- }
|