@@ -0,0 +1,8 @@ | |||
sudo: false | |||
language: go | |||
go: | |||
- 1.7rc2 | |||
script: | |||
- go test ./... |
@@ -0,0 +1,22 @@ | |||
MIT License | |||
Copyright (c) 2016 Vojtech Vitek | |||
Permission is hereby granted, free of charge, to any person obtaining | |||
a copy of this software and associated documentation files (the | |||
"Software"), to deal in the Software without restriction, including | |||
without limitation the rights to use, copy, modify, merge, publish, | |||
distribute, sublicense, and/or sell copies of the Software, and to | |||
permit persons to whom the Software is furnished to do so, subject to | |||
the following conditions: | |||
The above copyright notice and this permission notice shall be | |||
included in all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -0,0 +1,22 @@ | |||
# Rate Limit HTTP middleware | |||
[![GoDoc Widget]][GoDoc] [![Travis Widget]][Travis] | |||
[Golang](http://golang.org/) package for rate limiting HTTP endpoints based on context and request headers. | |||
[GoDoc]: https://godoc.org/github.com/VojtechVitek/ratelimit | |||
[GoDoc Widget]: https://godoc.org/github.com/VojtechVitek/ratelimit?status.svg | |||
[Travis]: https://travis-ci.org/VojtechVitek/ratelimit | |||
[Travis Widget]: https://travis-ci.org/VojtechVitek/ratelimit.svg?branch=master | |||
# Under development | |||
# Goals | |||
- Simple but powerful API | |||
- Token Bucket algorithm (rate + burst) | |||
- Storage independent (Redis, In-Memory or any other K/V store) | |||
# License | |||
Copyright (c) 2016 Vojtech Vitek | |||
Licensed under the [MIT License](./LICENSE). |
@@ -0,0 +1,117 @@ | |||
package ratelimit | |||
import ( | |||
"net/http" | |||
"time" | |||
) | |||
func DownloadSpeed(keyFn KeyFn) *downloadBuilder { | |||
return &downloadBuilder{ | |||
keyFn: keyFn, | |||
} | |||
} | |||
type downloadBuilder struct { | |||
keyFn KeyFn | |||
rate int | |||
window time.Duration | |||
} | |||
func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder { | |||
b.rate = rate | |||
b.window = window | |||
return b | |||
} | |||
// TODO: Custom burst? | |||
// func (b *downloadBuilder) Burst(burst int) *downloadBuilder {} | |||
func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler { | |||
store.InitRate(b.rate, b.window) | |||
for _, store := range fallbackStores { | |||
store.InitRate(b.rate, b.window) | |||
} | |||
downloadLimiter := downloadLimiter{ | |||
downloadBuilder: b, | |||
store: store, | |||
fallbackStores: fallbackStores, | |||
} | |||
return func(next http.Handler) http.Handler { | |||
fn := func(w http.ResponseWriter, r *http.Request) { | |||
key := downloadLimiter.keyFn(r) | |||
if key == "" { | |||
next.ServeHTTP(w, r) | |||
return | |||
} | |||
lw := &limitWriter{ | |||
ResponseWriter: w, | |||
downloadLimiter: &downloadLimiter, | |||
key: key, | |||
} | |||
next.ServeHTTP(lw, r) | |||
} | |||
return http.HandlerFunc(fn) | |||
} | |||
} | |||
type downloadLimiter struct { | |||
*downloadBuilder | |||
next http.Handler | |||
store TokenBucketStore | |||
fallbackStores []TokenBucketStore | |||
} | |||
type limitWriter struct { | |||
http.ResponseWriter | |||
*downloadLimiter | |||
key string | |||
wroteHeader bool | |||
canWrite int64 | |||
} | |||
func (w *limitWriter) Write(buf []byte) (int, error) { | |||
total := 0 | |||
for { | |||
if w.canWrite < 1024 { | |||
ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key) | |||
if err != nil { | |||
for _, store := range w.fallbackStores { | |||
ok, _, _, err = store.Take("download:" + w.key) | |||
if err == nil { | |||
break | |||
} | |||
} | |||
} | |||
if err != nil { | |||
return total, err | |||
} | |||
if ok { | |||
w.canWrite += 1024 | |||
} | |||
} | |||
if w.canWrite == 0 { | |||
continue | |||
} | |||
max := len(buf) - total | |||
if int(w.canWrite) < max { | |||
max = int(w.canWrite) | |||
} | |||
if max == 0 { | |||
return total, nil | |||
} | |||
n, err := w.ResponseWriter.Write(buf[total : total+max]) | |||
w.canWrite -= int64(n) | |||
total += n | |||
if err != nil { | |||
return total, err | |||
} | |||
} | |||
} |
@@ -0,0 +1,21 @@ | |||
package ratelimit_test | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
"github.com/VojtechVitek/ratelimit/memory" | |||
) | |||
// Watch the download speed with | |||
// wget http://localhost:3333/file -q --show-progress | |||
func ExampleDownloadSpeed() { | |||
middleware := ratelimit.DownloadSpeed(ratelimit.IP).Rate(1024, time.Second).LimitBy(memory.New()) | |||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
http.ServeFile(w, r, "/dev/random") | |||
}) | |||
http.ListenAndServe(":3333", middleware(handler)) | |||
} |
@@ -0,0 +1,46 @@ | |||
package main | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
"github.com/VojtechVitek/ratelimit/memory" | |||
"github.com/VojtechVitek/ratelimit/redis" | |||
redigo "github.com/garyburd/redigo/redis" | |||
"github.com/pressly/chi" | |||
"github.com/pressly/chi/middleware" | |||
) | |||
var pool = &redigo.Pool{ | |||
MaxIdle: 10, | |||
MaxActive: 50, | |||
IdleTimeout: 300 * time.Second, | |||
Wait: false, // Important | |||
Dial: func() (redigo.Conn, error) { | |||
c, err := redigo.DialTimeout("tcp", "127.0.0.1:6379", 200*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return c, err | |||
}, | |||
TestOnBorrow: func(c redigo.Conn, t time.Time) error { | |||
_, err := c.Do("PING") | |||
return err | |||
}, | |||
} | |||
// wget http://localhost:3333 -q --show-progress | |||
func main() { | |||
r := chi.NewRouter() | |||
r.Use(middleware.Logger) | |||
r.Use(ratelimit.DownloadSpeed(ratelimit.IP).Rate(1024, time.Second).LimitBy(redis.New(pool), memory.New())) | |||
r.Get("/", ServeVideo) | |||
http.ListenAndServe(":3333", r) | |||
} | |||
func ServeVideo(w http.ResponseWriter, r *http.Request) { | |||
http.ServeFile(w, r, "/Users/vojtechvitek/Desktop/govideo.mov") | |||
} |
@@ -0,0 +1,55 @@ | |||
package main | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
"github.com/VojtechVitek/ratelimit/memory" | |||
"github.com/VojtechVitek/ratelimit/redis" | |||
redigo "github.com/garyburd/redigo/redis" | |||
"github.com/pressly/chi" | |||
"github.com/pressly/chi/middleware" | |||
) | |||
var pool = &redigo.Pool{ | |||
MaxIdle: 10, | |||
MaxActive: 50, | |||
IdleTimeout: 300 * time.Second, | |||
Wait: false, // Important | |||
Dial: func() (redigo.Conn, error) { | |||
c, err := redigo.DialTimeout("tcp", "127.0.0.1:6379", 200*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return c, err | |||
}, | |||
TestOnBorrow: func(c redigo.Conn, t time.Time) error { | |||
_, err := c.Do("PING") | |||
return err | |||
}, | |||
} | |||
// while :; do curl -v localhost:3333; sleep 0.1; done | |||
func main() { | |||
r := chi.NewRouter() | |||
r.Use(middleware.Logger) | |||
//r.Use(ratelimit.Request(ratelimit.IP).Rate(1, time.Second).LimitBy(redis.New(pool))) | |||
r.Use(ratelimit.Request(ratelimit.IP).Rate(5, 5*time.Second).LimitBy(redis.New(pool), memory.New())) | |||
r.Get("/", Hello) | |||
http.ListenAndServe(":3333", r) | |||
} | |||
func Hello(w http.ResponseWriter, r *http.Request) { | |||
w.Write([]byte("Hello World!\n")) | |||
//w.Write([]byte("Hello user_id=" + r.URL.Query().Get("user_id") + "\n")) | |||
} | |||
func UserKey(r *http.Request) string { | |||
user := r.URL.Query().Get("user_id") | |||
// user, _ := r.Context().Value("session.user_id").(string) | |||
return user | |||
} |
@@ -0,0 +1,24 @@ | |||
package main | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
) | |||
// curl -v http://localhost:3333 | |||
func main() { | |||
middleware := ratelimit.Throttle(1) | |||
http.ListenAndServe(":3333", middleware(http.HandlerFunc(Work))) | |||
} | |||
func Work(w http.ResponseWriter, r *http.Request) { | |||
w.Write([]byte("working hard...\n\n")) | |||
if f, ok := w.(http.Flusher); ok { | |||
f.Flush() | |||
} | |||
time.Sleep(10 * time.Second) | |||
w.Write([]byte("done")) | |||
} |
@@ -0,0 +1,27 @@ | |||
package ratelimit | |||
import ( | |||
"net" | |||
"net/http" | |||
"strings" | |||
) | |||
// IP returns unique key per request IP. | |||
func IP(r *http.Request) string { | |||
ip, _, _ := net.SplitHostPort(r.RemoteAddr) | |||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { | |||
if i := strings.IndexAny(xff, ",;"); i != -1 { | |||
xff = xff[:i] | |||
} | |||
ip += "," + xff | |||
} | |||
if xrip := r.Header.Get("X-Real-IP"); xrip != "" { | |||
ip += "," + xrip | |||
} | |||
return ip | |||
} | |||
// NOP returns empty key for each request. | |||
func NOP(r *http.Request) string { | |||
return "" | |||
} |
@@ -0,0 +1,62 @@ | |||
package memory | |||
import ( | |||
"sync" | |||
"time" | |||
) | |||
type token struct{} | |||
type bucketStore struct { | |||
sync.Mutex // guards buckets | |||
buckets map[string]chan token | |||
bucketLen int | |||
reset time.Time | |||
} | |||
// New creates new in-memory token bucket store. | |||
func New() *bucketStore { | |||
return &bucketStore{ | |||
buckets: map[string]chan token{}, | |||
} | |||
} | |||
func (s *bucketStore) InitRate(rate int, window time.Duration) { | |||
s.bucketLen = rate | |||
s.reset = time.Now() | |||
go func() { | |||
interval := time.Duration(int(window) / rate) | |||
tick := time.NewTicker(interval) | |||
for t := range tick.C { | |||
s.Lock() | |||
s.reset = t.Add(interval) | |||
for key, bucket := range s.buckets { | |||
select { | |||
case <-bucket: | |||
default: | |||
delete(s.buckets, key) | |||
} | |||
} | |||
s.Unlock() | |||
} | |||
}() | |||
} | |||
// Take implements TokenBucketStore interface. It takes token from a bucket | |||
// referenced by a given key, if available. | |||
func (s *bucketStore) Take(key string) (bool, int, time.Time, error) { | |||
s.Lock() | |||
bucket, ok := s.buckets[key] | |||
if !ok { | |||
bucket = make(chan token, s.bucketLen) | |||
s.buckets[key] = bucket | |||
} | |||
s.Unlock() | |||
select { | |||
case bucket <- token{}: | |||
return true, cap(bucket) - len(bucket), s.reset, nil | |||
default: | |||
return false, 0, s.reset, nil | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
package ratelimit | |||
import ( | |||
"net/http" | |||
"time" | |||
) | |||
// TokenBucketStore is an interface for for any storage implementing | |||
// Token Bucket algorithm. | |||
type TokenBucketStore interface { | |||
InitRate(rate int, window time.Duration) | |||
Take(key string) (taken bool, remaining int, reset time.Time, err error) | |||
} | |||
// KeyFn is a function returning bucket key depending on request data. | |||
type KeyFn func(r *http.Request) string |
@@ -0,0 +1,94 @@ | |||
package redis | |||
import ( | |||
"errors" | |||
"time" | |||
"github.com/garyburd/redigo/redis" | |||
) | |||
var ( | |||
PrefixKey = "ratelimit:" | |||
ErrUnreachable = errors.New("redis is unreachable") | |||
RetryAfter = time.Second | |||
) | |||
const skipOnUnhealthy = 1000 | |||
type bucketStore struct { | |||
pool *redis.Pool | |||
rate int | |||
windowSeconds int | |||
retryAfter *time.Time | |||
} | |||
// New creates new in-memory token bucket store. | |||
func New(pool *redis.Pool) *bucketStore { | |||
return &bucketStore{ | |||
pool: pool, | |||
} | |||
} | |||
func (s *bucketStore) InitRate(rate int, window time.Duration) { | |||
s.rate = rate | |||
s.windowSeconds = int(window / time.Second) | |||
if s.windowSeconds <= 1 { | |||
s.windowSeconds = 1 | |||
} | |||
} | |||
// Take implements TokenBucketStore interface. It takes token from a bucket | |||
// referenced by a given key, if available. | |||
func (s *bucketStore) Take(key string) (bool, int, time.Time, error) { | |||
if s.retryAfter != nil { | |||
if s.retryAfter.After(time.Now()) { | |||
return false, 0, time.Time{}, ErrUnreachable | |||
} | |||
s.retryAfter = nil | |||
} | |||
c := s.pool.Get() | |||
defer c.Close() | |||
// Number of tokens in the bucket. | |||
bucketLen, err := redis.Int(c.Do("LLEN", PrefixKey+key)) | |||
if err != nil { | |||
next := time.Now().Add(time.Second) | |||
s.retryAfter = &next | |||
return false, 0, time.Time{}, err | |||
} | |||
// Bucket is full. | |||
if bucketLen >= s.rate { | |||
return false, 0, time.Time{}, nil | |||
} | |||
if bucketLen > 0 { | |||
// Bucket most probably exists, try to push a new token into it. | |||
// If RPUSHX returns 0 (ie. key expired between LLEN and RPUSHX), we need | |||
// to fall-back to RPUSH without returning error. | |||
c.Send("MULTI") | |||
c.Send("RPUSHX", PrefixKey+key, "") | |||
reply, err := redis.Ints(c.Do("EXEC")) | |||
if err != nil { | |||
next := time.Now().Add(time.Second) | |||
s.retryAfter = &next | |||
return false, 0, time.Time{}, err | |||
} | |||
bucketLen = reply[0] | |||
if bucketLen > 0 { | |||
return true, s.rate - bucketLen - 1, time.Time{}, nil | |||
} | |||
} | |||
c.Send("MULTI") | |||
c.Send("RPUSH", PrefixKey+key, "") | |||
c.Send("EXPIRE", PrefixKey+key, s.windowSeconds) | |||
if _, err := c.Do("EXEC"); err != nil { | |||
next := time.Now().Add(time.Second) | |||
s.retryAfter = &next | |||
return false, 0, time.Time{}, err | |||
} | |||
return true, s.rate - bucketLen - 1, time.Time{}, nil | |||
} |
@@ -0,0 +1,95 @@ | |||
package ratelimit | |||
import ( | |||
"fmt" | |||
"net/http" | |||
"time" | |||
) | |||
func Request(keyFn KeyFn) *requestBuilder { | |||
return &requestBuilder{ | |||
keyFn: keyFn, | |||
} | |||
} | |||
type requestBuilder struct { | |||
keyFn KeyFn | |||
rate int | |||
window time.Duration | |||
rateHeader string | |||
resetHeader string | |||
} | |||
func (b *requestBuilder) Rate(rate int, window time.Duration) *requestBuilder { | |||
b.rate = rate | |||
b.window = window | |||
b.rateHeader = fmt.Sprintf("%v", float32(rate)*float32(window/time.Second)) | |||
b.resetHeader = fmt.Sprintf("%d", time.Now().Unix()) | |||
return b | |||
} | |||
// TODO: Custom burst? | |||
// func (b *requestBuilder) Burst(burst int) *requestBuilder {} | |||
func (b *requestBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler { | |||
store.InitRate(b.rate, b.window) | |||
for _, store := range fallbackStores { | |||
store.InitRate(b.rate, b.window) | |||
} | |||
limiter := requestLimiter{ | |||
requestBuilder: b, | |||
store: store, | |||
fallbackStores: fallbackStores, | |||
} | |||
fn := func(next http.Handler) http.Handler { | |||
limiter.next = next | |||
return &limiter | |||
} | |||
return fn | |||
} | |||
type requestLimiter struct { | |||
*requestBuilder | |||
next http.Handler | |||
store TokenBucketStore | |||
fallbackStores []TokenBucketStore | |||
} | |||
// ServeHTTPC implements http.Handler interface. | |||
func (l *requestLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
key := l.keyFn(r) | |||
if key == "" { | |||
l.next.ServeHTTP(w, r) | |||
return | |||
} | |||
ok, remaining, reset, err := l.store.Take("request:" + key) | |||
if err != nil { | |||
for _, store := range l.fallbackStores { | |||
ok, remaining, reset, err = store.Take("request:" + key) | |||
if err == nil { | |||
break | |||
} | |||
} | |||
} | |||
if err != nil { | |||
l.next.ServeHTTP(w, r) | |||
return | |||
} | |||
if !ok { | |||
w.Header().Add("Retry-After", reset.Format(http.TimeFormat)) | |||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) | |||
return | |||
} | |||
w.Header().Add("X-RateLimit-Key", key) | |||
w.Header().Add("X-RateLimit-Rate", l.rateHeader) | |||
w.Header().Add("X-RateLimit-Limit", fmt.Sprintf("%d", l.rate)) | |||
w.Header().Add("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) | |||
w.Header().Add("X-RateLimit-Reset", fmt.Sprintf("%d", reset.Unix())) | |||
w.Header().Add("Retry-After", reset.Format(http.TimeFormat)) | |||
l.next.ServeHTTP(w, r) | |||
} |
@@ -0,0 +1,19 @@ | |||
package ratelimit_test | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
"github.com/VojtechVitek/ratelimit/memory" | |||
) | |||
func ExampleRequest() { | |||
middleware := ratelimit.Request(ratelimit.IP).Rate(30, time.Minute).LimitBy(memory.New()) | |||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Write([]byte("Hello World!")) | |||
}) | |||
http.ListenAndServe(":3333", middleware(handler)) | |||
} |
@@ -0,0 +1,47 @@ | |||
package ratelimit | |||
import "net/http" | |||
// Throttle is a middleware that limits number of currently | |||
// processed requests at a time. | |||
func Throttle(limit int) func(http.Handler) http.Handler { | |||
if limit <= 0 { | |||
panic("Throttle expects limit > 0") | |||
} | |||
t := throttler{ | |||
tokens: make(chan token, limit), | |||
} | |||
for i := 0; i < limit; i++ { | |||
t.tokens <- token{} | |||
} | |||
fn := func(h http.Handler) http.Handler { | |||
t.h = h | |||
return &t | |||
} | |||
return fn | |||
} | |||
// token represents a request that is being processed. | |||
type token struct{} | |||
// throttler limits number of currently processed requests at a time. | |||
type throttler struct { | |||
h http.Handler | |||
tokens chan token | |||
} | |||
// ServeHTTP implements http.Handler interface. | |||
func (t *throttler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||
select { | |||
case <-r.Context().Done(): | |||
return | |||
case tok := <-t.tokens: | |||
defer func() { | |||
t.tokens <- tok | |||
}() | |||
t.h.ServeHTTP(w, r) | |||
} | |||
} |
@@ -0,0 +1,23 @@ | |||
package ratelimit_test | |||
import ( | |||
"net/http" | |||
"time" | |||
"github.com/VojtechVitek/ratelimit" | |||
) | |||
func ExampleThrottle() { | |||
middleware := ratelimit.Throttle(1) | |||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
w.Write([]byte("working hard...\n\n")) | |||
if f, ok := w.(http.Flusher); ok { | |||
f.Flush() | |||
} | |||
time.Sleep(10 * time.Second) | |||
w.Write([]byte("done")) | |||
}) | |||
http.ListenAndServe(":3333", middleware(handler)) | |||
} |