You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

118 lines
2.2 KiB

  1. package ratelimit
  2. import (
  3. "net/http"
  4. "time"
  5. )
  6. func DownloadSpeed(keyFn KeyFn) *downloadBuilder {
  7. return &downloadBuilder{
  8. keyFn: keyFn,
  9. }
  10. }
  11. type downloadBuilder struct {
  12. keyFn KeyFn
  13. rate int
  14. window time.Duration
  15. }
  16. func (b *downloadBuilder) Rate(rate int, window time.Duration) *downloadBuilder {
  17. b.rate = rate
  18. b.window = window
  19. return b
  20. }
  21. // TODO: Custom burst?
  22. // func (b *downloadBuilder) Burst(burst int) *downloadBuilder {}
  23. func (b *downloadBuilder) LimitBy(store TokenBucketStore, fallbackStores ...TokenBucketStore) func(http.Handler) http.Handler {
  24. store.InitRate(b.rate, b.window)
  25. for _, store := range fallbackStores {
  26. store.InitRate(b.rate, b.window)
  27. }
  28. downloadLimiter := downloadLimiter{
  29. downloadBuilder: b,
  30. store: store,
  31. fallbackStores: fallbackStores,
  32. }
  33. return func(next http.Handler) http.Handler {
  34. fn := func(w http.ResponseWriter, r *http.Request) {
  35. key := downloadLimiter.keyFn(r)
  36. if key == "" {
  37. next.ServeHTTP(w, r)
  38. return
  39. }
  40. lw := &limitWriter{
  41. ResponseWriter: w,
  42. downloadLimiter: &downloadLimiter,
  43. key: key,
  44. }
  45. next.ServeHTTP(lw, r)
  46. }
  47. return http.HandlerFunc(fn)
  48. }
  49. }
  50. type downloadLimiter struct {
  51. *downloadBuilder
  52. next http.Handler
  53. store TokenBucketStore
  54. fallbackStores []TokenBucketStore
  55. }
  56. type limitWriter struct {
  57. http.ResponseWriter
  58. *downloadLimiter
  59. key string
  60. wroteHeader bool
  61. canWrite int64
  62. }
  63. func (w *limitWriter) Write(buf []byte) (int, error) {
  64. total := 0
  65. for {
  66. if w.canWrite < 1024 {
  67. ok, _, _, err := w.downloadLimiter.store.Take("download:" + w.key)
  68. if err != nil {
  69. for _, store := range w.fallbackStores {
  70. ok, _, _, err = store.Take("download:" + w.key)
  71. if err == nil {
  72. break
  73. }
  74. }
  75. }
  76. if err != nil {
  77. return total, err
  78. }
  79. if ok {
  80. w.canWrite += 1024
  81. }
  82. }
  83. if w.canWrite == 0 {
  84. continue
  85. }
  86. max := len(buf) - total
  87. if int(w.canWrite) < max {
  88. max = int(w.canWrite)
  89. }
  90. if max == 0 {
  91. return total, nil
  92. }
  93. n, err := w.ResponseWriter.Write(buf[total : total+max])
  94. w.canWrite -= int64(n)
  95. total += n
  96. if err != nil {
  97. return total, err
  98. }
  99. }
  100. }