|
- package ratelimit
-
- import (
- "net/http"
- "time"
- )
-
- func DownloadSpeed(keyFn KeyFn) *downloadBuilder {
- return &downloadBuilder{
- keyFn: keyFn,
- }
- }
-
- type downloadBuilder struct {
- keyFn KeyFn
- rate int
- window time.Duration
- }
-
- func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder {
- b.rate = rate
- b.window = window
- return b
- }
-
- // TODO: Custom burst?
- // func (b *downloadBuilder) Burst(burst int) *downloadBuilder {}
-
- func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
- store.InitRate(b.rate, b.window)
- for _, store := range fallbackStores {
- store.InitRate(b.rate, b.window)
- }
-
- downloadLimiter := downloadLimiter{
- downloadBuilder: b,
- store: store,
- fallbackStores: fallbackStores,
- }
-
- return func(next http.Handler) http.Handler {
- fn := func(w http.ResponseWriter, r *http.Request) {
- key := downloadLimiter.keyFn(r)
- if key == "" {
- next.ServeHTTP(w, r)
- return
- }
-
- lw := &limitWriter{
- ResponseWriter: w,
- downloadLimiter: &downloadLimiter,
- key: key,
- }
-
- next.ServeHTTP(lw, r)
- }
- return http.HandlerFunc(fn)
- }
- }
-
- type downloadLimiter struct {
- *downloadBuilder
-
- next http.Handler
- store TokenBucketStore
- fallbackStores []TokenBucketStore
- }
-
- type limitWriter struct {
- http.ResponseWriter
- *downloadLimiter
-
- key string
- wroteHeader bool
- canWrite int64
- }
-
- func (w *limitWriter) Write(buf []byte) (int, error) {
- total := 0
- for {
- if w.canWrite < 1024 {
- ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key)
- if err != nil {
- for _, store := range w.fallbackStores {
- ok, _, _, err = store.Take("download:" + w.key)
- if err == nil {
- break
- }
- }
- }
- if err != nil {
- return total, err
- }
- if ok {
- w.canWrite += 1024
- }
- }
- if w.canWrite == 0 {
- continue
- }
-
- max := len(buf) - total
- if int(w.canWrite) < max {
- max = int(w.canWrite)
- }
- if max == 0 {
- return total, nil
- }
-
- n, err := w.ResponseWriter.Write(buf[total : total+max])
- w.canWrite -= int64(n)
- total += n
- if err != nil {
- return total, err
- }
- }
- }
|