@@ -0,0 +1,8 @@ | |||||
sudo: false | |||||
language: go | |||||
go: | |||||
- 1.7rc2 | |||||
script: | |||||
- go test ./... |
@@ -0,0 +1,22 @@ | |||||
MIT License | |||||
Copyright (c) 2016 Vojtech Vitek | |||||
Permission is hereby granted, free of charge, to any person obtaining | |||||
a copy of this software and associated documentation files (the | |||||
"Software"), to deal in the Software without restriction, including | |||||
without limitation the rights to use, copy, modify, merge, publish, | |||||
distribute, sublicense, and/or sell copies of the Software, and to | |||||
permit persons to whom the Software is furnished to do so, subject to | |||||
the following conditions: | |||||
The above copyright notice and this permission notice shall be | |||||
included in all copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -0,0 +1,22 @@ | |||||
# Rate Limit HTTP middleware | |||||
[![GoDoc Widget]][GoDoc] [![Travis Widget]][Travis] | |||||
[Golang](http://golang.org/) package for rate limiting HTTP endpoints based on context and request headers. | |||||
[GoDoc]: https://godoc.org/github.com/VojtechVitek/ratelimit | |||||
[GoDoc Widget]: https://godoc.org/github.com/VojtechVitek/ratelimit?status.svg | |||||
[Travis]: https://travis-ci.org/VojtechVitek/ratelimit | |||||
[Travis Widget]: https://travis-ci.org/VojtechVitek/ratelimit.svg?branch=master | |||||
# Under development | |||||
# Goals | |||||
- Simple but powerful API | |||||
- Token Bucket algorithm (rate + burst) | |||||
- Storage independent (Redis, In-Memory or any other K/V store) | |||||
# License | |||||
Copyright (c) 2016 Vojtech Vitek | |||||
Licensed under the [MIT License](./LICENSE). |
@@ -0,0 +1,117 @@ | |||||
package ratelimit | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
) | |||||
func DownloadSpeed(keyFn KeyFn) *downloadBuilder { | |||||
return &downloadBuilder{ | |||||
keyFn: keyFn, | |||||
} | |||||
} | |||||
type downloadBuilder struct { | |||||
keyFn KeyFn | |||||
rate int | |||||
window time.Duration | |||||
} | |||||
func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder { | |||||
b.rate = rate | |||||
b.window = window | |||||
return b | |||||
} | |||||
// TODO: Custom burst? | |||||
// func (b *downloadBuilder) Burst(burst int) *downloadBuilder {} | |||||
func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler { | |||||
store.InitRate(b.rate, b.window) | |||||
for _, store := range fallbackStores { | |||||
store.InitRate(b.rate, b.window) | |||||
} | |||||
downloadLimiter := downloadLimiter{ | |||||
downloadBuilder: b, | |||||
store: store, | |||||
fallbackStores: fallbackStores, | |||||
} | |||||
return func(next http.Handler) http.Handler { | |||||
fn := func(w http.ResponseWriter, r *http.Request) { | |||||
key := downloadLimiter.keyFn(r) | |||||
if key == "" { | |||||
next.ServeHTTP(w, r) | |||||
return | |||||
} | |||||
lw := &limitWriter{ | |||||
ResponseWriter: w, | |||||
downloadLimiter: &downloadLimiter, | |||||
key: key, | |||||
} | |||||
next.ServeHTTP(lw, r) | |||||
} | |||||
return http.HandlerFunc(fn) | |||||
} | |||||
} | |||||
type downloadLimiter struct { | |||||
*downloadBuilder | |||||
next http.Handler | |||||
store TokenBucketStore | |||||
fallbackStores []TokenBucketStore | |||||
} | |||||
type limitWriter struct { | |||||
http.ResponseWriter | |||||
*downloadLimiter | |||||
key string | |||||
wroteHeader bool | |||||
canWrite int64 | |||||
} | |||||
func (w *limitWriter) Write(buf []byte) (int, error) { | |||||
total := 0 | |||||
for { | |||||
if w.canWrite < 1024 { | |||||
ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key) | |||||
if err != nil { | |||||
for _, store := range w.fallbackStores { | |||||
ok, _, _, err = store.Take("download:" + w.key) | |||||
if err == nil { | |||||
break | |||||
} | |||||
} | |||||
} | |||||
if err != nil { | |||||
return total, err | |||||
} | |||||
if ok { | |||||
w.canWrite += 1024 | |||||
} | |||||
} | |||||
if w.canWrite == 0 { | |||||
continue | |||||
} | |||||
max := len(buf) - total | |||||
if int(w.canWrite) < max { | |||||
max = int(w.canWrite) | |||||
} | |||||
if max == 0 { | |||||
return total, nil | |||||
} | |||||
n, err := w.ResponseWriter.Write(buf[total : total+max]) | |||||
w.canWrite -= int64(n) | |||||
total += n | |||||
if err != nil { | |||||
return total, err | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
package ratelimit_test | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
"github.com/VojtechVitek/ratelimit/memory" | |||||
) | |||||
// Watch the download speed with | |||||
// wget http://localhost:3333/file -q --show-progress | |||||
func ExampleDownloadSpeed() { | |||||
middleware := ratelimit.DownloadSpeed(ratelimit.IP).Rate(1024, time.Second).LimitBy(memory.New()) | |||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
http.ServeFile(w, r, "/dev/random") | |||||
}) | |||||
http.ListenAndServe(":3333", middleware(handler)) | |||||
} |
@@ -0,0 +1,46 @@ | |||||
package main | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
"github.com/VojtechVitek/ratelimit/memory" | |||||
"github.com/VojtechVitek/ratelimit/redis" | |||||
redigo "github.com/garyburd/redigo/redis" | |||||
"github.com/pressly/chi" | |||||
"github.com/pressly/chi/middleware" | |||||
) | |||||
var pool = &redigo.Pool{ | |||||
MaxIdle: 10, | |||||
MaxActive: 50, | |||||
IdleTimeout: 300 * time.Second, | |||||
Wait: false, // Important | |||||
Dial: func() (redigo.Conn, error) { | |||||
c, err := redigo.DialTimeout("tcp", "127.0.0.1:6379", 200*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return c, err | |||||
}, | |||||
TestOnBorrow: func(c redigo.Conn, t time.Time) error { | |||||
_, err := c.Do("PING") | |||||
return err | |||||
}, | |||||
} | |||||
// wget http://localhost:3333 -q --show-progress | |||||
func main() { | |||||
r := chi.NewRouter() | |||||
r.Use(middleware.Logger) | |||||
r.Use(ratelimit.DownloadSpeed(ratelimit.IP).Rate(1024, time.Second).LimitBy(redis.New(pool), memory.New())) | |||||
r.Get("/", ServeVideo) | |||||
http.ListenAndServe(":3333", r) | |||||
} | |||||
func ServeVideo(w http.ResponseWriter, r *http.Request) { | |||||
http.ServeFile(w, r, "/Users/vojtechvitek/Desktop/govideo.mov") | |||||
} |
@@ -0,0 +1,55 @@ | |||||
package main | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
"github.com/VojtechVitek/ratelimit/memory" | |||||
"github.com/VojtechVitek/ratelimit/redis" | |||||
redigo "github.com/garyburd/redigo/redis" | |||||
"github.com/pressly/chi" | |||||
"github.com/pressly/chi/middleware" | |||||
) | |||||
var pool = &redigo.Pool{ | |||||
MaxIdle: 10, | |||||
MaxActive: 50, | |||||
IdleTimeout: 300 * time.Second, | |||||
Wait: false, // Important | |||||
Dial: func() (redigo.Conn, error) { | |||||
c, err := redigo.DialTimeout("tcp", "127.0.0.1:6379", 200*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return c, err | |||||
}, | |||||
TestOnBorrow: func(c redigo.Conn, t time.Time) error { | |||||
_, err := c.Do("PING") | |||||
return err | |||||
}, | |||||
} | |||||
// while :; do curl -v localhost:3333; sleep 0.1; done | |||||
func main() { | |||||
r := chi.NewRouter() | |||||
r.Use(middleware.Logger) | |||||
//r.Use(ratelimit.Request(ratelimit.IP).Rate(1, time.Second).LimitBy(redis.New(pool))) | |||||
r.Use(ratelimit.Request(ratelimit.IP).Rate(5, 5*time.Second).LimitBy(redis.New(pool), memory.New())) | |||||
r.Get("/", Hello) | |||||
http.ListenAndServe(":3333", r) | |||||
} | |||||
func Hello(w http.ResponseWriter, r *http.Request) { | |||||
w.Write([]byte("Hello World!\n")) | |||||
//w.Write([]byte("Hello user_id=" + r.URL.Query().Get("user_id") + "\n")) | |||||
} | |||||
func UserKey(r *http.Request) string { | |||||
user := r.URL.Query().Get("user_id") | |||||
// user, _ := r.Context().Value("session.user_id").(string) | |||||
return user | |||||
} |
@@ -0,0 +1,24 @@ | |||||
package main | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
) | |||||
// curl -v http://localhost:3333 | |||||
func main() { | |||||
middleware := ratelimit.Throttle(1) | |||||
http.ListenAndServe(":3333", middleware(http.HandlerFunc(Work))) | |||||
} | |||||
func Work(w http.ResponseWriter, r *http.Request) { | |||||
w.Write([]byte("working hard...\n\n")) | |||||
if f, ok := w.(http.Flusher); ok { | |||||
f.Flush() | |||||
} | |||||
time.Sleep(10 * time.Second) | |||||
w.Write([]byte("done")) | |||||
} |
@@ -0,0 +1,27 @@ | |||||
package ratelimit | |||||
import ( | |||||
"net" | |||||
"net/http" | |||||
"strings" | |||||
) | |||||
// IP returns unique key per request IP. | |||||
func IP(r *http.Request) string { | |||||
ip, _, _ := net.SplitHostPort(r.RemoteAddr) | |||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { | |||||
if i := strings.IndexAny(xff, ",;"); i != -1 { | |||||
xff = xff[:i] | |||||
} | |||||
ip += "," + xff | |||||
} | |||||
if xrip := r.Header.Get("X-Real-IP"); xrip != "" { | |||||
ip += "," + xrip | |||||
} | |||||
return ip | |||||
} | |||||
// NOP returns empty key for each request. | |||||
func NOP(r *http.Request) string { | |||||
return "" | |||||
} |
@@ -0,0 +1,62 @@ | |||||
package memory | |||||
import ( | |||||
"sync" | |||||
"time" | |||||
) | |||||
type token struct{} | |||||
type bucketStore struct { | |||||
sync.Mutex // guards buckets | |||||
buckets map[string]chan token | |||||
bucketLen int | |||||
reset time.Time | |||||
} | |||||
// New creates new in-memory token bucket store. | |||||
func New() *bucketStore { | |||||
return &bucketStore{ | |||||
buckets: map[string]chan token{}, | |||||
} | |||||
} | |||||
func (s *bucketStore) InitRate(rate int, window time.Duration) { | |||||
s.bucketLen = rate | |||||
s.reset = time.Now() | |||||
go func() { | |||||
interval := time.Duration(int(window) / rate) | |||||
tick := time.NewTicker(interval) | |||||
for t := range tick.C { | |||||
s.Lock() | |||||
s.reset = t.Add(interval) | |||||
for key, bucket := range s.buckets { | |||||
select { | |||||
case <-bucket: | |||||
default: | |||||
delete(s.buckets, key) | |||||
} | |||||
} | |||||
s.Unlock() | |||||
} | |||||
}() | |||||
} | |||||
// Take implements TokenBucketStore interface. It takes token from a bucket | |||||
// referenced by a given key, if available. | |||||
func (s *bucketStore) Take(key string) (bool, int, time.Time, error) { | |||||
s.Lock() | |||||
bucket, ok := s.buckets[key] | |||||
if !ok { | |||||
bucket = make(chan token, s.bucketLen) | |||||
s.buckets[key] = bucket | |||||
} | |||||
s.Unlock() | |||||
select { | |||||
case bucket <- token{}: | |||||
return true, cap(bucket) - len(bucket), s.reset, nil | |||||
default: | |||||
return false, 0, s.reset, nil | |||||
} | |||||
} |
@@ -0,0 +1,16 @@ | |||||
package ratelimit | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
) | |||||
// TokenBucketStore is an interface for for any storage implementing | |||||
// Token Bucket algorithm. | |||||
type TokenBucketStore interface { | |||||
InitRate(rate int, window time.Duration) | |||||
Take(key string) (taken bool, remaining int, reset time.Time, err error) | |||||
} | |||||
// KeyFn is a function returning bucket key depending on request data. | |||||
type KeyFn func(r *http.Request) string |
@@ -0,0 +1,94 @@ | |||||
package redis | |||||
import ( | |||||
"errors" | |||||
"time" | |||||
"github.com/garyburd/redigo/redis" | |||||
) | |||||
var ( | |||||
PrefixKey = "ratelimit:" | |||||
ErrUnreachable = errors.New("redis is unreachable") | |||||
RetryAfter = time.Second | |||||
) | |||||
const skipOnUnhealthy = 1000 | |||||
type bucketStore struct { | |||||
pool *redis.Pool | |||||
rate int | |||||
windowSeconds int | |||||
retryAfter *time.Time | |||||
} | |||||
// New creates new in-memory token bucket store. | |||||
func New(pool *redis.Pool) *bucketStore { | |||||
return &bucketStore{ | |||||
pool: pool, | |||||
} | |||||
} | |||||
func (s *bucketStore) InitRate(rate int, window time.Duration) { | |||||
s.rate = rate | |||||
s.windowSeconds = int(window / time.Second) | |||||
if s.windowSeconds <= 1 { | |||||
s.windowSeconds = 1 | |||||
} | |||||
} | |||||
// Take implements TokenBucketStore interface. It takes token from a bucket | |||||
// referenced by a given key, if available. | |||||
func (s *bucketStore) Take(key string) (bool, int, time.Time, error) { | |||||
if s.retryAfter != nil { | |||||
if s.retryAfter.After(time.Now()) { | |||||
return false, 0, time.Time{}, ErrUnreachable | |||||
} | |||||
s.retryAfter = nil | |||||
} | |||||
c := s.pool.Get() | |||||
defer c.Close() | |||||
// Number of tokens in the bucket. | |||||
bucketLen, err := redis.Int(c.Do("LLEN", PrefixKey+key)) | |||||
if err != nil { | |||||
next := time.Now().Add(time.Second) | |||||
s.retryAfter = &next | |||||
return false, 0, time.Time{}, err | |||||
} | |||||
// Bucket is full. | |||||
if bucketLen >= s.rate { | |||||
return false, 0, time.Time{}, nil | |||||
} | |||||
if bucketLen > 0 { | |||||
// Bucket most probably exists, try to push a new token into it. | |||||
// If RPUSHX returns 0 (ie. key expired between LLEN and RPUSHX), we need | |||||
// to fall-back to RPUSH without returning error. | |||||
c.Send("MULTI") | |||||
c.Send("RPUSHX", PrefixKey+key, "") | |||||
reply, err := redis.Ints(c.Do("EXEC")) | |||||
if err != nil { | |||||
next := time.Now().Add(time.Second) | |||||
s.retryAfter = &next | |||||
return false, 0, time.Time{}, err | |||||
} | |||||
bucketLen = reply[0] | |||||
if bucketLen > 0 { | |||||
return true, s.rate - bucketLen - 1, time.Time{}, nil | |||||
} | |||||
} | |||||
c.Send("MULTI") | |||||
c.Send("RPUSH", PrefixKey+key, "") | |||||
c.Send("EXPIRE", PrefixKey+key, s.windowSeconds) | |||||
if _, err := c.Do("EXEC"); err != nil { | |||||
next := time.Now().Add(time.Second) | |||||
s.retryAfter = &next | |||||
return false, 0, time.Time{}, err | |||||
} | |||||
return true, s.rate - bucketLen - 1, time.Time{}, nil | |||||
} |
@@ -0,0 +1,95 @@ | |||||
package ratelimit | |||||
import ( | |||||
"fmt" | |||||
"net/http" | |||||
"time" | |||||
) | |||||
func Request(keyFn KeyFn) *requestBuilder { | |||||
return &requestBuilder{ | |||||
keyFn: keyFn, | |||||
} | |||||
} | |||||
type requestBuilder struct { | |||||
keyFn KeyFn | |||||
rate int | |||||
window time.Duration | |||||
rateHeader string | |||||
resetHeader string | |||||
} | |||||
func (b *requestBuilder) Rate(rate int, window time.Duration) *requestBuilder { | |||||
b.rate = rate | |||||
b.window = window | |||||
b.rateHeader = fmt.Sprintf("%v", float32(rate)*float32(window/time.Second)) | |||||
b.resetHeader = fmt.Sprintf("%d", time.Now().Unix()) | |||||
return b | |||||
} | |||||
// TODO: Custom burst? | |||||
// func (b *requestBuilder) Burst(burst int) *requestBuilder {} | |||||
func (b *requestBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler { | |||||
store.InitRate(b.rate, b.window) | |||||
for _, store := range fallbackStores { | |||||
store.InitRate(b.rate, b.window) | |||||
} | |||||
limiter := requestLimiter{ | |||||
requestBuilder: b, | |||||
store: store, | |||||
fallbackStores: fallbackStores, | |||||
} | |||||
fn := func(next http.Handler) http.Handler { | |||||
limiter.next = next | |||||
return &limiter | |||||
} | |||||
return fn | |||||
} | |||||
type requestLimiter struct { | |||||
*requestBuilder | |||||
next http.Handler | |||||
store TokenBucketStore | |||||
fallbackStores []TokenBucketStore | |||||
} | |||||
// ServeHTTPC implements http.Handler interface. | |||||
func (l *requestLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||
key := l.keyFn(r) | |||||
if key == "" { | |||||
l.next.ServeHTTP(w, r) | |||||
return | |||||
} | |||||
ok, remaining, reset, err := l.store.Take("request:" + key) | |||||
if err != nil { | |||||
for _, store := range l.fallbackStores { | |||||
ok, remaining, reset, err = store.Take("request:" + key) | |||||
if err == nil { | |||||
break | |||||
} | |||||
} | |||||
} | |||||
if err != nil { | |||||
l.next.ServeHTTP(w, r) | |||||
return | |||||
} | |||||
if !ok { | |||||
w.Header().Add("Retry-After", reset.Format(http.TimeFormat)) | |||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | |||||
return | |||||
} | |||||
w.Header().Add("X-RateLimit-Key", key) | |||||
w.Header().Add("X-RateLimit-Rate", l.rateHeader) | |||||
w.Header().Add("X-RateLimit-Limit", fmt.Sprintf("%d", l.rate)) | |||||
w.Header().Add("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) | |||||
w.Header().Add("X-RateLimit-Reset", fmt.Sprintf("%d", reset.Unix())) | |||||
w.Header().Add("Retry-After", reset.Format(http.TimeFormat)) | |||||
l.next.ServeHTTP(w, r) | |||||
} |
@@ -0,0 +1,19 @@ | |||||
package ratelimit_test | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
"github.com/VojtechVitek/ratelimit/memory" | |||||
) | |||||
func ExampleRequest() { | |||||
middleware := ratelimit.Request(ratelimit.IP).Rate(30, time.Minute).LimitBy(memory.New()) | |||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
w.Write([]byte("Hello World!")) | |||||
}) | |||||
http.ListenAndServe(":3333", middleware(handler)) | |||||
} |
@@ -0,0 +1,47 @@ | |||||
package ratelimit | |||||
import "net/http" | |||||
// Throttle is a middleware that limits number of currently | |||||
// processed requests at a time. | |||||
func Throttle(limit int) func(http.Handler) http.Handler { | |||||
if limit <= 0 { | |||||
panic("Throttle expects limit > 0") | |||||
} | |||||
t := throttler{ | |||||
tokens: make(chan token, limit), | |||||
} | |||||
for i := 0; i < limit; i++ { | |||||
t.tokens <- token{} | |||||
} | |||||
fn := func(h http.Handler) http.Handler { | |||||
t.h = h | |||||
return &t | |||||
} | |||||
return fn | |||||
} | |||||
// token represents a request that is being processed. | |||||
type token struct{} | |||||
// throttler limits number of currently processed requests at a time. | |||||
type throttler struct { | |||||
h http.Handler | |||||
tokens chan token | |||||
} | |||||
// ServeHTTP implements http.Handler interface. | |||||
func (t *throttler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||
select { | |||||
case <-r.Context().Done(): | |||||
return | |||||
case tok := <-t.tokens: | |||||
defer func() { | |||||
t.tokens <- tok | |||||
}() | |||||
t.h.ServeHTTP(w, r) | |||||
} | |||||
} |
@@ -0,0 +1,23 @@ | |||||
package ratelimit_test | |||||
import ( | |||||
"net/http" | |||||
"time" | |||||
"github.com/VojtechVitek/ratelimit" | |||||
) | |||||
func ExampleThrottle() { | |||||
middleware := ratelimit.Throttle(1) | |||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
w.Write([]byte("working hard...\n\n")) | |||||
if f, ok := w.(http.Flusher); ok { | |||||
f.Flush() | |||||
} | |||||
time.Sleep(10 * time.Second) | |||||
w.Write([]byte("done")) | |||||
}) | |||||
http.ListenAndServe(":3333", middleware(handler)) | |||||
} |