@@ -0,0 +1,2 @@ | |||||
.idea | |||||
dist/ |
@@ -0,0 +1,16 @@ | |||||
[servers] | |||||
[servers.main] | |||||
addr = "127.0.0.1:6379" | |||||
[servers.other] | |||||
addr = "127.0.0.1:6380" | |||||
[servers.third] | |||||
addr = "127.0.0.1:6381" | |||||
[[shovels]] | |||||
src = "main" | |||||
dst = "other" | |||||
key = "test" | |||||
[[shovels]] | |||||
src = "other" | |||||
dst = "third" | |||||
key = "foo" | |||||
dstkey = "bar" |
@@ -0,0 +1,13 @@ | |||||
module reshovel | |||||
go 1.17 | |||||
require ( | |||||
github.com/BurntSushi/toml v1.0.0 | |||||
github.com/go-redis/redis/v8 v8.11.4 | |||||
) | |||||
require ( | |||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect | |||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | |||||
) |
@@ -0,0 +1,99 @@ | |||||
github.com/BurntSushi/toml v1.0.0 h1:dtDWrepsVPfW9H/4y7dDgFc2MBUSeJhlaDtK13CxFlU= | |||||
github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= | |||||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | |||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | |||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | |||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | |||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= | |||||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | |||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= | |||||
github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= | |||||
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= | |||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= | |||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | |||||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= | |||||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= | |||||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= | |||||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= | |||||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= | |||||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= | |||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | |||||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= | |||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | |||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | |||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | |||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= | |||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | |||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= | |||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | |||||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= | |||||
github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= | |||||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= | |||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= | |||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= | |||||
github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= | |||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= | |||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | |||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | |||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= | |||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= | |||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | |||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= | |||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | |||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= | |||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | |||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | |||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | |||||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= | |||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= | |||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= | |||||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= | |||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= | |||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | |||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | |||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= | |||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | |||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= | |||||
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= | |||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= | |||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= | |||||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= | |||||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= | |||||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= | |||||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= | |||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | |||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= | |||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | |||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= | |||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | |||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= | |||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= | |||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= |
@@ -0,0 +1,164 @@ | |||||
package main | |||||
import ( | |||||
"context" | |||||
"log" | |||||
"os" | |||||
"os/signal" | |||||
"syscall" | |||||
"time" | |||||
"github.com/BurntSushi/toml" | |||||
"github.com/go-redis/redis/v8" | |||||
) | |||||
const BatchSize = 10000 | |||||
var RedisClients map[string]*redis.Client | |||||
type Config struct { | |||||
RedisServerConfigs map[string]RedisServerConfig `toml:"servers"` | |||||
ShovelConfigs []ShovelConfig `toml:"shovels"` | |||||
} | |||||
type ShovelConfig struct { | |||||
Key string `toml:"key"` | |||||
Src string `toml:"src"` | |||||
Dst string `toml:"dst"` | |||||
DstKey string `toml:"dstkey"` | |||||
} | |||||
type RedisServerConfig struct { | |||||
Network string `toml:"network"` | |||||
Addr string `toml:"addr"` | |||||
Username string `toml:"username"` | |||||
Password string `toml:"password"` | |||||
DB int `toml:"db"` | |||||
MaxRetries int `toml:"maxretries"` | |||||
MinRetryBackoff float64 `toml:"minretrybackoff"` | |||||
MaxRetryBackoff float64 `toml:"maxretrybackoff"` | |||||
DialTimeout float64 `toml:"dialtimeout"` | |||||
ReadTimeout float64 `toml:"readtimeout"` | |||||
WriteTimeout float64 `toml:"writetimeout"` | |||||
PoolFIFO bool `toml:"poolfifo"` | |||||
PoolSize int `toml:"poolsize"` | |||||
MinIdleConns int `toml:"minidleconns"` | |||||
MaxConnAge float64 `toml:"maxconnage"` | |||||
PoolTimeout float64 `toml:"pooltimeout"` | |||||
IdleTimeout float64 `toml:"idletimeout"` | |||||
IdleCheckFrequency float64 `toml:"idlecheckfrequency"` | |||||
} | |||||
func RedisConfigToRedisOptions(config RedisServerConfig) *redis.Options { | |||||
nano := float64(time.Second.Nanoseconds()) | |||||
if config.ReadTimeout == 0 { | |||||
config.ReadTimeout = 15 * time.Minute.Seconds() | |||||
} | |||||
return &redis.Options{ | |||||
Network: config.Network, | |||||
Addr: config.Addr, | |||||
Username: config.Username, | |||||
Password: config.Password, | |||||
DB: config.DB, | |||||
MaxRetries: config.MaxRetries, | |||||
MinRetryBackoff: time.Duration(config.MinRetryBackoff * nano), | |||||
MaxRetryBackoff: time.Duration(config.MaxRetryBackoff * nano), | |||||
DialTimeout: time.Duration(config.DialTimeout * nano), | |||||
ReadTimeout: time.Duration(config.ReadTimeout * nano), | |||||
WriteTimeout: time.Duration(config.WriteTimeout * nano), | |||||
PoolFIFO: config.PoolFIFO, | |||||
PoolSize: config.PoolSize, | |||||
MinIdleConns: config.MinIdleConns, | |||||
MaxConnAge: time.Duration(config.MaxConnAge * nano), | |||||
PoolTimeout: time.Duration(config.PoolTimeout * nano), | |||||
IdleTimeout: time.Duration(config.IdleTimeout * nano), | |||||
IdleCheckFrequency: time.Duration(config.IdleCheckFrequency * nano), | |||||
} | |||||
} | |||||
func StartShovelWorker(c context.Context, dc chan bool, s *redis.Client, d *redis.Client, sk string, dk string) { | |||||
defer close(dc) | |||||
if dk == "" { | |||||
dk = sk | |||||
} | |||||
var m time.Duration = 0 | |||||
for { | |||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | |||||
items, err := s.SPopN(ctx, sk, BatchSize).Result() | |||||
cancel() | |||||
if err != nil { | |||||
log.Printf("unable to spop %s: %s", sk, err) | |||||
} else if len(items) != 0 { | |||||
var iitems []interface{} | |||||
for _, item := range items { | |||||
iitems = append(iitems, item) | |||||
} | |||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) | |||||
err = d.SAdd(ctx, dk, iitems...).Err() | |||||
cancel() | |||||
if err != nil { | |||||
log.Printf("unable to sadd %s: %s", dk, err) | |||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) | |||||
err = s.SAdd(ctx, sk, iitems...).Err() | |||||
cancel() | |||||
if err != nil { | |||||
log.Printf("unable to revert spop %s: %s", sk, err) | |||||
} | |||||
} else if len(items) >= BatchSize { | |||||
m = 0 | |||||
} | |||||
} | |||||
t := time.NewTimer(m * time.Second) | |||||
select { | |||||
case <-c.Done(): | |||||
if !t.Stop() { | |||||
<-t.C | |||||
} | |||||
return | |||||
case <-t.C: | |||||
} | |||||
if m < 60 { | |||||
m++ | |||||
} | |||||
} | |||||
} | |||||
func main() { | |||||
var config Config | |||||
_, err := toml.DecodeFile("./config.toml", &config) | |||||
if err != nil { | |||||
log.Panicf("error parsing config.toml: %s", err) | |||||
} | |||||
RedisClients = map[string]*redis.Client{} | |||||
for i, c := range config.ShovelConfigs { | |||||
if _, has := config.RedisServerConfigs[c.Src]; !has { | |||||
log.Panicf("invalid redis source: %s", c.Src) | |||||
} | |||||
if _, has := config.RedisServerConfigs[c.Dst]; !has { | |||||
log.Panicf("invalid redis destination: %s", c.Dst) | |||||
} | |||||
if c.DstKey == "" { | |||||
config.ShovelConfigs[i].DstKey = c.Key | |||||
} | |||||
} | |||||
for n, c := range config.RedisServerConfigs { | |||||
RedisClients[n] = redis.NewClient(RedisConfigToRedisOptions(c)) | |||||
} | |||||
ctx, cancel := context.WithCancel(context.Background()) | |||||
var doneChans []chan bool | |||||
for _, c := range config.ShovelConfigs { | |||||
log.Printf("starting shovel worker for %s/%s -> %s/%s", c.Src, c.Key, c.Dst, c.DstKey) | |||||
doneChan := make(chan bool) | |||||
go StartShovelWorker(ctx, doneChan, RedisClients[c.Src], RedisClients[c.Dst], c.Key, c.DstKey) | |||||
doneChans = append(doneChans, doneChan) | |||||
} | |||||
sc := make(chan os.Signal, 1) | |||||
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill) | |||||
<-sc | |||||
cancel() | |||||
log.Printf("waiting for %d workers to shut down...", len(doneChans)) | |||||
for _, c := range doneChans { | |||||
<-c | |||||
} | |||||
} |
@@ -0,0 +1,2 @@ | |||||
toml.test | |||||
/toml-test |
@@ -0,0 +1 @@ | |||||
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0). |
@@ -0,0 +1,21 @@ | |||||
The MIT License (MIT) | |||||
Copyright (c) 2013 TOML authors | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
The above copyright notice and this permission notice shall be included in | |||||
all copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |||||
THE SOFTWARE. |
@@ -0,0 +1,211 @@ | |||||
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a | |||||
reflection interface similar to Go's standard library `json` and `xml` | |||||
packages. | |||||
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0). | |||||
Documentation: https://godocs.io/github.com/BurntSushi/toml | |||||
See the [releases page](https://github.com/BurntSushi/toml/releases) for a | |||||
changelog; this information is also in the git tag annotations (e.g. `git show | |||||
v0.4.0`). | |||||
This library requires Go 1.13 or newer; install it with: | |||||
% go get github.com/BurntSushi/toml@latest | |||||
It also comes with a TOML validator CLI tool: | |||||
% go install github.com/BurntSushi/toml/cmd/tomlv@latest | |||||
% tomlv some-toml-file.toml | |||||
### Testing | |||||
This package passes all tests in [toml-test] for both the decoder and the | |||||
encoder. | |||||
[toml-test]: https://github.com/BurntSushi/toml-test | |||||
### Examples | |||||
This package works similar to how the Go standard library handles XML and JSON. | |||||
Namely, data is loaded into Go values via reflection. | |||||
For the simplest example, consider some TOML file as just a list of keys and | |||||
values: | |||||
```toml | |||||
Age = 25 | |||||
Cats = [ "Cauchy", "Plato" ] | |||||
Pi = 3.14 | |||||
Perfection = [ 6, 28, 496, 8128 ] | |||||
DOB = 1987-07-05T05:45:00Z | |||||
``` | |||||
Which could be defined in Go as: | |||||
```go | |||||
type Config struct { | |||||
Age int | |||||
Cats []string | |||||
Pi float64 | |||||
Perfection []int | |||||
DOB time.Time // requires `import time` | |||||
} | |||||
``` | |||||
And then decoded with: | |||||
```go | |||||
var conf Config | |||||
err := toml.Decode(tomlData, &conf) | |||||
// handle error | |||||
``` | |||||
You can also use struct tags if your struct field name doesn't map to a TOML | |||||
key value directly: | |||||
```toml | |||||
some_key_NAME = "wat" | |||||
``` | |||||
```go | |||||
type TOML struct { | |||||
ObscureKey string `toml:"some_key_NAME"` | |||||
} | |||||
``` | |||||
Beware that like other most other decoders **only exported fields** are | |||||
considered when encoding and decoding; private fields are silently ignored. | |||||
### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces | |||||
Here's an example that automatically parses duration strings into | |||||
`time.Duration` values: | |||||
```toml | |||||
[[song]] | |||||
name = "Thunder Road" | |||||
duration = "4m49s" | |||||
[[song]] | |||||
name = "Stairway to Heaven" | |||||
duration = "8m03s" | |||||
``` | |||||
Which can be decoded with: | |||||
```go | |||||
type song struct { | |||||
Name string | |||||
Duration duration | |||||
} | |||||
type songs struct { | |||||
Song []song | |||||
} | |||||
var favorites songs | |||||
if _, err := toml.Decode(blob, &favorites); err != nil { | |||||
log.Fatal(err) | |||||
} | |||||
for _, s := range favorites.Song { | |||||
fmt.Printf("%s (%s)\n", s.Name, s.Duration) | |||||
} | |||||
``` | |||||
And you'll also need a `duration` type that satisfies the | |||||
`encoding.TextUnmarshaler` interface: | |||||
```go | |||||
type duration struct { | |||||
time.Duration | |||||
} | |||||
func (d *duration) UnmarshalText(text []byte) error { | |||||
var err error | |||||
d.Duration, err = time.ParseDuration(string(text)) | |||||
return err | |||||
} | |||||
``` | |||||
To target TOML specifically you can implement `UnmarshalTOML` TOML interface in | |||||
a similar way. | |||||
### More complex usage | |||||
Here's an example of how to load the example from the official spec page: | |||||
```toml | |||||
# This is a TOML document. Boom. | |||||
title = "TOML Example" | |||||
[owner] | |||||
name = "Tom Preston-Werner" | |||||
organization = "GitHub" | |||||
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer." | |||||
dob = 1979-05-27T07:32:00Z # First class dates? Why not? | |||||
[database] | |||||
server = "192.168.1.1" | |||||
ports = [ 8001, 8001, 8002 ] | |||||
connection_max = 5000 | |||||
enabled = true | |||||
[servers] | |||||
# You can indent as you please. Tabs or spaces. TOML don't care. | |||||
[servers.alpha] | |||||
ip = "10.0.0.1" | |||||
dc = "eqdc10" | |||||
[servers.beta] | |||||
ip = "10.0.0.2" | |||||
dc = "eqdc10" | |||||
[clients] | |||||
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it | |||||
# Line breaks are OK when inside arrays | |||||
hosts = [ | |||||
"alpha", | |||||
"omega" | |||||
] | |||||
``` | |||||
And the corresponding Go types are: | |||||
```go | |||||
type tomlConfig struct { | |||||
Title string | |||||
Owner ownerInfo | |||||
DB database `toml:"database"` | |||||
Servers map[string]server | |||||
Clients clients | |||||
} | |||||
type ownerInfo struct { | |||||
Name string | |||||
Org string `toml:"organization"` | |||||
Bio string | |||||
DOB time.Time | |||||
} | |||||
type database struct { | |||||
Server string | |||||
Ports []int | |||||
ConnMax int `toml:"connection_max"` | |||||
Enabled bool | |||||
} | |||||
type server struct { | |||||
IP string | |||||
DC string | |||||
} | |||||
type clients struct { | |||||
Data [][]interface{} | |||||
Hosts []string | |||||
} | |||||
``` | |||||
Note that a case insensitive match will be tried if an exact match can't be | |||||
found. | |||||
A working example of the above can be found in `_example/example.{go,toml}`. |
@@ -0,0 +1,560 @@ | |||||
package toml | |||||
import ( | |||||
"encoding" | |||||
"fmt" | |||||
"io" | |||||
"io/ioutil" | |||||
"math" | |||||
"os" | |||||
"reflect" | |||||
"strings" | |||||
) | |||||
// Unmarshaler is the interface implemented by objects that can unmarshal a | |||||
// TOML description of themselves. | |||||
type Unmarshaler interface { | |||||
UnmarshalTOML(interface{}) error | |||||
} | |||||
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. | |||||
func Unmarshal(p []byte, v interface{}) error { | |||||
_, err := Decode(string(p), v) | |||||
return err | |||||
} | |||||
// Primitive is a TOML value that hasn't been decoded into a Go value. | |||||
// | |||||
// This type can be used for any value, which will cause decoding to be delayed. | |||||
// You can use the PrimitiveDecode() function to "manually" decode these values. | |||||
// | |||||
// NOTE: The underlying representation of a `Primitive` value is subject to | |||||
// change. Do not rely on it. | |||||
// | |||||
// NOTE: Primitive values are still parsed, so using them will only avoid the | |||||
// overhead of reflection. They can be useful when you don't know the exact type | |||||
// of TOML data until runtime. | |||||
type Primitive struct { | |||||
undecoded interface{} | |||||
context Key | |||||
} | |||||
// The significand precision for float32 and float64 is 24 and 53 bits; this is | |||||
// the range a natural number can be stored in a float without loss of data. | |||||
const ( | |||||
maxSafeFloat32Int = 16777215 // 2^24-1 | |||||
maxSafeFloat64Int = 9007199254740991 // 2^53-1 | |||||
) | |||||
// PrimitiveDecode is just like the other `Decode*` functions, except it | |||||
// decodes a TOML value that has already been parsed. Valid primitive values | |||||
// can *only* be obtained from values filled by the decoder functions, | |||||
// including this method. (i.e., `v` may contain more `Primitive` | |||||
// values.) | |||||
// | |||||
// Meta data for primitive values is included in the meta data returned by | |||||
// the `Decode*` functions with one exception: keys returned by the Undecoded | |||||
// method will only reflect keys that were decoded. Namely, any keys hidden | |||||
// behind a Primitive will be considered undecoded. Executing this method will | |||||
// update the undecoded keys in the meta data. (See the example.) | |||||
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error { | |||||
md.context = primValue.context | |||||
defer func() { md.context = nil }() | |||||
return md.unify(primValue.undecoded, rvalue(v)) | |||||
} | |||||
// Decoder decodes TOML data. | |||||
// | |||||
// TOML tables correspond to Go structs or maps (dealer's choice – they can be | |||||
// used interchangeably). | |||||
// | |||||
// TOML table arrays correspond to either a slice of structs or a slice of maps. | |||||
// | |||||
// TOML datetimes correspond to Go time.Time values. Local datetimes are parsed | |||||
// in the local timezone. | |||||
// | |||||
// All other TOML types (float, string, int, bool and array) correspond to the | |||||
// obvious Go types. | |||||
// | |||||
// An exception to the above rules is if a type implements the TextUnmarshaler | |||||
// interface, in which case any primitive TOML value (floats, strings, integers, | |||||
// booleans, datetimes) will be converted to a []byte and given to the value's | |||||
// UnmarshalText method. See the Unmarshaler example for a demonstration with | |||||
// time duration strings. | |||||
// | |||||
// Key mapping | |||||
// | |||||
// TOML keys can map to either keys in a Go map or field names in a Go struct. | |||||
// The special `toml` struct tag can be used to map TOML keys to struct fields | |||||
// that don't match the key name exactly (see the example). A case insensitive | |||||
// match to struct names will be tried if an exact match can't be found. | |||||
// | |||||
// The mapping between TOML values and Go values is loose. That is, there may | |||||
// exist TOML values that cannot be placed into your representation, and there | |||||
// may be parts of your representation that do not correspond to TOML values. | |||||
// This loose mapping can be made stricter by using the IsDefined and/or | |||||
// Undecoded methods on the MetaData returned. | |||||
// | |||||
// This decoder does not handle cyclic types. Decode will not terminate if a | |||||
// cyclic type is passed. | |||||
type Decoder struct { | |||||
r io.Reader | |||||
} | |||||
// NewDecoder creates a new Decoder. | |||||
func NewDecoder(r io.Reader) *Decoder { | |||||
return &Decoder{r: r} | |||||
} | |||||
var ( | |||||
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem() | |||||
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() | |||||
) | |||||
// Decode TOML data in to the pointer `v`. | |||||
func (dec *Decoder) Decode(v interface{}) (MetaData, error) { | |||||
rv := reflect.ValueOf(v) | |||||
if rv.Kind() != reflect.Ptr { | |||||
s := "%q" | |||||
if reflect.TypeOf(v) == nil { | |||||
s = "%v" | |||||
} | |||||
return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v)) | |||||
} | |||||
if rv.IsNil() { | |||||
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v)) | |||||
} | |||||
// Check if this is a supported type: struct, map, interface{}, or something | |||||
// that implements UnmarshalTOML or UnmarshalText. | |||||
rv = indirect(rv) | |||||
rt := rv.Type() | |||||
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map && | |||||
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) && | |||||
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) { | |||||
return MetaData{}, e("cannot decode to type %s", rt) | |||||
} | |||||
// TODO: parser should read from io.Reader? Or at the very least, make it | |||||
// read from []byte rather than string | |||||
data, err := ioutil.ReadAll(dec.r) | |||||
if err != nil { | |||||
return MetaData{}, err | |||||
} | |||||
p, err := parse(string(data)) | |||||
if err != nil { | |||||
return MetaData{}, err | |||||
} | |||||
md := MetaData{ | |||||
mapping: p.mapping, | |||||
types: p.types, | |||||
keys: p.ordered, | |||||
decoded: make(map[string]struct{}, len(p.ordered)), | |||||
context: nil, | |||||
} | |||||
return md, md.unify(p.mapping, rv) | |||||
} | |||||
// Decode the TOML data in to the pointer v. | |||||
// | |||||
// See the documentation on Decoder for a description of the decoding process. | |||||
func Decode(data string, v interface{}) (MetaData, error) { | |||||
return NewDecoder(strings.NewReader(data)).Decode(v) | |||||
} | |||||
// DecodeFile is just like Decode, except it will automatically read the | |||||
// contents of the file at path and decode it for you. | |||||
func DecodeFile(path string, v interface{}) (MetaData, error) { | |||||
fp, err := os.Open(path) | |||||
if err != nil { | |||||
return MetaData{}, err | |||||
} | |||||
defer fp.Close() | |||||
return NewDecoder(fp).Decode(v) | |||||
} | |||||
// unify performs a sort of type unification based on the structure of `rv`, | |||||
// which is the client representation. | |||||
// | |||||
// Any type mismatch produces an error. Finding a type that we don't know | |||||
// how to handle produces an unsupported type error. | |||||
func (md *MetaData) unify(data interface{}, rv reflect.Value) error { | |||||
// Special case. Look for a `Primitive` value. | |||||
// TODO: #76 would make this superfluous after implemented. | |||||
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { | |||||
// Save the undecoded data and the key context into the primitive | |||||
// value. | |||||
context := make(Key, len(md.context)) | |||||
copy(context, md.context) | |||||
rv.Set(reflect.ValueOf(Primitive{ | |||||
undecoded: data, | |||||
context: context, | |||||
})) | |||||
return nil | |||||
} | |||||
// Special case. Unmarshaler Interface support. | |||||
if rv.CanAddr() { | |||||
if v, ok := rv.Addr().Interface().(Unmarshaler); ok { | |||||
return v.UnmarshalTOML(data) | |||||
} | |||||
} | |||||
// Special case. Look for a value satisfying the TextUnmarshaler interface. | |||||
if v, ok := rv.Interface().(encoding.TextUnmarshaler); ok { | |||||
return md.unifyText(data, v) | |||||
} | |||||
// TODO: | |||||
// The behavior here is incorrect whenever a Go type satisfies the | |||||
// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or | |||||
// array. In particular, the unmarshaler should only be applied to primitive | |||||
// TOML values. But at this point, it will be applied to all kinds of values | |||||
// and produce an incorrect error whenever those values are hashes or arrays | |||||
// (including arrays of tables). | |||||
k := rv.Kind() | |||||
// laziness | |||||
if k >= reflect.Int && k <= reflect.Uint64 { | |||||
return md.unifyInt(data, rv) | |||||
} | |||||
switch k { | |||||
case reflect.Ptr: | |||||
elem := reflect.New(rv.Type().Elem()) | |||||
err := md.unify(data, reflect.Indirect(elem)) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
rv.Set(elem) | |||||
return nil | |||||
case reflect.Struct: | |||||
return md.unifyStruct(data, rv) | |||||
case reflect.Map: | |||||
return md.unifyMap(data, rv) | |||||
case reflect.Array: | |||||
return md.unifyArray(data, rv) | |||||
case reflect.Slice: | |||||
return md.unifySlice(data, rv) | |||||
case reflect.String: | |||||
return md.unifyString(data, rv) | |||||
case reflect.Bool: | |||||
return md.unifyBool(data, rv) | |||||
case reflect.Interface: | |||||
// we only support empty interfaces. | |||||
if rv.NumMethod() > 0 { | |||||
return e("unsupported type %s", rv.Type()) | |||||
} | |||||
return md.unifyAnything(data, rv) | |||||
case reflect.Float32, reflect.Float64: | |||||
return md.unifyFloat64(data, rv) | |||||
} | |||||
return e("unsupported type %s", rv.Kind()) | |||||
} | |||||
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { | |||||
tmap, ok := mapping.(map[string]interface{}) | |||||
if !ok { | |||||
if mapping == nil { | |||||
return nil | |||||
} | |||||
return e("type mismatch for %s: expected table but found %T", | |||||
rv.Type().String(), mapping) | |||||
} | |||||
for key, datum := range tmap { | |||||
var f *field | |||||
fields := cachedTypeFields(rv.Type()) | |||||
for i := range fields { | |||||
ff := &fields[i] | |||||
if ff.name == key { | |||||
f = ff | |||||
break | |||||
} | |||||
if f == nil && strings.EqualFold(ff.name, key) { | |||||
f = ff | |||||
} | |||||
} | |||||
if f != nil { | |||||
subv := rv | |||||
for _, i := range f.index { | |||||
subv = indirect(subv.Field(i)) | |||||
} | |||||
if isUnifiable(subv) { | |||||
md.decoded[md.context.add(key).String()] = struct{}{} | |||||
md.context = append(md.context, key) | |||||
err := md.unify(datum, subv) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
md.context = md.context[0 : len(md.context)-1] | |||||
} else if f.name != "" { | |||||
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name) | |||||
} | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { | |||||
if k := rv.Type().Key().Kind(); k != reflect.String { | |||||
return fmt.Errorf( | |||||
"toml: cannot decode to a map with non-string key type (%s in %q)", | |||||
k, rv.Type()) | |||||
} | |||||
tmap, ok := mapping.(map[string]interface{}) | |||||
if !ok { | |||||
if tmap == nil { | |||||
return nil | |||||
} | |||||
return md.badtype("map", mapping) | |||||
} | |||||
if rv.IsNil() { | |||||
rv.Set(reflect.MakeMap(rv.Type())) | |||||
} | |||||
for k, v := range tmap { | |||||
md.decoded[md.context.add(k).String()] = struct{}{} | |||||
md.context = append(md.context, k) | |||||
rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) | |||||
if err := md.unify(v, rvval); err != nil { | |||||
return err | |||||
} | |||||
md.context = md.context[0 : len(md.context)-1] | |||||
rvkey := indirect(reflect.New(rv.Type().Key())) | |||||
rvkey.SetString(k) | |||||
rv.SetMapIndex(rvkey, rvval) | |||||
} | |||||
return nil | |||||
} | |||||
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { | |||||
datav := reflect.ValueOf(data) | |||||
if datav.Kind() != reflect.Slice { | |||||
if !datav.IsValid() { | |||||
return nil | |||||
} | |||||
return md.badtype("slice", data) | |||||
} | |||||
if l := datav.Len(); l != rv.Len() { | |||||
return e("expected array length %d; got TOML array of length %d", rv.Len(), l) | |||||
} | |||||
return md.unifySliceArray(datav, rv) | |||||
} | |||||
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { | |||||
datav := reflect.ValueOf(data) | |||||
if datav.Kind() != reflect.Slice { | |||||
if !datav.IsValid() { | |||||
return nil | |||||
} | |||||
return md.badtype("slice", data) | |||||
} | |||||
n := datav.Len() | |||||
if rv.IsNil() || rv.Cap() < n { | |||||
rv.Set(reflect.MakeSlice(rv.Type(), n, n)) | |||||
} | |||||
rv.SetLen(n) | |||||
return md.unifySliceArray(datav, rv) | |||||
} | |||||
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { | |||||
l := data.Len() | |||||
for i := 0; i < l; i++ { | |||||
err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i))) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error { | |||||
if s, ok := data.(string); ok { | |||||
rv.SetString(s) | |||||
return nil | |||||
} | |||||
return md.badtype("string", data) | |||||
} | |||||
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error { | |||||
if num, ok := data.(float64); ok { | |||||
switch rv.Kind() { | |||||
case reflect.Float32: | |||||
if num < -math.MaxFloat32 || num > math.MaxFloat32 { | |||||
return e("value %f is out of range for float32", num) | |||||
} | |||||
fallthrough | |||||
case reflect.Float64: | |||||
rv.SetFloat(num) | |||||
default: | |||||
panic("bug") | |||||
} | |||||
return nil | |||||
} | |||||
if num, ok := data.(int64); ok { | |||||
switch rv.Kind() { | |||||
case reflect.Float32: | |||||
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int { | |||||
return e("value %d is out of range for float32", num) | |||||
} | |||||
fallthrough | |||||
case reflect.Float64: | |||||
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int { | |||||
return e("value %d is out of range for float64", num) | |||||
} | |||||
rv.SetFloat(float64(num)) | |||||
default: | |||||
panic("bug") | |||||
} | |||||
return nil | |||||
} | |||||
return md.badtype("float", data) | |||||
} | |||||
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { | |||||
if num, ok := data.(int64); ok { | |||||
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 { | |||||
switch rv.Kind() { | |||||
case reflect.Int, reflect.Int64: | |||||
// No bounds checking necessary. | |||||
case reflect.Int8: | |||||
if num < math.MinInt8 || num > math.MaxInt8 { | |||||
return e("value %d is out of range for int8", num) | |||||
} | |||||
case reflect.Int16: | |||||
if num < math.MinInt16 || num > math.MaxInt16 { | |||||
return e("value %d is out of range for int16", num) | |||||
} | |||||
case reflect.Int32: | |||||
if num < math.MinInt32 || num > math.MaxInt32 { | |||||
return e("value %d is out of range for int32", num) | |||||
} | |||||
} | |||||
rv.SetInt(num) | |||||
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 { | |||||
unum := uint64(num) | |||||
switch rv.Kind() { | |||||
case reflect.Uint, reflect.Uint64: | |||||
// No bounds checking necessary. | |||||
case reflect.Uint8: | |||||
if num < 0 || unum > math.MaxUint8 { | |||||
return e("value %d is out of range for uint8", num) | |||||
} | |||||
case reflect.Uint16: | |||||
if num < 0 || unum > math.MaxUint16 { | |||||
return e("value %d is out of range for uint16", num) | |||||
} | |||||
case reflect.Uint32: | |||||
if num < 0 || unum > math.MaxUint32 { | |||||
return e("value %d is out of range for uint32", num) | |||||
} | |||||
} | |||||
rv.SetUint(unum) | |||||
} else { | |||||
panic("unreachable") | |||||
} | |||||
return nil | |||||
} | |||||
return md.badtype("integer", data) | |||||
} | |||||
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error { | |||||
if b, ok := data.(bool); ok { | |||||
rv.SetBool(b) | |||||
return nil | |||||
} | |||||
return md.badtype("boolean", data) | |||||
} | |||||
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error { | |||||
rv.Set(reflect.ValueOf(data)) | |||||
return nil | |||||
} | |||||
func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) error { | |||||
var s string | |||||
switch sdata := data.(type) { | |||||
case Marshaler: | |||||
text, err := sdata.MarshalTOML() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
s = string(text) | |||||
case TextMarshaler: | |||||
text, err := sdata.MarshalText() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
s = string(text) | |||||
case fmt.Stringer: | |||||
s = sdata.String() | |||||
case string: | |||||
s = sdata | |||||
case bool: | |||||
s = fmt.Sprintf("%v", sdata) | |||||
case int64: | |||||
s = fmt.Sprintf("%d", sdata) | |||||
case float64: | |||||
s = fmt.Sprintf("%f", sdata) | |||||
default: | |||||
return md.badtype("primitive (string-like)", data) | |||||
} | |||||
if err := v.UnmarshalText([]byte(s)); err != nil { | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
func (md *MetaData) badtype(dst string, data interface{}) error { | |||||
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst) | |||||
} | |||||
// rvalue returns a reflect.Value of `v`. All pointers are resolved. | |||||
func rvalue(v interface{}) reflect.Value { | |||||
return indirect(reflect.ValueOf(v)) | |||||
} | |||||
// indirect returns the value pointed to by a pointer. | |||||
// | |||||
// Pointers are followed until the value is not a pointer. New values are | |||||
// allocated for each nil pointer. | |||||
// | |||||
// An exception to this rule is if the value satisfies an interface of interest | |||||
// to us (like encoding.TextUnmarshaler). | |||||
func indirect(v reflect.Value) reflect.Value { | |||||
if v.Kind() != reflect.Ptr { | |||||
if v.CanSet() { | |||||
pv := v.Addr() | |||||
if _, ok := pv.Interface().(encoding.TextUnmarshaler); ok { | |||||
return pv | |||||
} | |||||
} | |||||
return v | |||||
} | |||||
if v.IsNil() { | |||||
v.Set(reflect.New(v.Type().Elem())) | |||||
} | |||||
return indirect(reflect.Indirect(v)) | |||||
} | |||||
func isUnifiable(rv reflect.Value) bool { | |||||
if rv.CanSet() { | |||||
return true | |||||
} | |||||
if _, ok := rv.Interface().(encoding.TextUnmarshaler); ok { | |||||
return true | |||||
} | |||||
return false | |||||
} | |||||
func e(format string, args ...interface{}) error { | |||||
return fmt.Errorf("toml: "+format, args...) | |||||
} |
@@ -0,0 +1,19 @@ | |||||
//go:build go1.16 | |||||
// +build go1.16 | |||||
package toml | |||||
import ( | |||||
"io/fs" | |||||
) | |||||
// DecodeFS is just like Decode, except it will automatically read the contents | |||||
// of the file at `path` from a fs.FS instance. | |||||
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) { | |||||
fp, err := fsys.Open(path) | |||||
if err != nil { | |||||
return MetaData{}, err | |||||
} | |||||
defer fp.Close() | |||||
return NewDecoder(fp).Decode(v) | |||||
} |
@@ -0,0 +1,21 @@ | |||||
package toml | |||||
import ( | |||||
"encoding" | |||||
"io" | |||||
) | |||||
// Deprecated: use encoding.TextMarshaler | |||||
type TextMarshaler encoding.TextMarshaler | |||||
// Deprecated: use encoding.TextUnmarshaler | |||||
type TextUnmarshaler encoding.TextUnmarshaler | |||||
// Deprecated: use MetaData.PrimitiveDecode. | |||||
func PrimitiveDecode(primValue Primitive, v interface{}) error { | |||||
md := MetaData{decoded: make(map[string]struct{})} | |||||
return md.unify(primValue.undecoded, rvalue(v)) | |||||
} | |||||
// Deprecated: use NewDecoder(reader).Decode(&value). | |||||
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { return NewDecoder(r).Decode(v) } |
@@ -0,0 +1,13 @@ | |||||
/* | |||||
Package toml implements decoding and encoding of TOML files. | |||||
This package supports TOML v1.0.0, as listed on https://toml.io | |||||
There is also support for delaying decoding with the Primitive type, and | |||||
querying the set of keys in a TOML document with the MetaData type. | |||||
The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator, | |||||
and can be used to verify if TOML document is valid. It can also be used to | |||||
print the type of each key. | |||||
*/ | |||||
package toml |
@@ -0,0 +1,694 @@ | |||||
package toml | |||||
import ( | |||||
"bufio" | |||||
"encoding" | |||||
"errors" | |||||
"fmt" | |||||
"io" | |||||
"math" | |||||
"reflect" | |||||
"sort" | |||||
"strconv" | |||||
"strings" | |||||
"time" | |||||
"github.com/BurntSushi/toml/internal" | |||||
) | |||||
type tomlEncodeError struct{ error } | |||||
var ( | |||||
errArrayNilElement = errors.New("toml: cannot encode array with nil element") | |||||
errNonString = errors.New("toml: cannot encode a map with non-string key type") | |||||
errNoKey = errors.New("toml: top-level values must be Go maps or structs") | |||||
errAnything = errors.New("") // used in testing | |||||
) | |||||
var dblQuotedReplacer = strings.NewReplacer( | |||||
"\"", "\\\"", | |||||
"\\", "\\\\", | |||||
"\x00", `\u0000`, | |||||
"\x01", `\u0001`, | |||||
"\x02", `\u0002`, | |||||
"\x03", `\u0003`, | |||||
"\x04", `\u0004`, | |||||
"\x05", `\u0005`, | |||||
"\x06", `\u0006`, | |||||
"\x07", `\u0007`, | |||||
"\b", `\b`, | |||||
"\t", `\t`, | |||||
"\n", `\n`, | |||||
"\x0b", `\u000b`, | |||||
"\f", `\f`, | |||||
"\r", `\r`, | |||||
"\x0e", `\u000e`, | |||||
"\x0f", `\u000f`, | |||||
"\x10", `\u0010`, | |||||
"\x11", `\u0011`, | |||||
"\x12", `\u0012`, | |||||
"\x13", `\u0013`, | |||||
"\x14", `\u0014`, | |||||
"\x15", `\u0015`, | |||||
"\x16", `\u0016`, | |||||
"\x17", `\u0017`, | |||||
"\x18", `\u0018`, | |||||
"\x19", `\u0019`, | |||||
"\x1a", `\u001a`, | |||||
"\x1b", `\u001b`, | |||||
"\x1c", `\u001c`, | |||||
"\x1d", `\u001d`, | |||||
"\x1e", `\u001e`, | |||||
"\x1f", `\u001f`, | |||||
"\x7f", `\u007f`, | |||||
) | |||||
// Marshaler is the interface implemented by types that can marshal themselves | |||||
// into valid TOML. | |||||
type Marshaler interface { | |||||
MarshalTOML() ([]byte, error) | |||||
} | |||||
// Encoder encodes a Go to a TOML document. | |||||
// | |||||
// The mapping between Go values and TOML values should be precisely the same as | |||||
// for the Decode* functions. | |||||
// | |||||
// The toml.Marshaler and encoder.TextMarshaler interfaces are supported to | |||||
// encoding the value as custom TOML. | |||||
// | |||||
// If you want to write arbitrary binary data then you will need to use | |||||
// something like base64 since TOML does not have any binary types. | |||||
// | |||||
// When encoding TOML hashes (Go maps or structs), keys without any sub-hashes | |||||
// are encoded first. | |||||
// | |||||
// Go maps will be sorted alphabetically by key for deterministic output. | |||||
// | |||||
// Encoding Go values without a corresponding TOML representation will return an | |||||
// error. Examples of this includes maps with non-string keys, slices with nil | |||||
// elements, embedded non-struct types, and nested slices containing maps or | |||||
// structs. (e.g. [][]map[string]string is not allowed but []map[string]string | |||||
// is okay, as is []map[string][]string). | |||||
// | |||||
// NOTE: only exported keys are encoded due to the use of reflection. Unexported | |||||
// keys are silently discarded. | |||||
type Encoder struct { | |||||
// String to use for a single indentation level; default is two spaces. | |||||
Indent string | |||||
w *bufio.Writer | |||||
hasWritten bool // written any output to w yet? | |||||
} | |||||
// NewEncoder create a new Encoder. | |||||
func NewEncoder(w io.Writer) *Encoder { | |||||
return &Encoder{ | |||||
w: bufio.NewWriter(w), | |||||
Indent: " ", | |||||
} | |||||
} | |||||
// Encode writes a TOML representation of the Go value to the Encoder's writer. | |||||
// | |||||
// An error is returned if the value given cannot be encoded to a valid TOML | |||||
// document. | |||||
func (enc *Encoder) Encode(v interface{}) error { | |||||
rv := eindirect(reflect.ValueOf(v)) | |||||
if err := enc.safeEncode(Key([]string{}), rv); err != nil { | |||||
return err | |||||
} | |||||
return enc.w.Flush() | |||||
} | |||||
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { | |||||
defer func() { | |||||
if r := recover(); r != nil { | |||||
if terr, ok := r.(tomlEncodeError); ok { | |||||
err = terr.error | |||||
return | |||||
} | |||||
panic(r) | |||||
} | |||||
}() | |||||
enc.encode(key, rv) | |||||
return nil | |||||
} | |||||
func (enc *Encoder) encode(key Key, rv reflect.Value) { | |||||
// Special case: time needs to be in ISO8601 format. | |||||
// | |||||
// Special case: if we can marshal the type to text, then we used that. This | |||||
// prevents the encoder for handling these types as generic structs (or | |||||
// whatever the underlying type of a TextMarshaler is). | |||||
switch t := rv.Interface().(type) { | |||||
case time.Time, encoding.TextMarshaler, Marshaler: | |||||
enc.writeKeyValue(key, rv, false) | |||||
return | |||||
// TODO: #76 would make this superfluous after implemented. | |||||
case Primitive: | |||||
enc.encode(key, reflect.ValueOf(t.undecoded)) | |||||
return | |||||
} | |||||
k := rv.Kind() | |||||
switch k { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, | |||||
reflect.Int64, | |||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, | |||||
reflect.Uint64, | |||||
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: | |||||
enc.writeKeyValue(key, rv, false) | |||||
case reflect.Array, reflect.Slice: | |||||
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { | |||||
enc.eArrayOfTables(key, rv) | |||||
} else { | |||||
enc.writeKeyValue(key, rv, false) | |||||
} | |||||
case reflect.Interface: | |||||
if rv.IsNil() { | |||||
return | |||||
} | |||||
enc.encode(key, rv.Elem()) | |||||
case reflect.Map: | |||||
if rv.IsNil() { | |||||
return | |||||
} | |||||
enc.eTable(key, rv) | |||||
case reflect.Ptr: | |||||
if rv.IsNil() { | |||||
return | |||||
} | |||||
enc.encode(key, rv.Elem()) | |||||
case reflect.Struct: | |||||
enc.eTable(key, rv) | |||||
default: | |||||
encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k)) | |||||
} | |||||
} | |||||
// eElement encodes any value that can be an array element. | |||||
func (enc *Encoder) eElement(rv reflect.Value) { | |||||
switch v := rv.Interface().(type) { | |||||
case time.Time: // Using TextMarshaler adds extra quotes, which we don't want. | |||||
format := time.RFC3339Nano | |||||
switch v.Location() { | |||||
case internal.LocalDatetime: | |||||
format = "2006-01-02T15:04:05.999999999" | |||||
case internal.LocalDate: | |||||
format = "2006-01-02" | |||||
case internal.LocalTime: | |||||
format = "15:04:05.999999999" | |||||
} | |||||
switch v.Location() { | |||||
default: | |||||
enc.wf(v.Format(format)) | |||||
case internal.LocalDatetime, internal.LocalDate, internal.LocalTime: | |||||
enc.wf(v.In(time.UTC).Format(format)) | |||||
} | |||||
return | |||||
case Marshaler: | |||||
s, err := v.MarshalTOML() | |||||
if err != nil { | |||||
encPanic(err) | |||||
} | |||||
enc.writeQuoted(string(s)) | |||||
return | |||||
case encoding.TextMarshaler: | |||||
s, err := v.MarshalText() | |||||
if err != nil { | |||||
encPanic(err) | |||||
} | |||||
enc.writeQuoted(string(s)) | |||||
return | |||||
} | |||||
switch rv.Kind() { | |||||
case reflect.String: | |||||
enc.writeQuoted(rv.String()) | |||||
case reflect.Bool: | |||||
enc.wf(strconv.FormatBool(rv.Bool())) | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
enc.wf(strconv.FormatInt(rv.Int(), 10)) | |||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||||
enc.wf(strconv.FormatUint(rv.Uint(), 10)) | |||||
case reflect.Float32: | |||||
f := rv.Float() | |||||
if math.IsNaN(f) { | |||||
enc.wf("nan") | |||||
} else if math.IsInf(f, 0) { | |||||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) | |||||
} else { | |||||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32))) | |||||
} | |||||
case reflect.Float64: | |||||
f := rv.Float() | |||||
if math.IsNaN(f) { | |||||
enc.wf("nan") | |||||
} else if math.IsInf(f, 0) { | |||||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) | |||||
} else { | |||||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64))) | |||||
} | |||||
case reflect.Array, reflect.Slice: | |||||
enc.eArrayOrSliceElement(rv) | |||||
case reflect.Struct: | |||||
enc.eStruct(nil, rv, true) | |||||
case reflect.Map: | |||||
enc.eMap(nil, rv, true) | |||||
case reflect.Interface: | |||||
enc.eElement(rv.Elem()) | |||||
default: | |||||
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface())) | |||||
} | |||||
} | |||||
// By the TOML spec, all floats must have a decimal with at least one number on | |||||
// either side. | |||||
func floatAddDecimal(fstr string) string { | |||||
if !strings.Contains(fstr, ".") { | |||||
return fstr + ".0" | |||||
} | |||||
return fstr | |||||
} | |||||
func (enc *Encoder) writeQuoted(s string) { | |||||
enc.wf("\"%s\"", dblQuotedReplacer.Replace(s)) | |||||
} | |||||
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { | |||||
length := rv.Len() | |||||
enc.wf("[") | |||||
for i := 0; i < length; i++ { | |||||
elem := rv.Index(i) | |||||
enc.eElement(elem) | |||||
if i != length-1 { | |||||
enc.wf(", ") | |||||
} | |||||
} | |||||
enc.wf("]") | |||||
} | |||||
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { | |||||
if len(key) == 0 { | |||||
encPanic(errNoKey) | |||||
} | |||||
for i := 0; i < rv.Len(); i++ { | |||||
trv := rv.Index(i) | |||||
if isNil(trv) { | |||||
continue | |||||
} | |||||
enc.newline() | |||||
enc.wf("%s[[%s]]", enc.indentStr(key), key) | |||||
enc.newline() | |||||
enc.eMapOrStruct(key, trv, false) | |||||
} | |||||
} | |||||
func (enc *Encoder) eTable(key Key, rv reflect.Value) { | |||||
if len(key) == 1 { | |||||
// Output an extra newline between top-level tables. | |||||
// (The newline isn't written if nothing else has been written though.) | |||||
enc.newline() | |||||
} | |||||
if len(key) > 0 { | |||||
enc.wf("%s[%s]", enc.indentStr(key), key) | |||||
enc.newline() | |||||
} | |||||
enc.eMapOrStruct(key, rv, false) | |||||
} | |||||
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { | |||||
switch rv := eindirect(rv); rv.Kind() { | |||||
case reflect.Map: | |||||
enc.eMap(key, rv, inline) | |||||
case reflect.Struct: | |||||
enc.eStruct(key, rv, inline) | |||||
default: | |||||
// Should never happen? | |||||
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) | |||||
} | |||||
} | |||||
func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { | |||||
rt := rv.Type() | |||||
if rt.Key().Kind() != reflect.String { | |||||
encPanic(errNonString) | |||||
} | |||||
// Sort keys so that we have deterministic output. And write keys directly | |||||
// underneath this key first, before writing sub-structs or sub-maps. | |||||
var mapKeysDirect, mapKeysSub []string | |||||
for _, mapKey := range rv.MapKeys() { | |||||
k := mapKey.String() | |||||
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) { | |||||
mapKeysSub = append(mapKeysSub, k) | |||||
} else { | |||||
mapKeysDirect = append(mapKeysDirect, k) | |||||
} | |||||
} | |||||
var writeMapKeys = func(mapKeys []string, trailC bool) { | |||||
sort.Strings(mapKeys) | |||||
for i, mapKey := range mapKeys { | |||||
val := rv.MapIndex(reflect.ValueOf(mapKey)) | |||||
if isNil(val) { | |||||
continue | |||||
} | |||||
if inline { | |||||
enc.writeKeyValue(Key{mapKey}, val, true) | |||||
if trailC || i != len(mapKeys)-1 { | |||||
enc.wf(", ") | |||||
} | |||||
} else { | |||||
enc.encode(key.add(mapKey), val) | |||||
} | |||||
} | |||||
} | |||||
if inline { | |||||
enc.wf("{") | |||||
} | |||||
writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0) | |||||
writeMapKeys(mapKeysSub, false) | |||||
if inline { | |||||
enc.wf("}") | |||||
} | |||||
} | |||||
const is32Bit = (32 << (^uint(0) >> 63)) == 32 | |||||
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { | |||||
// Write keys for fields directly under this key first, because if we write | |||||
// a field that creates a new table then all keys under it will be in that | |||||
// table (not the one we're writing here). | |||||
// | |||||
// Fields is a [][]int: for fieldsDirect this always has one entry (the | |||||
// struct index). For fieldsSub it contains two entries: the parent field | |||||
// index from tv, and the field indexes for the fields of the sub. | |||||
var ( | |||||
rt = rv.Type() | |||||
fieldsDirect, fieldsSub [][]int | |||||
addFields func(rt reflect.Type, rv reflect.Value, start []int) | |||||
) | |||||
addFields = func(rt reflect.Type, rv reflect.Value, start []int) { | |||||
for i := 0; i < rt.NumField(); i++ { | |||||
f := rt.Field(i) | |||||
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields. | |||||
continue | |||||
} | |||||
frv := rv.Field(i) | |||||
// Treat anonymous struct fields with tag names as though they are | |||||
// not anonymous, like encoding/json does. | |||||
// | |||||
// Non-struct anonymous fields use the normal encoding logic. | |||||
if f.Anonymous { | |||||
t := f.Type | |||||
switch t.Kind() { | |||||
case reflect.Struct: | |||||
if getOptions(f.Tag).name == "" { | |||||
addFields(t, frv, append(start, f.Index...)) | |||||
continue | |||||
} | |||||
case reflect.Ptr: | |||||
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" { | |||||
if !frv.IsNil() { | |||||
addFields(t.Elem(), frv.Elem(), append(start, f.Index...)) | |||||
} | |||||
continue | |||||
} | |||||
} | |||||
} | |||||
if typeIsTable(tomlTypeOfGo(frv)) { | |||||
fieldsSub = append(fieldsSub, append(start, f.Index...)) | |||||
} else { | |||||
// Copy so it works correct on 32bit archs; not clear why this | |||||
// is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4 | |||||
// This also works fine on 64bit, but 32bit archs are somewhat | |||||
// rare and this is a wee bit faster. | |||||
if is32Bit { | |||||
copyStart := make([]int, len(start)) | |||||
copy(copyStart, start) | |||||
fieldsDirect = append(fieldsDirect, append(copyStart, f.Index...)) | |||||
} else { | |||||
fieldsDirect = append(fieldsDirect, append(start, f.Index...)) | |||||
} | |||||
} | |||||
} | |||||
} | |||||
addFields(rt, rv, nil) | |||||
writeFields := func(fields [][]int) { | |||||
for _, fieldIndex := range fields { | |||||
fieldType := rt.FieldByIndex(fieldIndex) | |||||
fieldVal := rv.FieldByIndex(fieldIndex) | |||||
if isNil(fieldVal) { /// Don't write anything for nil fields. | |||||
continue | |||||
} | |||||
opts := getOptions(fieldType.Tag) | |||||
if opts.skip { | |||||
continue | |||||
} | |||||
keyName := fieldType.Name | |||||
if opts.name != "" { | |||||
keyName = opts.name | |||||
} | |||||
if opts.omitempty && isEmpty(fieldVal) { | |||||
continue | |||||
} | |||||
if opts.omitzero && isZero(fieldVal) { | |||||
continue | |||||
} | |||||
if inline { | |||||
enc.writeKeyValue(Key{keyName}, fieldVal, true) | |||||
if fieldIndex[0] != len(fields)-1 { | |||||
enc.wf(", ") | |||||
} | |||||
} else { | |||||
enc.encode(key.add(keyName), fieldVal) | |||||
} | |||||
} | |||||
} | |||||
if inline { | |||||
enc.wf("{") | |||||
} | |||||
writeFields(fieldsDirect) | |||||
writeFields(fieldsSub) | |||||
if inline { | |||||
enc.wf("}") | |||||
} | |||||
} | |||||
// tomlTypeOfGo returns the TOML type name of the Go value's type. | |||||
// | |||||
// It is used to determine whether the types of array elements are mixed (which | |||||
// is forbidden). If the Go value is nil, then it is illegal for it to be an | |||||
// array element, and valueIsNil is returned as true. | |||||
// | |||||
// The type may be `nil`, which means no concrete TOML type could be found. | |||||
func tomlTypeOfGo(rv reflect.Value) tomlType { | |||||
if isNil(rv) || !rv.IsValid() { | |||||
return nil | |||||
} | |||||
switch rv.Kind() { | |||||
case reflect.Bool: | |||||
return tomlBool | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, | |||||
reflect.Int64, | |||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, | |||||
reflect.Uint64: | |||||
return tomlInteger | |||||
case reflect.Float32, reflect.Float64: | |||||
return tomlFloat | |||||
case reflect.Array, reflect.Slice: | |||||
if typeEqual(tomlHash, tomlArrayType(rv)) { | |||||
return tomlArrayHash | |||||
} | |||||
return tomlArray | |||||
case reflect.Ptr, reflect.Interface: | |||||
return tomlTypeOfGo(rv.Elem()) | |||||
case reflect.String: | |||||
return tomlString | |||||
case reflect.Map: | |||||
return tomlHash | |||||
case reflect.Struct: | |||||
if _, ok := rv.Interface().(time.Time); ok { | |||||
return tomlDatetime | |||||
} | |||||
if isMarshaler(rv) { | |||||
return tomlString | |||||
} | |||||
return tomlHash | |||||
default: | |||||
if isMarshaler(rv) { | |||||
return tomlString | |||||
} | |||||
encPanic(errors.New("unsupported type: " + rv.Kind().String())) | |||||
panic("unreachable") | |||||
} | |||||
} | |||||
func isMarshaler(rv reflect.Value) bool { | |||||
switch rv.Interface().(type) { | |||||
case encoding.TextMarshaler: | |||||
return true | |||||
case Marshaler: | |||||
return true | |||||
} | |||||
// Someone used a pointer receiver: we can make it work for pointer values. | |||||
if rv.CanAddr() { | |||||
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok { | |||||
return true | |||||
} | |||||
if _, ok := rv.Addr().Interface().(Marshaler); ok { | |||||
return true | |||||
} | |||||
} | |||||
return false | |||||
} | |||||
// tomlArrayType returns the element type of a TOML array. The type returned | |||||
// may be nil if it cannot be determined (e.g., a nil slice or a zero length | |||||
// slize). This function may also panic if it finds a type that cannot be | |||||
// expressed in TOML (such as nil elements, heterogeneous arrays or directly | |||||
// nested arrays of tables). | |||||
func tomlArrayType(rv reflect.Value) tomlType { | |||||
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 { | |||||
return nil | |||||
} | |||||
/// Don't allow nil. | |||||
rvlen := rv.Len() | |||||
for i := 1; i < rvlen; i++ { | |||||
if tomlTypeOfGo(rv.Index(i)) == nil { | |||||
encPanic(errArrayNilElement) | |||||
} | |||||
} | |||||
firstType := tomlTypeOfGo(rv.Index(0)) | |||||
if firstType == nil { | |||||
encPanic(errArrayNilElement) | |||||
} | |||||
return firstType | |||||
} | |||||
type tagOptions struct { | |||||
skip bool // "-" | |||||
name string | |||||
omitempty bool | |||||
omitzero bool | |||||
} | |||||
func getOptions(tag reflect.StructTag) tagOptions { | |||||
t := tag.Get("toml") | |||||
if t == "-" { | |||||
return tagOptions{skip: true} | |||||
} | |||||
var opts tagOptions | |||||
parts := strings.Split(t, ",") | |||||
opts.name = parts[0] | |||||
for _, s := range parts[1:] { | |||||
switch s { | |||||
case "omitempty": | |||||
opts.omitempty = true | |||||
case "omitzero": | |||||
opts.omitzero = true | |||||
} | |||||
} | |||||
return opts | |||||
} | |||||
func isZero(rv reflect.Value) bool { | |||||
switch rv.Kind() { | |||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||||
return rv.Int() == 0 | |||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||||
return rv.Uint() == 0 | |||||
case reflect.Float32, reflect.Float64: | |||||
return rv.Float() == 0.0 | |||||
} | |||||
return false | |||||
} | |||||
func isEmpty(rv reflect.Value) bool { | |||||
switch rv.Kind() { | |||||
case reflect.Array, reflect.Slice, reflect.Map, reflect.String: | |||||
return rv.Len() == 0 | |||||
case reflect.Bool: | |||||
return !rv.Bool() | |||||
} | |||||
return false | |||||
} | |||||
func (enc *Encoder) newline() { | |||||
if enc.hasWritten { | |||||
enc.wf("\n") | |||||
} | |||||
} | |||||
// Write a key/value pair: | |||||
// | |||||
// key = <any value> | |||||
// | |||||
// This is also used for "k = v" in inline tables; so something like this will | |||||
// be written in three calls: | |||||
// | |||||
// ┌────────────────────┐ | |||||
// │ ┌───┐ ┌─────┐│ | |||||
// v v v v vv | |||||
// key = {k = v, k2 = v2} | |||||
// | |||||
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { | |||||
if len(key) == 0 { | |||||
encPanic(errNoKey) | |||||
} | |||||
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) | |||||
enc.eElement(val) | |||||
if !inline { | |||||
enc.newline() | |||||
} | |||||
} | |||||
func (enc *Encoder) wf(format string, v ...interface{}) { | |||||
_, err := fmt.Fprintf(enc.w, format, v...) | |||||
if err != nil { | |||||
encPanic(err) | |||||
} | |||||
enc.hasWritten = true | |||||
} | |||||
func (enc *Encoder) indentStr(key Key) string { | |||||
return strings.Repeat(enc.Indent, len(key)-1) | |||||
} | |||||
func encPanic(err error) { | |||||
panic(tomlEncodeError{err}) | |||||
} | |||||
func eindirect(v reflect.Value) reflect.Value { | |||||
switch v.Kind() { | |||||
case reflect.Ptr, reflect.Interface: | |||||
return eindirect(v.Elem()) | |||||
default: | |||||
return v | |||||
} | |||||
} | |||||
func isNil(rv reflect.Value) bool { | |||||
switch rv.Kind() { | |||||
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: | |||||
return rv.IsNil() | |||||
default: | |||||
return false | |||||
} | |||||
} |
@@ -0,0 +1,229 @@ | |||||
package toml | |||||
import ( | |||||
"fmt" | |||||
"strings" | |||||
) | |||||
// ParseError is returned when there is an error parsing the TOML syntax. | |||||
// | |||||
// For example invalid syntax, duplicate keys, etc. | |||||
// | |||||
// In addition to the error message itself, you can also print detailed location | |||||
// information with context by using ErrorWithLocation(): | |||||
// | |||||
// toml: error: Key 'fruit' was already created and cannot be used as an array. | |||||
// | |||||
// At line 4, column 2-7: | |||||
// | |||||
// 2 | fruit = [] | |||||
// 3 | | |||||
// 4 | [[fruit]] # Not allowed | |||||
// ^^^^^ | |||||
// | |||||
// Furthermore, the ErrorWithUsage() can be used to print the above with some | |||||
// more detailed usage guidance: | |||||
// | |||||
// toml: error: newlines not allowed within inline tables | |||||
// | |||||
// At line 1, column 18: | |||||
// | |||||
// 1 | x = [{ key = 42 # | |||||
// ^ | |||||
// | |||||
// Error help: | |||||
// | |||||
// Inline tables must always be on a single line: | |||||
// | |||||
// table = {key = 42, second = 43} | |||||
// | |||||
// It is invalid to split them over multiple lines like so: | |||||
// | |||||
// # INVALID | |||||
// table = { | |||||
// key = 42, | |||||
// second = 43 | |||||
// } | |||||
// | |||||
// Use regular for this: | |||||
// | |||||
// [table] | |||||
// key = 42 | |||||
// second = 43 | |||||
type ParseError struct { | |||||
Message string // Short technical message. | |||||
Usage string // Longer message with usage guidance; may be blank. | |||||
Position Position // Position of the error | |||||
LastKey string // Last parsed key, may be blank. | |||||
Line int // Line the error occurred. Deprecated: use Position. | |||||
err error | |||||
input string | |||||
} | |||||
// Position of an error. | |||||
type Position struct { | |||||
Line int // Line number, starting at 1. | |||||
Start int // Start of error, as byte offset starting at 0. | |||||
Len int // Lenght in bytes. | |||||
} | |||||
func (pe ParseError) Error() string { | |||||
msg := pe.Message | |||||
if msg == "" { // Error from errorf() | |||||
msg = pe.err.Error() | |||||
} | |||||
if pe.LastKey == "" { | |||||
return fmt.Sprintf("toml: line %d: %s", pe.Position.Line, msg) | |||||
} | |||||
return fmt.Sprintf("toml: line %d (last key %q): %s", | |||||
pe.Position.Line, pe.LastKey, msg) | |||||
} | |||||
// ErrorWithUsage() returns the error with detailed location context. | |||||
// | |||||
// See the documentation on ParseError. | |||||
func (pe ParseError) ErrorWithPosition() string { | |||||
if pe.input == "" { // Should never happen, but just in case. | |||||
return pe.Error() | |||||
} | |||||
var ( | |||||
lines = strings.Split(pe.input, "\n") | |||||
col = pe.column(lines) | |||||
b = new(strings.Builder) | |||||
) | |||||
msg := pe.Message | |||||
if msg == "" { | |||||
msg = pe.err.Error() | |||||
} | |||||
// TODO: don't show control characters as literals? This may not show up | |||||
// well everywhere. | |||||
if pe.Position.Len == 1 { | |||||
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d:\n\n", | |||||
msg, pe.Position.Line, col+1) | |||||
} else { | |||||
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d-%d:\n\n", | |||||
msg, pe.Position.Line, col, col+pe.Position.Len) | |||||
} | |||||
if pe.Position.Line > 2 { | |||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, lines[pe.Position.Line-3]) | |||||
} | |||||
if pe.Position.Line > 1 { | |||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, lines[pe.Position.Line-2]) | |||||
} | |||||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, lines[pe.Position.Line-1]) | |||||
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col), strings.Repeat("^", pe.Position.Len)) | |||||
return b.String() | |||||
} | |||||
// ErrorWithUsage() returns the error with detailed location context and usage | |||||
// guidance. | |||||
// | |||||
// See the documentation on ParseError. | |||||
func (pe ParseError) ErrorWithUsage() string { | |||||
m := pe.ErrorWithPosition() | |||||
if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" { | |||||
return m + "Error help:\n\n " + | |||||
strings.ReplaceAll(strings.TrimSpace(u.Usage()), "\n", "\n ") + | |||||
"\n" | |||||
} | |||||
return m | |||||
} | |||||
func (pe ParseError) column(lines []string) int { | |||||
var pos, col int | |||||
for i := range lines { | |||||
ll := len(lines[i]) + 1 // +1 for the removed newline | |||||
if pos+ll >= pe.Position.Start { | |||||
col = pe.Position.Start - pos | |||||
if col < 0 { // Should never happen, but just in case. | |||||
col = 0 | |||||
} | |||||
break | |||||
} | |||||
pos += ll | |||||
} | |||||
return col | |||||
} | |||||
type ( | |||||
errLexControl struct{ r rune } | |||||
errLexEscape struct{ r rune } | |||||
errLexUTF8 struct{ b byte } | |||||
errLexInvalidNum struct{ v string } | |||||
errLexInvalidDate struct{ v string } | |||||
errLexInlineTableNL struct{} | |||||
errLexStringNL struct{} | |||||
) | |||||
func (e errLexControl) Error() string { | |||||
return fmt.Sprintf("TOML files cannot contain control characters: '0x%02x'", e.r) | |||||
} | |||||
func (e errLexControl) Usage() string { return "" } | |||||
func (e errLexEscape) Error() string { return fmt.Sprintf(`invalid escape in string '\%c'`, e.r) } | |||||
func (e errLexEscape) Usage() string { return usageEscape } | |||||
func (e errLexUTF8) Error() string { return fmt.Sprintf("invalid UTF-8 byte: 0x%02x", e.b) } | |||||
func (e errLexUTF8) Usage() string { return "" } | |||||
func (e errLexInvalidNum) Error() string { return fmt.Sprintf("invalid number: %q", e.v) } | |||||
func (e errLexInvalidNum) Usage() string { return "" } | |||||
func (e errLexInvalidDate) Error() string { return fmt.Sprintf("invalid date: %q", e.v) } | |||||
func (e errLexInvalidDate) Usage() string { return "" } | |||||
func (e errLexInlineTableNL) Error() string { return "newlines not allowed within inline tables" } | |||||
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline } | |||||
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" } | |||||
func (e errLexStringNL) Usage() string { return usageStringNewline } | |||||
const usageEscape = ` | |||||
A '\' inside a "-delimited string is interpreted as an escape character. | |||||
The following escape sequences are supported: | |||||
\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX | |||||
To prevent a '\' from being recognized as an escape character, use either: | |||||
- a ' or '''-delimited string; escape characters aren't processed in them; or | |||||
- write two backslashes to get a single backslash: '\\'. | |||||
If you're trying to add a Windows path (e.g. "C:\Users\martin") then using '/' | |||||
instead of '\' will usually also work: "C:/Users/martin". | |||||
` | |||||
const usageInlineNewline = ` | |||||
Inline tables must always be on a single line: | |||||
table = {key = 42, second = 43} | |||||
It is invalid to split them over multiple lines like so: | |||||
# INVALID | |||||
table = { | |||||
key = 42, | |||||
second = 43 | |||||
} | |||||
Use regular for this: | |||||
[table] | |||||
key = 42 | |||||
second = 43 | |||||
` | |||||
const usageStringNewline = ` | |||||
Strings must always be on a single line, and cannot span more than one line: | |||||
# INVALID | |||||
string = "Hello, | |||||
world!" | |||||
Instead use """ or ''' to split strings over multiple lines: | |||||
string = """Hello, | |||||
world!""" | |||||
` |
@@ -0,0 +1,36 @@ | |||||
package internal | |||||
import "time" | |||||
// Timezones used for local datetime, date, and time TOML types. | |||||
// | |||||
// The exact way times and dates without a timezone should be interpreted is not | |||||
// well-defined in the TOML specification and left to the implementation. These | |||||
// defaults to current local timezone offset of the computer, but this can be | |||||
// changed by changing these variables before decoding. | |||||
// | |||||
// TODO: | |||||
// Ideally we'd like to offer people the ability to configure the used timezone | |||||
// by setting Decoder.Timezone and Encoder.Timezone; however, this is a bit | |||||
// tricky: the reason we use three different variables for this is to support | |||||
// round-tripping – without these specific TZ names we wouldn't know which | |||||
// format to use. | |||||
// | |||||
// There isn't a good way to encode this right now though, and passing this sort | |||||
// of information also ties in to various related issues such as string format | |||||
// encoding, encoding of comments, etc. | |||||
// | |||||
// So, for the time being, just put this in internal until we can write a good | |||||
// comprehensive API for doing all of this. | |||||
// | |||||
// The reason they're exported is because they're referred from in e.g. | |||||
// internal/tag. | |||||
// | |||||
// Note that this behaviour is valid according to the TOML spec as the exact | |||||
// behaviour is left up to implementations. | |||||
var ( | |||||
localOffset = func() int { _, o := time.Now().Zone(); return o }() | |||||
LocalDatetime = time.FixedZone("datetime-local", localOffset) | |||||
LocalDate = time.FixedZone("date-local", localOffset) | |||||
LocalTime = time.FixedZone("time-local", localOffset) | |||||
) |
@@ -0,0 +1,120 @@ | |||||
package toml | |||||
import ( | |||||
"strings" | |||||
) | |||||
// MetaData allows access to meta information about TOML data that's not | |||||
// accessible otherwise. | |||||
// | |||||
// It allows checking if a key is defined in the TOML data, whether any keys | |||||
// were undecoded, and the TOML type of a key. | |||||
type MetaData struct { | |||||
context Key // Used only during decoding. | |||||
mapping map[string]interface{} | |||||
types map[string]tomlType | |||||
keys []Key | |||||
decoded map[string]struct{} | |||||
} | |||||
// IsDefined reports if the key exists in the TOML data. | |||||
// | |||||
// The key should be specified hierarchically, for example to access the TOML | |||||
// key "a.b.c" you would use IsDefined("a", "b", "c"). Keys are case sensitive. | |||||
// | |||||
// Returns false for an empty key. | |||||
func (md *MetaData) IsDefined(key ...string) bool { | |||||
if len(key) == 0 { | |||||
return false | |||||
} | |||||
var ( | |||||
hash map[string]interface{} | |||||
ok bool | |||||
hashOrVal interface{} = md.mapping | |||||
) | |||||
for _, k := range key { | |||||
if hash, ok = hashOrVal.(map[string]interface{}); !ok { | |||||
return false | |||||
} | |||||
if hashOrVal, ok = hash[k]; !ok { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
// Type returns a string representation of the type of the key specified. | |||||
// | |||||
// Type will return the empty string if given an empty key or a key that does | |||||
// not exist. Keys are case sensitive. | |||||
func (md *MetaData) Type(key ...string) string { | |||||
if typ, ok := md.types[Key(key).String()]; ok { | |||||
return typ.typeString() | |||||
} | |||||
return "" | |||||
} | |||||
// Keys returns a slice of every key in the TOML data, including key groups. | |||||
// | |||||
// Each key is itself a slice, where the first element is the top of the | |||||
// hierarchy and the last is the most specific. The list will have the same | |||||
// order as the keys appeared in the TOML data. | |||||
// | |||||
// All keys returned are non-empty. | |||||
func (md *MetaData) Keys() []Key { | |||||
return md.keys | |||||
} | |||||
// Undecoded returns all keys that have not been decoded in the order in which | |||||
// they appear in the original TOML document. | |||||
// | |||||
// This includes keys that haven't been decoded because of a Primitive value. | |||||
// Once the Primitive value is decoded, the keys will be considered decoded. | |||||
// | |||||
// Also note that decoding into an empty interface will result in no decoding, | |||||
// and so no keys will be considered decoded. | |||||
// | |||||
// In this sense, the Undecoded keys correspond to keys in the TOML document | |||||
// that do not have a concrete type in your representation. | |||||
func (md *MetaData) Undecoded() []Key { | |||||
undecoded := make([]Key, 0, len(md.keys)) | |||||
for _, key := range md.keys { | |||||
if _, ok := md.decoded[key.String()]; !ok { | |||||
undecoded = append(undecoded, key) | |||||
} | |||||
} | |||||
return undecoded | |||||
} | |||||
// Key represents any TOML key, including key groups. Use (MetaData).Keys to get | |||||
// values of this type. | |||||
type Key []string | |||||
func (k Key) String() string { | |||||
ss := make([]string, len(k)) | |||||
for i := range k { | |||||
ss[i] = k.maybeQuoted(i) | |||||
} | |||||
return strings.Join(ss, ".") | |||||
} | |||||
func (k Key) maybeQuoted(i int) string { | |||||
if k[i] == "" { | |||||
return `""` | |||||
} | |||||
for _, c := range k[i] { | |||||
if !isBareKeyChar(c) { | |||||
return `"` + dblQuotedReplacer.Replace(k[i]) + `"` | |||||
} | |||||
} | |||||
return k[i] | |||||
} | |||||
func (k Key) add(piece string) Key { | |||||
newKey := make(Key, len(k)+1) | |||||
copy(newKey, k) | |||||
newKey[len(k)] = piece | |||||
return newKey | |||||
} |
@@ -0,0 +1,763 @@ | |||||
package toml | |||||
import ( | |||||
"fmt" | |||||
"strconv" | |||||
"strings" | |||||
"time" | |||||
"unicode/utf8" | |||||
"github.com/BurntSushi/toml/internal" | |||||
) | |||||
type parser struct { | |||||
lx *lexer | |||||
context Key // Full key for the current hash in scope. | |||||
currentKey string // Base key name for everything except hashes. | |||||
pos Position // Current position in the TOML file. | |||||
ordered []Key // List of keys in the order that they appear in the TOML data. | |||||
mapping map[string]interface{} // Map keyname → key value. | |||||
types map[string]tomlType // Map keyname → TOML type. | |||||
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names"). | |||||
} | |||||
func parse(data string) (p *parser, err error) { | |||||
defer func() { | |||||
if r := recover(); r != nil { | |||||
if pErr, ok := r.(ParseError); ok { | |||||
pErr.input = data | |||||
err = pErr | |||||
return | |||||
} | |||||
panic(r) | |||||
} | |||||
}() | |||||
// Read over BOM; do this here as the lexer calls utf8.DecodeRuneInString() | |||||
// which mangles stuff. | |||||
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") { | |||||
data = data[2:] | |||||
} | |||||
// Examine first few bytes for NULL bytes; this probably means it's a UTF-16 | |||||
// file (second byte in surrogate pair being NULL). Again, do this here to | |||||
// avoid having to deal with UTF-8/16 stuff in the lexer. | |||||
ex := 6 | |||||
if len(data) < 6 { | |||||
ex = len(data) | |||||
} | |||||
if i := strings.IndexRune(data[:ex], 0); i > -1 { | |||||
return nil, ParseError{ | |||||
Message: "files cannot contain NULL bytes; probably using UTF-16; TOML files must be UTF-8", | |||||
Position: Position{Line: 1, Start: i, Len: 1}, | |||||
Line: 1, | |||||
input: data, | |||||
} | |||||
} | |||||
p = &parser{ | |||||
mapping: make(map[string]interface{}), | |||||
types: make(map[string]tomlType), | |||||
lx: lex(data), | |||||
ordered: make([]Key, 0), | |||||
implicits: make(map[string]struct{}), | |||||
} | |||||
for { | |||||
item := p.next() | |||||
if item.typ == itemEOF { | |||||
break | |||||
} | |||||
p.topLevel(item) | |||||
} | |||||
return p, nil | |||||
} | |||||
func (p *parser) panicItemf(it item, format string, v ...interface{}) { | |||||
panic(ParseError{ | |||||
Message: fmt.Sprintf(format, v...), | |||||
Position: it.pos, | |||||
Line: it.pos.Len, | |||||
LastKey: p.current(), | |||||
}) | |||||
} | |||||
func (p *parser) panicf(format string, v ...interface{}) { | |||||
panic(ParseError{ | |||||
Message: fmt.Sprintf(format, v...), | |||||
Position: p.pos, | |||||
Line: p.pos.Line, | |||||
LastKey: p.current(), | |||||
}) | |||||
} | |||||
func (p *parser) next() item { | |||||
it := p.lx.nextItem() | |||||
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.line, it.val) | |||||
if it.typ == itemError { | |||||
if it.err != nil { | |||||
panic(ParseError{ | |||||
Position: it.pos, | |||||
Line: it.pos.Line, | |||||
LastKey: p.current(), | |||||
err: it.err, | |||||
}) | |||||
} | |||||
p.panicItemf(it, "%s", it.val) | |||||
} | |||||
return it | |||||
} | |||||
func (p *parser) nextPos() item { | |||||
it := p.next() | |||||
p.pos = it.pos | |||||
return it | |||||
} | |||||
func (p *parser) bug(format string, v ...interface{}) { | |||||
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...)) | |||||
} | |||||
func (p *parser) expect(typ itemType) item { | |||||
it := p.next() | |||||
p.assertEqual(typ, it.typ) | |||||
return it | |||||
} | |||||
func (p *parser) assertEqual(expected, got itemType) { | |||||
if expected != got { | |||||
p.bug("Expected '%s' but got '%s'.", expected, got) | |||||
} | |||||
} | |||||
func (p *parser) topLevel(item item) { | |||||
switch item.typ { | |||||
case itemCommentStart: // # .. | |||||
p.expect(itemText) | |||||
case itemTableStart: // [ .. ] | |||||
name := p.nextPos() | |||||
var key Key | |||||
for ; name.typ != itemTableEnd && name.typ != itemEOF; name = p.next() { | |||||
key = append(key, p.keyString(name)) | |||||
} | |||||
p.assertEqual(itemTableEnd, name.typ) | |||||
p.addContext(key, false) | |||||
p.setType("", tomlHash) | |||||
p.ordered = append(p.ordered, key) | |||||
case itemArrayTableStart: // [[ .. ]] | |||||
name := p.nextPos() | |||||
var key Key | |||||
for ; name.typ != itemArrayTableEnd && name.typ != itemEOF; name = p.next() { | |||||
key = append(key, p.keyString(name)) | |||||
} | |||||
p.assertEqual(itemArrayTableEnd, name.typ) | |||||
p.addContext(key, true) | |||||
p.setType("", tomlArrayHash) | |||||
p.ordered = append(p.ordered, key) | |||||
case itemKeyStart: // key = .. | |||||
outerContext := p.context | |||||
/// Read all the key parts (e.g. 'a' and 'b' in 'a.b') | |||||
k := p.nextPos() | |||||
var key Key | |||||
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() { | |||||
key = append(key, p.keyString(k)) | |||||
} | |||||
p.assertEqual(itemKeyEnd, k.typ) | |||||
/// The current key is the last part. | |||||
p.currentKey = key[len(key)-1] | |||||
/// All the other parts (if any) are the context; need to set each part | |||||
/// as implicit. | |||||
context := key[:len(key)-1] | |||||
for i := range context { | |||||
p.addImplicitContext(append(p.context, context[i:i+1]...)) | |||||
} | |||||
/// Set value. | |||||
val, typ := p.value(p.next(), false) | |||||
p.set(p.currentKey, val, typ) | |||||
p.ordered = append(p.ordered, p.context.add(p.currentKey)) | |||||
/// Remove the context we added (preserving any context from [tbl] lines). | |||||
p.context = outerContext | |||||
p.currentKey = "" | |||||
default: | |||||
p.bug("Unexpected type at top level: %s", item.typ) | |||||
} | |||||
} | |||||
// Gets a string for a key (or part of a key in a table name). | |||||
func (p *parser) keyString(it item) string { | |||||
switch it.typ { | |||||
case itemText: | |||||
return it.val | |||||
case itemString, itemMultilineString, | |||||
itemRawString, itemRawMultilineString: | |||||
s, _ := p.value(it, false) | |||||
return s.(string) | |||||
default: | |||||
p.bug("Unexpected key type: %s", it.typ) | |||||
} | |||||
panic("unreachable") | |||||
} | |||||
var datetimeRepl = strings.NewReplacer( | |||||
"z", "Z", | |||||
"t", "T", | |||||
" ", "T") | |||||
// value translates an expected value from the lexer into a Go value wrapped | |||||
// as an empty interface. | |||||
func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) { | |||||
switch it.typ { | |||||
case itemString: | |||||
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it) | |||||
case itemMultilineString: | |||||
return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), p.typeOfPrimitive(it) | |||||
case itemRawString: | |||||
return it.val, p.typeOfPrimitive(it) | |||||
case itemRawMultilineString: | |||||
return stripFirstNewline(it.val), p.typeOfPrimitive(it) | |||||
case itemInteger: | |||||
return p.valueInteger(it) | |||||
case itemFloat: | |||||
return p.valueFloat(it) | |||||
case itemBool: | |||||
switch it.val { | |||||
case "true": | |||||
return true, p.typeOfPrimitive(it) | |||||
case "false": | |||||
return false, p.typeOfPrimitive(it) | |||||
default: | |||||
p.bug("Expected boolean value, but got '%s'.", it.val) | |||||
} | |||||
case itemDatetime: | |||||
return p.valueDatetime(it) | |||||
case itemArray: | |||||
return p.valueArray(it) | |||||
case itemInlineTableStart: | |||||
return p.valueInlineTable(it, parentIsArray) | |||||
default: | |||||
p.bug("Unexpected value type: %s", it.typ) | |||||
} | |||||
panic("unreachable") | |||||
} | |||||
func (p *parser) valueInteger(it item) (interface{}, tomlType) { | |||||
if !numUnderscoresOK(it.val) { | |||||
p.panicItemf(it, "Invalid integer %q: underscores must be surrounded by digits", it.val) | |||||
} | |||||
if numHasLeadingZero(it.val) { | |||||
p.panicItemf(it, "Invalid integer %q: cannot have leading zeroes", it.val) | |||||
} | |||||
num, err := strconv.ParseInt(it.val, 0, 64) | |||||
if err != nil { | |||||
// Distinguish integer values. Normally, it'd be a bug if the lexer | |||||
// provides an invalid integer, but it's possible that the number is | |||||
// out of range of valid values (which the lexer cannot determine). | |||||
// So mark the former as a bug but the latter as a legitimate user | |||||
// error. | |||||
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { | |||||
p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val) | |||||
} else { | |||||
p.bug("Expected integer value, but got '%s'.", it.val) | |||||
} | |||||
} | |||||
return num, p.typeOfPrimitive(it) | |||||
} | |||||
func (p *parser) valueFloat(it item) (interface{}, tomlType) { | |||||
parts := strings.FieldsFunc(it.val, func(r rune) bool { | |||||
switch r { | |||||
case '.', 'e', 'E': | |||||
return true | |||||
} | |||||
return false | |||||
}) | |||||
for _, part := range parts { | |||||
if !numUnderscoresOK(part) { | |||||
p.panicItemf(it, "Invalid float %q: underscores must be surrounded by digits", it.val) | |||||
} | |||||
} | |||||
if len(parts) > 0 && numHasLeadingZero(parts[0]) { | |||||
p.panicItemf(it, "Invalid float %q: cannot have leading zeroes", it.val) | |||||
} | |||||
if !numPeriodsOK(it.val) { | |||||
// As a special case, numbers like '123.' or '1.e2', | |||||
// which are valid as far as Go/strconv are concerned, | |||||
// must be rejected because TOML says that a fractional | |||||
// part consists of '.' followed by 1+ digits. | |||||
p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val) | |||||
} | |||||
val := strings.Replace(it.val, "_", "", -1) | |||||
if val == "+nan" || val == "-nan" { // Go doesn't support this, but TOML spec does. | |||||
val = "nan" | |||||
} | |||||
num, err := strconv.ParseFloat(val, 64) | |||||
if err != nil { | |||||
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { | |||||
p.panicItemf(it, "Float '%s' is out of the range of 64-bit IEEE-754 floating-point numbers.", it.val) | |||||
} else { | |||||
p.panicItemf(it, "Invalid float value: %q", it.val) | |||||
} | |||||
} | |||||
return num, p.typeOfPrimitive(it) | |||||
} | |||||
var dtTypes = []struct { | |||||
fmt string | |||||
zone *time.Location | |||||
}{ | |||||
{time.RFC3339Nano, time.Local}, | |||||
{"2006-01-02T15:04:05.999999999", internal.LocalDatetime}, | |||||
{"2006-01-02", internal.LocalDate}, | |||||
{"15:04:05.999999999", internal.LocalTime}, | |||||
} | |||||
func (p *parser) valueDatetime(it item) (interface{}, tomlType) { | |||||
it.val = datetimeRepl.Replace(it.val) | |||||
var ( | |||||
t time.Time | |||||
ok bool | |||||
err error | |||||
) | |||||
for _, dt := range dtTypes { | |||||
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone) | |||||
if err == nil { | |||||
ok = true | |||||
break | |||||
} | |||||
} | |||||
if !ok { | |||||
p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val) | |||||
} | |||||
return t, p.typeOfPrimitive(it) | |||||
} | |||||
func (p *parser) valueArray(it item) (interface{}, tomlType) { | |||||
p.setType(p.currentKey, tomlArray) | |||||
// p.setType(p.currentKey, typ) | |||||
var ( | |||||
types []tomlType | |||||
// Initialize to a non-nil empty slice. This makes it consistent with | |||||
// how S = [] decodes into a non-nil slice inside something like struct | |||||
// { S []string }. See #338 | |||||
array = []interface{}{} | |||||
) | |||||
for it = p.next(); it.typ != itemArrayEnd; it = p.next() { | |||||
if it.typ == itemCommentStart { | |||||
p.expect(itemText) | |||||
continue | |||||
} | |||||
val, typ := p.value(it, true) | |||||
array = append(array, val) | |||||
types = append(types, typ) | |||||
// XXX: types isn't used here, we need it to record the accurate type | |||||
// information. | |||||
// | |||||
// Not entirely sure how to best store this; could use "key[0]", | |||||
// "key[1]" notation, or maybe store it on the Array type? | |||||
} | |||||
return array, tomlArray | |||||
} | |||||
func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) { | |||||
var ( | |||||
hash = make(map[string]interface{}) | |||||
outerContext = p.context | |||||
outerKey = p.currentKey | |||||
) | |||||
p.context = append(p.context, p.currentKey) | |||||
prevContext := p.context | |||||
p.currentKey = "" | |||||
p.addImplicit(p.context) | |||||
p.addContext(p.context, parentIsArray) | |||||
/// Loop over all table key/value pairs. | |||||
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() { | |||||
if it.typ == itemCommentStart { | |||||
p.expect(itemText) | |||||
continue | |||||
} | |||||
/// Read all key parts. | |||||
k := p.nextPos() | |||||
var key Key | |||||
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() { | |||||
key = append(key, p.keyString(k)) | |||||
} | |||||
p.assertEqual(itemKeyEnd, k.typ) | |||||
/// The current key is the last part. | |||||
p.currentKey = key[len(key)-1] | |||||
/// All the other parts (if any) are the context; need to set each part | |||||
/// as implicit. | |||||
context := key[:len(key)-1] | |||||
for i := range context { | |||||
p.addImplicitContext(append(p.context, context[i:i+1]...)) | |||||
} | |||||
/// Set the value. | |||||
val, typ := p.value(p.next(), false) | |||||
p.set(p.currentKey, val, typ) | |||||
p.ordered = append(p.ordered, p.context.add(p.currentKey)) | |||||
hash[p.currentKey] = val | |||||
/// Restore context. | |||||
p.context = prevContext | |||||
} | |||||
p.context = outerContext | |||||
p.currentKey = outerKey | |||||
return hash, tomlHash | |||||
} | |||||
// numHasLeadingZero checks if this number has leading zeroes, allowing for '0', | |||||
// +/- signs, and base prefixes. | |||||
func numHasLeadingZero(s string) bool { | |||||
if len(s) > 1 && s[0] == '0' && !(s[1] == 'b' || s[1] == 'o' || s[1] == 'x') { // Allow 0b, 0o, 0x | |||||
return true | |||||
} | |||||
if len(s) > 2 && (s[0] == '-' || s[0] == '+') && s[1] == '0' { | |||||
return true | |||||
} | |||||
return false | |||||
} | |||||
// numUnderscoresOK checks whether each underscore in s is surrounded by | |||||
// characters that are not underscores. | |||||
func numUnderscoresOK(s string) bool { | |||||
switch s { | |||||
case "nan", "+nan", "-nan", "inf", "-inf", "+inf": | |||||
return true | |||||
} | |||||
accept := false | |||||
for _, r := range s { | |||||
if r == '_' { | |||||
if !accept { | |||||
return false | |||||
} | |||||
} | |||||
// isHexadecimal is a superset of all the permissable characters | |||||
// surrounding an underscore. | |||||
accept = isHexadecimal(r) | |||||
} | |||||
return accept | |||||
} | |||||
// numPeriodsOK checks whether every period in s is followed by a digit. | |||||
func numPeriodsOK(s string) bool { | |||||
period := false | |||||
for _, r := range s { | |||||
if period && !isDigit(r) { | |||||
return false | |||||
} | |||||
period = r == '.' | |||||
} | |||||
return !period | |||||
} | |||||
// Set the current context of the parser, where the context is either a hash or | |||||
// an array of hashes, depending on the value of the `array` parameter. | |||||
// | |||||
// Establishing the context also makes sure that the key isn't a duplicate, and | |||||
// will create implicit hashes automatically. | |||||
func (p *parser) addContext(key Key, array bool) { | |||||
var ok bool | |||||
// Always start at the top level and drill down for our context. | |||||
hashContext := p.mapping | |||||
keyContext := make(Key, 0) | |||||
// We only need implicit hashes for key[0:-1] | |||||
for _, k := range key[0 : len(key)-1] { | |||||
_, ok = hashContext[k] | |||||
keyContext = append(keyContext, k) | |||||
// No key? Make an implicit hash and move on. | |||||
if !ok { | |||||
p.addImplicit(keyContext) | |||||
hashContext[k] = make(map[string]interface{}) | |||||
} | |||||
// If the hash context is actually an array of tables, then set | |||||
// the hash context to the last element in that array. | |||||
// | |||||
// Otherwise, it better be a table, since this MUST be a key group (by | |||||
// virtue of it not being the last element in a key). | |||||
switch t := hashContext[k].(type) { | |||||
case []map[string]interface{}: | |||||
hashContext = t[len(t)-1] | |||||
case map[string]interface{}: | |||||
hashContext = t | |||||
default: | |||||
p.panicf("Key '%s' was already created as a hash.", keyContext) | |||||
} | |||||
} | |||||
p.context = keyContext | |||||
if array { | |||||
// If this is the first element for this array, then allocate a new | |||||
// list of tables for it. | |||||
k := key[len(key)-1] | |||||
if _, ok := hashContext[k]; !ok { | |||||
hashContext[k] = make([]map[string]interface{}, 0, 4) | |||||
} | |||||
// Add a new table. But make sure the key hasn't already been used | |||||
// for something else. | |||||
if hash, ok := hashContext[k].([]map[string]interface{}); ok { | |||||
hashContext[k] = append(hash, make(map[string]interface{})) | |||||
} else { | |||||
p.panicf("Key '%s' was already created and cannot be used as an array.", key) | |||||
} | |||||
} else { | |||||
p.setValue(key[len(key)-1], make(map[string]interface{})) | |||||
} | |||||
p.context = append(p.context, key[len(key)-1]) | |||||
} | |||||
// set calls setValue and setType. | |||||
func (p *parser) set(key string, val interface{}, typ tomlType) { | |||||
p.setValue(key, val) | |||||
p.setType(key, typ) | |||||
} | |||||
// setValue sets the given key to the given value in the current context. | |||||
// It will make sure that the key hasn't already been defined, account for | |||||
// implicit key groups. | |||||
func (p *parser) setValue(key string, value interface{}) { | |||||
var ( | |||||
tmpHash interface{} | |||||
ok bool | |||||
hash = p.mapping | |||||
keyContext Key | |||||
) | |||||
for _, k := range p.context { | |||||
keyContext = append(keyContext, k) | |||||
if tmpHash, ok = hash[k]; !ok { | |||||
p.bug("Context for key '%s' has not been established.", keyContext) | |||||
} | |||||
switch t := tmpHash.(type) { | |||||
case []map[string]interface{}: | |||||
// The context is a table of hashes. Pick the most recent table | |||||
// defined as the current hash. | |||||
hash = t[len(t)-1] | |||||
case map[string]interface{}: | |||||
hash = t | |||||
default: | |||||
p.panicf("Key '%s' has already been defined.", keyContext) | |||||
} | |||||
} | |||||
keyContext = append(keyContext, key) | |||||
if _, ok := hash[key]; ok { | |||||
// Normally redefining keys isn't allowed, but the key could have been | |||||
// defined implicitly and it's allowed to be redefined concretely. (See | |||||
// the `valid/implicit-and-explicit-after.toml` in toml-test) | |||||
// | |||||
// But we have to make sure to stop marking it as an implicit. (So that | |||||
// another redefinition provokes an error.) | |||||
// | |||||
// Note that since it has already been defined (as a hash), we don't | |||||
// want to overwrite it. So our business is done. | |||||
if p.isArray(keyContext) { | |||||
p.removeImplicit(keyContext) | |||||
hash[key] = value | |||||
return | |||||
} | |||||
if p.isImplicit(keyContext) { | |||||
p.removeImplicit(keyContext) | |||||
return | |||||
} | |||||
// Otherwise, we have a concrete key trying to override a previous | |||||
// key, which is *always* wrong. | |||||
p.panicf("Key '%s' has already been defined.", keyContext) | |||||
} | |||||
hash[key] = value | |||||
} | |||||
// setType sets the type of a particular value at a given key. It should be | |||||
// called immediately AFTER setValue. | |||||
// | |||||
// Note that if `key` is empty, then the type given will be applied to the | |||||
// current context (which is either a table or an array of tables). | |||||
func (p *parser) setType(key string, typ tomlType) { | |||||
keyContext := make(Key, 0, len(p.context)+1) | |||||
keyContext = append(keyContext, p.context...) | |||||
if len(key) > 0 { // allow type setting for hashes | |||||
keyContext = append(keyContext, key) | |||||
} | |||||
// Special case to make empty keys ("" = 1) work. | |||||
// Without it it will set "" rather than `""`. | |||||
// TODO: why is this needed? And why is this only needed here? | |||||
if len(keyContext) == 0 { | |||||
keyContext = Key{""} | |||||
} | |||||
p.types[keyContext.String()] = typ | |||||
} | |||||
// Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and | |||||
// "[a.b.c]" (the "a", "b", and "c" hashes are never created explicitly). | |||||
func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} } | |||||
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) } | |||||
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok } | |||||
func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray } | |||||
func (p *parser) addImplicitContext(key Key) { | |||||
p.addImplicit(key) | |||||
p.addContext(key, false) | |||||
} | |||||
// current returns the full key name of the current context. | |||||
func (p *parser) current() string { | |||||
if len(p.currentKey) == 0 { | |||||
return p.context.String() | |||||
} | |||||
if len(p.context) == 0 { | |||||
return p.currentKey | |||||
} | |||||
return fmt.Sprintf("%s.%s", p.context, p.currentKey) | |||||
} | |||||
func stripFirstNewline(s string) string { | |||||
if len(s) > 0 && s[0] == '\n' { | |||||
return s[1:] | |||||
} | |||||
if len(s) > 1 && s[0] == '\r' && s[1] == '\n' { | |||||
return s[2:] | |||||
} | |||||
return s | |||||
} | |||||
// Remove newlines inside triple-quoted strings if a line ends with "\". | |||||
func stripEscapedNewlines(s string) string { | |||||
split := strings.Split(s, "\n") | |||||
if len(split) < 1 { | |||||
return s | |||||
} | |||||
escNL := false // Keep track of the last non-blank line was escaped. | |||||
for i, line := range split { | |||||
line = strings.TrimRight(line, " \t\r") | |||||
if len(line) == 0 || line[len(line)-1] != '\\' { | |||||
split[i] = strings.TrimRight(split[i], "\r") | |||||
if !escNL && i != len(split)-1 { | |||||
split[i] += "\n" | |||||
} | |||||
continue | |||||
} | |||||
escBS := true | |||||
for j := len(line) - 1; j >= 0 && line[j] == '\\'; j-- { | |||||
escBS = !escBS | |||||
} | |||||
if escNL { | |||||
line = strings.TrimLeft(line, " \t\r") | |||||
} | |||||
escNL = !escBS | |||||
if escBS { | |||||
split[i] += "\n" | |||||
continue | |||||
} | |||||
split[i] = line[:len(line)-1] // Remove \ | |||||
if len(split)-1 > i { | |||||
split[i+1] = strings.TrimLeft(split[i+1], " \t\r") | |||||
} | |||||
} | |||||
return strings.Join(split, "") | |||||
} | |||||
func (p *parser) replaceEscapes(it item, str string) string { | |||||
replaced := make([]rune, 0, len(str)) | |||||
s := []byte(str) | |||||
r := 0 | |||||
for r < len(s) { | |||||
if s[r] != '\\' { | |||||
c, size := utf8.DecodeRune(s[r:]) | |||||
r += size | |||||
replaced = append(replaced, c) | |||||
continue | |||||
} | |||||
r += 1 | |||||
if r >= len(s) { | |||||
p.bug("Escape sequence at end of string.") | |||||
return "" | |||||
} | |||||
switch s[r] { | |||||
default: | |||||
p.bug("Expected valid escape code after \\, but got %q.", s[r]) | |||||
return "" | |||||
case ' ', '\t': | |||||
p.panicItemf(it, "invalid escape: '\\%c'", s[r]) | |||||
return "" | |||||
case 'b': | |||||
replaced = append(replaced, rune(0x0008)) | |||||
r += 1 | |||||
case 't': | |||||
replaced = append(replaced, rune(0x0009)) | |||||
r += 1 | |||||
case 'n': | |||||
replaced = append(replaced, rune(0x000A)) | |||||
r += 1 | |||||
case 'f': | |||||
replaced = append(replaced, rune(0x000C)) | |||||
r += 1 | |||||
case 'r': | |||||
replaced = append(replaced, rune(0x000D)) | |||||
r += 1 | |||||
case '"': | |||||
replaced = append(replaced, rune(0x0022)) | |||||
r += 1 | |||||
case '\\': | |||||
replaced = append(replaced, rune(0x005C)) | |||||
r += 1 | |||||
case 'u': | |||||
// At this point, we know we have a Unicode escape of the form | |||||
// `uXXXX` at [r, r+5). (Because the lexer guarantees this | |||||
// for us.) | |||||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5]) | |||||
replaced = append(replaced, escaped) | |||||
r += 5 | |||||
case 'U': | |||||
// At this point, we know we have a Unicode escape of the form | |||||
// `uXXXX` at [r, r+9). (Because the lexer guarantees this | |||||
// for us.) | |||||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9]) | |||||
replaced = append(replaced, escaped) | |||||
r += 9 | |||||
} | |||||
} | |||||
return string(replaced) | |||||
} | |||||
func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune { | |||||
s := string(bs) | |||||
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32) | |||||
if err != nil { | |||||
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err) | |||||
} | |||||
if !utf8.ValidRune(rune(hex)) { | |||||
p.panicItemf(it, "Escaped character '\\u%s' is not valid UTF-8.", s) | |||||
} | |||||
return rune(hex) | |||||
} |
@@ -0,0 +1,242 @@ | |||||
package toml | |||||
// Struct field handling is adapted from code in encoding/json: | |||||
// | |||||
// Copyright 2010 The Go Authors. All rights reserved. | |||||
// Use of this source code is governed by a BSD-style | |||||
// license that can be found in the Go distribution. | |||||
import ( | |||||
"reflect" | |||||
"sort" | |||||
"sync" | |||||
) | |||||
// A field represents a single field found in a struct. | |||||
type field struct { | |||||
name string // the name of the field (`toml` tag included) | |||||
tag bool // whether field has a `toml` tag | |||||
index []int // represents the depth of an anonymous field | |||||
typ reflect.Type // the type of the field | |||||
} | |||||
// byName sorts field by name, breaking ties with depth, | |||||
// then breaking ties with "name came from toml tag", then | |||||
// breaking ties with index sequence. | |||||
type byName []field | |||||
func (x byName) Len() int { return len(x) } | |||||
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } | |||||
func (x byName) Less(i, j int) bool { | |||||
if x[i].name != x[j].name { | |||||
return x[i].name < x[j].name | |||||
} | |||||
if len(x[i].index) != len(x[j].index) { | |||||
return len(x[i].index) < len(x[j].index) | |||||
} | |||||
if x[i].tag != x[j].tag { | |||||
return x[i].tag | |||||
} | |||||
return byIndex(x).Less(i, j) | |||||
} | |||||
// byIndex sorts field by index sequence. | |||||
type byIndex []field | |||||
func (x byIndex) Len() int { return len(x) } | |||||
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } | |||||
func (x byIndex) Less(i, j int) bool { | |||||
for k, xik := range x[i].index { | |||||
if k >= len(x[j].index) { | |||||
return false | |||||
} | |||||
if xik != x[j].index[k] { | |||||
return xik < x[j].index[k] | |||||
} | |||||
} | |||||
return len(x[i].index) < len(x[j].index) | |||||
} | |||||
// typeFields returns a list of fields that TOML should recognize for the given | |||||
// type. The algorithm is breadth-first search over the set of structs to | |||||
// include - the top struct and then any reachable anonymous structs. | |||||
func typeFields(t reflect.Type) []field { | |||||
// Anonymous fields to explore at the current level and the next. | |||||
current := []field{} | |||||
next := []field{{typ: t}} | |||||
// Count of queued names for current level and the next. | |||||
var count map[reflect.Type]int | |||||
var nextCount map[reflect.Type]int | |||||
// Types already visited at an earlier level. | |||||
visited := map[reflect.Type]bool{} | |||||
// Fields found. | |||||
var fields []field | |||||
for len(next) > 0 { | |||||
current, next = next, current[:0] | |||||
count, nextCount = nextCount, map[reflect.Type]int{} | |||||
for _, f := range current { | |||||
if visited[f.typ] { | |||||
continue | |||||
} | |||||
visited[f.typ] = true | |||||
// Scan f.typ for fields to include. | |||||
for i := 0; i < f.typ.NumField(); i++ { | |||||
sf := f.typ.Field(i) | |||||
if sf.PkgPath != "" && !sf.Anonymous { // unexported | |||||
continue | |||||
} | |||||
opts := getOptions(sf.Tag) | |||||
if opts.skip { | |||||
continue | |||||
} | |||||
index := make([]int, len(f.index)+1) | |||||
copy(index, f.index) | |||||
index[len(f.index)] = i | |||||
ft := sf.Type | |||||
if ft.Name() == "" && ft.Kind() == reflect.Ptr { | |||||
// Follow pointer. | |||||
ft = ft.Elem() | |||||
} | |||||
// Record found field and index sequence. | |||||
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { | |||||
tagged := opts.name != "" | |||||
name := opts.name | |||||
if name == "" { | |||||
name = sf.Name | |||||
} | |||||
fields = append(fields, field{name, tagged, index, ft}) | |||||
if count[f.typ] > 1 { | |||||
// If there were multiple instances, add a second, | |||||
// so that the annihilation code will see a duplicate. | |||||
// It only cares about the distinction between 1 or 2, | |||||
// so don't bother generating any more copies. | |||||
fields = append(fields, fields[len(fields)-1]) | |||||
} | |||||
continue | |||||
} | |||||
// Record new anonymous struct to explore in next round. | |||||
nextCount[ft]++ | |||||
if nextCount[ft] == 1 { | |||||
f := field{name: ft.Name(), index: index, typ: ft} | |||||
next = append(next, f) | |||||
} | |||||
} | |||||
} | |||||
} | |||||
sort.Sort(byName(fields)) | |||||
// Delete all fields that are hidden by the Go rules for embedded fields, | |||||
// except that fields with TOML tags are promoted. | |||||
// The fields are sorted in primary order of name, secondary order | |||||
// of field index length. Loop over names; for each name, delete | |||||
// hidden fields by choosing the one dominant field that survives. | |||||
out := fields[:0] | |||||
for advance, i := 0, 0; i < len(fields); i += advance { | |||||
// One iteration per name. | |||||
// Find the sequence of fields with the name of this first field. | |||||
fi := fields[i] | |||||
name := fi.name | |||||
for advance = 1; i+advance < len(fields); advance++ { | |||||
fj := fields[i+advance] | |||||
if fj.name != name { | |||||
break | |||||
} | |||||
} | |||||
if advance == 1 { // Only one field with this name | |||||
out = append(out, fi) | |||||
continue | |||||
} | |||||
dominant, ok := dominantField(fields[i : i+advance]) | |||||
if ok { | |||||
out = append(out, dominant) | |||||
} | |||||
} | |||||
fields = out | |||||
sort.Sort(byIndex(fields)) | |||||
return fields | |||||
} | |||||
// dominantField looks through the fields, all of which are known to | |||||
// have the same name, to find the single field that dominates the | |||||
// others using Go's embedding rules, modified by the presence of | |||||
// TOML tags. If there are multiple top-level fields, the boolean | |||||
// will be false: This condition is an error in Go and we skip all | |||||
// the fields. | |||||
func dominantField(fields []field) (field, bool) { | |||||
// The fields are sorted in increasing index-length order. The winner | |||||
// must therefore be one with the shortest index length. Drop all | |||||
// longer entries, which is easy: just truncate the slice. | |||||
length := len(fields[0].index) | |||||
tagged := -1 // Index of first tagged field. | |||||
for i, f := range fields { | |||||
if len(f.index) > length { | |||||
fields = fields[:i] | |||||
break | |||||
} | |||||
if f.tag { | |||||
if tagged >= 0 { | |||||
// Multiple tagged fields at the same level: conflict. | |||||
// Return no field. | |||||
return field{}, false | |||||
} | |||||
tagged = i | |||||
} | |||||
} | |||||
if tagged >= 0 { | |||||
return fields[tagged], true | |||||
} | |||||
// All remaining fields have the same length. If there's more than one, | |||||
// we have a conflict (two fields named "X" at the same level) and we | |||||
// return no field. | |||||
if len(fields) > 1 { | |||||
return field{}, false | |||||
} | |||||
return fields[0], true | |||||
} | |||||
var fieldCache struct { | |||||
sync.RWMutex | |||||
m map[reflect.Type][]field | |||||
} | |||||
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. | |||||
func cachedTypeFields(t reflect.Type) []field { | |||||
fieldCache.RLock() | |||||
f := fieldCache.m[t] | |||||
fieldCache.RUnlock() | |||||
if f != nil { | |||||
return f | |||||
} | |||||
// Compute fields without lock. | |||||
// Might duplicate effort but won't hold other computations back. | |||||
f = typeFields(t) | |||||
if f == nil { | |||||
f = []field{} | |||||
} | |||||
fieldCache.Lock() | |||||
if fieldCache.m == nil { | |||||
fieldCache.m = map[reflect.Type][]field{} | |||||
} | |||||
fieldCache.m[t] = f | |||||
fieldCache.Unlock() | |||||
return f | |||||
} |
@@ -0,0 +1,70 @@ | |||||
package toml | |||||
// tomlType represents any Go type that corresponds to a TOML type. | |||||
// While the first draft of the TOML spec has a simplistic type system that | |||||
// probably doesn't need this level of sophistication, we seem to be militating | |||||
// toward adding real composite types. | |||||
type tomlType interface { | |||||
typeString() string | |||||
} | |||||
// typeEqual accepts any two types and returns true if they are equal. | |||||
func typeEqual(t1, t2 tomlType) bool { | |||||
if t1 == nil || t2 == nil { | |||||
return false | |||||
} | |||||
return t1.typeString() == t2.typeString() | |||||
} | |||||
func typeIsTable(t tomlType) bool { | |||||
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash) | |||||
} | |||||
type tomlBaseType string | |||||
func (btype tomlBaseType) typeString() string { | |||||
return string(btype) | |||||
} | |||||
func (btype tomlBaseType) String() string { | |||||
return btype.typeString() | |||||
} | |||||
var ( | |||||
tomlInteger tomlBaseType = "Integer" | |||||
tomlFloat tomlBaseType = "Float" | |||||
tomlDatetime tomlBaseType = "Datetime" | |||||
tomlString tomlBaseType = "String" | |||||
tomlBool tomlBaseType = "Bool" | |||||
tomlArray tomlBaseType = "Array" | |||||
tomlHash tomlBaseType = "Hash" | |||||
tomlArrayHash tomlBaseType = "ArrayHash" | |||||
) | |||||
// typeOfPrimitive returns a tomlType of any primitive value in TOML. | |||||
// Primitive values are: Integer, Float, Datetime, String and Bool. | |||||
// | |||||
// Passing a lexer item other than the following will cause a BUG message | |||||
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime. | |||||
func (p *parser) typeOfPrimitive(lexItem item) tomlType { | |||||
switch lexItem.typ { | |||||
case itemInteger: | |||||
return tomlInteger | |||||
case itemFloat: | |||||
return tomlFloat | |||||
case itemDatetime: | |||||
return tomlDatetime | |||||
case itemString: | |||||
return tomlString | |||||
case itemMultilineString: | |||||
return tomlString | |||||
case itemRawString: | |||||
return tomlString | |||||
case itemRawMultilineString: | |||||
return tomlString | |||||
case itemBool: | |||||
return tomlBool | |||||
} | |||||
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) | |||||
panic("unreachable") | |||||
} |
@@ -0,0 +1,22 @@ | |||||
Copyright (c) 2016 Caleb Spare | |||||
MIT License | |||||
Permission is hereby granted, free of charge, to any person obtaining | |||||
a copy of this software and associated documentation files (the | |||||
"Software"), to deal in the Software without restriction, including | |||||
without limitation the rights to use, copy, modify, merge, publish, | |||||
distribute, sublicense, and/or sell copies of the Software, and to | |||||
permit persons to whom the Software is furnished to do so, subject to | |||||
the following conditions: | |||||
The above copyright notice and this permission notice shall be | |||||
included in all copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -0,0 +1,69 @@ | |||||
# xxhash | |||||
[![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2) | |||||
[![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml) | |||||
xxhash is a Go implementation of the 64-bit | |||||
[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a | |||||
high-quality hashing algorithm that is much faster than anything in the Go | |||||
standard library. | |||||
This package provides a straightforward API: | |||||
``` | |||||
func Sum64(b []byte) uint64 | |||||
func Sum64String(s string) uint64 | |||||
type Digest struct{ ... } | |||||
func New() *Digest | |||||
``` | |||||
The `Digest` type implements hash.Hash64. Its key methods are: | |||||
``` | |||||
func (*Digest) Write([]byte) (int, error) | |||||
func (*Digest) WriteString(string) (int, error) | |||||
func (*Digest) Sum64() uint64 | |||||
``` | |||||
This implementation provides a fast pure-Go implementation and an even faster | |||||
assembly implementation for amd64. | |||||
## Compatibility | |||||
This package is in a module and the latest code is in version 2 of the module. | |||||
You need a version of Go with at least "minimal module compatibility" to use | |||||
github.com/cespare/xxhash/v2: | |||||
* 1.9.7+ for Go 1.9 | |||||
* 1.10.3+ for Go 1.10 | |||||
* Go 1.11 or later | |||||
I recommend using the latest release of Go. | |||||
## Benchmarks | |||||
Here are some quick benchmarks comparing the pure-Go and assembly | |||||
implementations of Sum64. | |||||
| input size | purego | asm | | |||||
| --- | --- | --- | | |||||
| 5 B | 979.66 MB/s | 1291.17 MB/s | | |||||
| 100 B | 7475.26 MB/s | 7973.40 MB/s | | |||||
| 4 KB | 17573.46 MB/s | 17602.65 MB/s | | |||||
| 10 MB | 17131.46 MB/s | 17142.16 MB/s | | |||||
These numbers were generated on Ubuntu 18.04 with an Intel i7-8700K CPU using | |||||
the following commands under Go 1.11.2: | |||||
``` | |||||
$ go test -tags purego -benchtime 10s -bench '/xxhash,direct,bytes' | |||||
$ go test -benchtime 10s -bench '/xxhash,direct,bytes' | |||||
``` | |||||
## Projects using this package | |||||
- [InfluxDB](https://github.com/influxdata/influxdb) | |||||
- [Prometheus](https://github.com/prometheus/prometheus) | |||||
- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics) | |||||
- [FreeCache](https://github.com/coocood/freecache) | |||||
- [FastCache](https://github.com/VictoriaMetrics/fastcache) |
@@ -0,0 +1,235 @@ | |||||
// Package xxhash implements the 64-bit variant of xxHash (XXH64) as described | |||||
// at http://cyan4973.github.io/xxHash/. | |||||
package xxhash | |||||
import ( | |||||
"encoding/binary" | |||||
"errors" | |||||
"math/bits" | |||||
) | |||||
const ( | |||||
prime1 uint64 = 11400714785074694791 | |||||
prime2 uint64 = 14029467366897019727 | |||||
prime3 uint64 = 1609587929392839161 | |||||
prime4 uint64 = 9650029242287828579 | |||||
prime5 uint64 = 2870177450012600261 | |||||
) | |||||
// NOTE(caleb): I'm using both consts and vars of the primes. Using consts where | |||||
// possible in the Go code is worth a small (but measurable) performance boost | |||||
// by avoiding some MOVQs. Vars are needed for the asm and also are useful for | |||||
// convenience in the Go code in a few places where we need to intentionally | |||||
// avoid constant arithmetic (e.g., v1 := prime1 + prime2 fails because the | |||||
// result overflows a uint64). | |||||
var ( | |||||
prime1v = prime1 | |||||
prime2v = prime2 | |||||
prime3v = prime3 | |||||
prime4v = prime4 | |||||
prime5v = prime5 | |||||
) | |||||
// Digest implements hash.Hash64. | |||||
type Digest struct { | |||||
v1 uint64 | |||||
v2 uint64 | |||||
v3 uint64 | |||||
v4 uint64 | |||||
total uint64 | |||||
mem [32]byte | |||||
n int // how much of mem is used | |||||
} | |||||
// New creates a new Digest that computes the 64-bit xxHash algorithm. | |||||
func New() *Digest { | |||||
var d Digest | |||||
d.Reset() | |||||
return &d | |||||
} | |||||
// Reset clears the Digest's state so that it can be reused. | |||||
func (d *Digest) Reset() { | |||||
d.v1 = prime1v + prime2 | |||||
d.v2 = prime2 | |||||
d.v3 = 0 | |||||
d.v4 = -prime1v | |||||
d.total = 0 | |||||
d.n = 0 | |||||
} | |||||
// Size always returns 8 bytes. | |||||
func (d *Digest) Size() int { return 8 } | |||||
// BlockSize always returns 32 bytes. | |||||
func (d *Digest) BlockSize() int { return 32 } | |||||
// Write adds more data to d. It always returns len(b), nil. | |||||
func (d *Digest) Write(b []byte) (n int, err error) { | |||||
n = len(b) | |||||
d.total += uint64(n) | |||||
if d.n+n < 32 { | |||||
// This new data doesn't even fill the current block. | |||||
copy(d.mem[d.n:], b) | |||||
d.n += n | |||||
return | |||||
} | |||||
if d.n > 0 { | |||||
// Finish off the partial block. | |||||
copy(d.mem[d.n:], b) | |||||
d.v1 = round(d.v1, u64(d.mem[0:8])) | |||||
d.v2 = round(d.v2, u64(d.mem[8:16])) | |||||
d.v3 = round(d.v3, u64(d.mem[16:24])) | |||||
d.v4 = round(d.v4, u64(d.mem[24:32])) | |||||
b = b[32-d.n:] | |||||
d.n = 0 | |||||
} | |||||
if len(b) >= 32 { | |||||
// One or more full blocks left. | |||||
nw := writeBlocks(d, b) | |||||
b = b[nw:] | |||||
} | |||||
// Store any remaining partial block. | |||||
copy(d.mem[:], b) | |||||
d.n = len(b) | |||||
return | |||||
} | |||||
// Sum appends the current hash to b and returns the resulting slice. | |||||
func (d *Digest) Sum(b []byte) []byte { | |||||
s := d.Sum64() | |||||
return append( | |||||
b, | |||||
byte(s>>56), | |||||
byte(s>>48), | |||||
byte(s>>40), | |||||
byte(s>>32), | |||||
byte(s>>24), | |||||
byte(s>>16), | |||||
byte(s>>8), | |||||
byte(s), | |||||
) | |||||
} | |||||
// Sum64 returns the current hash. | |||||
func (d *Digest) Sum64() uint64 { | |||||
var h uint64 | |||||
if d.total >= 32 { | |||||
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4 | |||||
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4) | |||||
h = mergeRound(h, v1) | |||||
h = mergeRound(h, v2) | |||||
h = mergeRound(h, v3) | |||||
h = mergeRound(h, v4) | |||||
} else { | |||||
h = d.v3 + prime5 | |||||
} | |||||
h += d.total | |||||
i, end := 0, d.n | |||||
for ; i+8 <= end; i += 8 { | |||||
k1 := round(0, u64(d.mem[i:i+8])) | |||||
h ^= k1 | |||||
h = rol27(h)*prime1 + prime4 | |||||
} | |||||
if i+4 <= end { | |||||
h ^= uint64(u32(d.mem[i:i+4])) * prime1 | |||||
h = rol23(h)*prime2 + prime3 | |||||
i += 4 | |||||
} | |||||
for i < end { | |||||
h ^= uint64(d.mem[i]) * prime5 | |||||
h = rol11(h) * prime1 | |||||
i++ | |||||
} | |||||
h ^= h >> 33 | |||||
h *= prime2 | |||||
h ^= h >> 29 | |||||
h *= prime3 | |||||
h ^= h >> 32 | |||||
return h | |||||
} | |||||
const ( | |||||
magic = "xxh\x06" | |||||
marshaledSize = len(magic) + 8*5 + 32 | |||||
) | |||||
// MarshalBinary implements the encoding.BinaryMarshaler interface. | |||||
func (d *Digest) MarshalBinary() ([]byte, error) { | |||||
b := make([]byte, 0, marshaledSize) | |||||
b = append(b, magic...) | |||||
b = appendUint64(b, d.v1) | |||||
b = appendUint64(b, d.v2) | |||||
b = appendUint64(b, d.v3) | |||||
b = appendUint64(b, d.v4) | |||||
b = appendUint64(b, d.total) | |||||
b = append(b, d.mem[:d.n]...) | |||||
b = b[:len(b)+len(d.mem)-d.n] | |||||
return b, nil | |||||
} | |||||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. | |||||
func (d *Digest) UnmarshalBinary(b []byte) error { | |||||
if len(b) < len(magic) || string(b[:len(magic)]) != magic { | |||||
return errors.New("xxhash: invalid hash state identifier") | |||||
} | |||||
if len(b) != marshaledSize { | |||||
return errors.New("xxhash: invalid hash state size") | |||||
} | |||||
b = b[len(magic):] | |||||
b, d.v1 = consumeUint64(b) | |||||
b, d.v2 = consumeUint64(b) | |||||
b, d.v3 = consumeUint64(b) | |||||
b, d.v4 = consumeUint64(b) | |||||
b, d.total = consumeUint64(b) | |||||
copy(d.mem[:], b) | |||||
d.n = int(d.total % uint64(len(d.mem))) | |||||
return nil | |||||
} | |||||
func appendUint64(b []byte, x uint64) []byte { | |||||
var a [8]byte | |||||
binary.LittleEndian.PutUint64(a[:], x) | |||||
return append(b, a[:]...) | |||||
} | |||||
func consumeUint64(b []byte) ([]byte, uint64) { | |||||
x := u64(b) | |||||
return b[8:], x | |||||
} | |||||
func u64(b []byte) uint64 { return binary.LittleEndian.Uint64(b) } | |||||
func u32(b []byte) uint32 { return binary.LittleEndian.Uint32(b) } | |||||
func round(acc, input uint64) uint64 { | |||||
acc += input * prime2 | |||||
acc = rol31(acc) | |||||
acc *= prime1 | |||||
return acc | |||||
} | |||||
func mergeRound(acc, val uint64) uint64 { | |||||
val = round(0, val) | |||||
acc ^= val | |||||
acc = acc*prime1 + prime4 | |||||
return acc | |||||
} | |||||
func rol1(x uint64) uint64 { return bits.RotateLeft64(x, 1) } | |||||
func rol7(x uint64) uint64 { return bits.RotateLeft64(x, 7) } | |||||
func rol11(x uint64) uint64 { return bits.RotateLeft64(x, 11) } | |||||
func rol12(x uint64) uint64 { return bits.RotateLeft64(x, 12) } | |||||
func rol18(x uint64) uint64 { return bits.RotateLeft64(x, 18) } | |||||
func rol23(x uint64) uint64 { return bits.RotateLeft64(x, 23) } | |||||
func rol27(x uint64) uint64 { return bits.RotateLeft64(x, 27) } | |||||
func rol31(x uint64) uint64 { return bits.RotateLeft64(x, 31) } |
@@ -0,0 +1,13 @@ | |||||
// +build !appengine | |||||
// +build gc | |||||
// +build !purego | |||||
package xxhash | |||||
// Sum64 computes the 64-bit xxHash digest of b. | |||||
// | |||||
//go:noescape | |||||
func Sum64(b []byte) uint64 | |||||
//go:noescape | |||||
func writeBlocks(d *Digest, b []byte) int |
@@ -0,0 +1,215 @@ | |||||
// +build !appengine | |||||
// +build gc | |||||
// +build !purego | |||||
#include "textflag.h" | |||||
// Register allocation: | |||||
// AX h | |||||
// SI pointer to advance through b | |||||
// DX n | |||||
// BX loop end | |||||
// R8 v1, k1 | |||||
// R9 v2 | |||||
// R10 v3 | |||||
// R11 v4 | |||||
// R12 tmp | |||||
// R13 prime1v | |||||
// R14 prime2v | |||||
// DI prime4v | |||||
// round reads from and advances the buffer pointer in SI. | |||||
// It assumes that R13 has prime1v and R14 has prime2v. | |||||
#define round(r) \ | |||||
MOVQ (SI), R12 \ | |||||
ADDQ $8, SI \ | |||||
IMULQ R14, R12 \ | |||||
ADDQ R12, r \ | |||||
ROLQ $31, r \ | |||||
IMULQ R13, r | |||||
// mergeRound applies a merge round on the two registers acc and val. | |||||
// It assumes that R13 has prime1v, R14 has prime2v, and DI has prime4v. | |||||
#define mergeRound(acc, val) \ | |||||
IMULQ R14, val \ | |||||
ROLQ $31, val \ | |||||
IMULQ R13, val \ | |||||
XORQ val, acc \ | |||||
IMULQ R13, acc \ | |||||
ADDQ DI, acc | |||||
// func Sum64(b []byte) uint64 | |||||
TEXT ·Sum64(SB), NOSPLIT, $0-32 | |||||
// Load fixed primes. | |||||
MOVQ ·prime1v(SB), R13 | |||||
MOVQ ·prime2v(SB), R14 | |||||
MOVQ ·prime4v(SB), DI | |||||
// Load slice. | |||||
MOVQ b_base+0(FP), SI | |||||
MOVQ b_len+8(FP), DX | |||||
LEAQ (SI)(DX*1), BX | |||||
// The first loop limit will be len(b)-32. | |||||
SUBQ $32, BX | |||||
// Check whether we have at least one block. | |||||
CMPQ DX, $32 | |||||
JLT noBlocks | |||||
// Set up initial state (v1, v2, v3, v4). | |||||
MOVQ R13, R8 | |||||
ADDQ R14, R8 | |||||
MOVQ R14, R9 | |||||
XORQ R10, R10 | |||||
XORQ R11, R11 | |||||
SUBQ R13, R11 | |||||
// Loop until SI > BX. | |||||
blockLoop: | |||||
round(R8) | |||||
round(R9) | |||||
round(R10) | |||||
round(R11) | |||||
CMPQ SI, BX | |||||
JLE blockLoop | |||||
MOVQ R8, AX | |||||
ROLQ $1, AX | |||||
MOVQ R9, R12 | |||||
ROLQ $7, R12 | |||||
ADDQ R12, AX | |||||
MOVQ R10, R12 | |||||
ROLQ $12, R12 | |||||
ADDQ R12, AX | |||||
MOVQ R11, R12 | |||||
ROLQ $18, R12 | |||||
ADDQ R12, AX | |||||
mergeRound(AX, R8) | |||||
mergeRound(AX, R9) | |||||
mergeRound(AX, R10) | |||||
mergeRound(AX, R11) | |||||
JMP afterBlocks | |||||
noBlocks: | |||||
MOVQ ·prime5v(SB), AX | |||||
afterBlocks: | |||||
ADDQ DX, AX | |||||
// Right now BX has len(b)-32, and we want to loop until SI > len(b)-8. | |||||
ADDQ $24, BX | |||||
CMPQ SI, BX | |||||
JG fourByte | |||||
wordLoop: | |||||
// Calculate k1. | |||||
MOVQ (SI), R8 | |||||
ADDQ $8, SI | |||||
IMULQ R14, R8 | |||||
ROLQ $31, R8 | |||||
IMULQ R13, R8 | |||||
XORQ R8, AX | |||||
ROLQ $27, AX | |||||
IMULQ R13, AX | |||||
ADDQ DI, AX | |||||
CMPQ SI, BX | |||||
JLE wordLoop | |||||
fourByte: | |||||
ADDQ $4, BX | |||||
CMPQ SI, BX | |||||
JG singles | |||||
MOVL (SI), R8 | |||||
ADDQ $4, SI | |||||
IMULQ R13, R8 | |||||
XORQ R8, AX | |||||
ROLQ $23, AX | |||||
IMULQ R14, AX | |||||
ADDQ ·prime3v(SB), AX | |||||
singles: | |||||
ADDQ $4, BX | |||||
CMPQ SI, BX | |||||
JGE finalize | |||||
singlesLoop: | |||||
MOVBQZX (SI), R12 | |||||
ADDQ $1, SI | |||||
IMULQ ·prime5v(SB), R12 | |||||
XORQ R12, AX | |||||
ROLQ $11, AX | |||||
IMULQ R13, AX | |||||
CMPQ SI, BX | |||||
JL singlesLoop | |||||
finalize: | |||||
MOVQ AX, R12 | |||||
SHRQ $33, R12 | |||||
XORQ R12, AX | |||||
IMULQ R14, AX | |||||
MOVQ AX, R12 | |||||
SHRQ $29, R12 | |||||
XORQ R12, AX | |||||
IMULQ ·prime3v(SB), AX | |||||
MOVQ AX, R12 | |||||
SHRQ $32, R12 | |||||
XORQ R12, AX | |||||
MOVQ AX, ret+24(FP) | |||||
RET | |||||
// writeBlocks uses the same registers as above except that it uses AX to store | |||||
// the d pointer. | |||||
// func writeBlocks(d *Digest, b []byte) int | |||||
TEXT ·writeBlocks(SB), NOSPLIT, $0-40 | |||||
// Load fixed primes needed for round. | |||||
MOVQ ·prime1v(SB), R13 | |||||
MOVQ ·prime2v(SB), R14 | |||||
// Load slice. | |||||
MOVQ b_base+8(FP), SI | |||||
MOVQ b_len+16(FP), DX | |||||
LEAQ (SI)(DX*1), BX | |||||
SUBQ $32, BX | |||||
// Load vN from d. | |||||
MOVQ d+0(FP), AX | |||||
MOVQ 0(AX), R8 // v1 | |||||
MOVQ 8(AX), R9 // v2 | |||||
MOVQ 16(AX), R10 // v3 | |||||
MOVQ 24(AX), R11 // v4 | |||||
// We don't need to check the loop condition here; this function is | |||||
// always called with at least one block of data to process. | |||||
blockLoop: | |||||
round(R8) | |||||
round(R9) | |||||
round(R10) | |||||
round(R11) | |||||
CMPQ SI, BX | |||||
JLE blockLoop | |||||
// Copy vN back to d. | |||||
MOVQ R8, 0(AX) | |||||
MOVQ R9, 8(AX) | |||||
MOVQ R10, 16(AX) | |||||
MOVQ R11, 24(AX) | |||||
// The number of bytes written is SI minus the old base pointer. | |||||
SUBQ b_base+8(FP), SI | |||||
MOVQ SI, ret+32(FP) | |||||
RET |
@@ -0,0 +1,76 @@ | |||||
// +build !amd64 appengine !gc purego | |||||
package xxhash | |||||
// Sum64 computes the 64-bit xxHash digest of b. | |||||
func Sum64(b []byte) uint64 { | |||||
// A simpler version would be | |||||
// d := New() | |||||
// d.Write(b) | |||||
// return d.Sum64() | |||||
// but this is faster, particularly for small inputs. | |||||
n := len(b) | |||||
var h uint64 | |||||
if n >= 32 { | |||||
v1 := prime1v + prime2 | |||||
v2 := prime2 | |||||
v3 := uint64(0) | |||||
v4 := -prime1v | |||||
for len(b) >= 32 { | |||||
v1 = round(v1, u64(b[0:8:len(b)])) | |||||
v2 = round(v2, u64(b[8:16:len(b)])) | |||||
v3 = round(v3, u64(b[16:24:len(b)])) | |||||
v4 = round(v4, u64(b[24:32:len(b)])) | |||||
b = b[32:len(b):len(b)] | |||||
} | |||||
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4) | |||||
h = mergeRound(h, v1) | |||||
h = mergeRound(h, v2) | |||||
h = mergeRound(h, v3) | |||||
h = mergeRound(h, v4) | |||||
} else { | |||||
h = prime5 | |||||
} | |||||
h += uint64(n) | |||||
i, end := 0, len(b) | |||||
for ; i+8 <= end; i += 8 { | |||||
k1 := round(0, u64(b[i:i+8:len(b)])) | |||||
h ^= k1 | |||||
h = rol27(h)*prime1 + prime4 | |||||
} | |||||
if i+4 <= end { | |||||
h ^= uint64(u32(b[i:i+4:len(b)])) * prime1 | |||||
h = rol23(h)*prime2 + prime3 | |||||
i += 4 | |||||
} | |||||
for ; i < end; i++ { | |||||
h ^= uint64(b[i]) * prime5 | |||||
h = rol11(h) * prime1 | |||||
} | |||||
h ^= h >> 33 | |||||
h *= prime2 | |||||
h ^= h >> 29 | |||||
h *= prime3 | |||||
h ^= h >> 32 | |||||
return h | |||||
} | |||||
func writeBlocks(d *Digest, b []byte) int { | |||||
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4 | |||||
n := len(b) | |||||
for len(b) >= 32 { | |||||
v1 = round(v1, u64(b[0:8:len(b)])) | |||||
v2 = round(v2, u64(b[8:16:len(b)])) | |||||
v3 = round(v3, u64(b[16:24:len(b)])) | |||||
v4 = round(v4, u64(b[24:32:len(b)])) | |||||
b = b[32:len(b):len(b)] | |||||
} | |||||
d.v1, d.v2, d.v3, d.v4 = v1, v2, v3, v4 | |||||
return n - len(b) | |||||
} |
@@ -0,0 +1,15 @@ | |||||
// +build appengine | |||||
// This file contains the safe implementations of otherwise unsafe-using code. | |||||
package xxhash | |||||
// Sum64String computes the 64-bit xxHash digest of s. | |||||
func Sum64String(s string) uint64 { | |||||
return Sum64([]byte(s)) | |||||
} | |||||
// WriteString adds more data to d. It always returns len(s), nil. | |||||
func (d *Digest) WriteString(s string) (n int, err error) { | |||||
return d.Write([]byte(s)) | |||||
} |
@@ -0,0 +1,57 @@ | |||||
// +build !appengine | |||||
// This file encapsulates usage of unsafe. | |||||
// xxhash_safe.go contains the safe implementations. | |||||
package xxhash | |||||
import ( | |||||
"unsafe" | |||||
) | |||||
// In the future it's possible that compiler optimizations will make these | |||||
// XxxString functions unnecessary by realizing that calls such as | |||||
// Sum64([]byte(s)) don't need to copy s. See https://golang.org/issue/2205. | |||||
// If that happens, even if we keep these functions they can be replaced with | |||||
// the trivial safe code. | |||||
// NOTE: The usual way of doing an unsafe string-to-[]byte conversion is: | |||||
// | |||||
// var b []byte | |||||
// bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) | |||||
// bh.Data = (*reflect.StringHeader)(unsafe.Pointer(&s)).Data | |||||
// bh.Len = len(s) | |||||
// bh.Cap = len(s) | |||||
// | |||||
// Unfortunately, as of Go 1.15.3 the inliner's cost model assigns a high enough | |||||
// weight to this sequence of expressions that any function that uses it will | |||||
// not be inlined. Instead, the functions below use a different unsafe | |||||
// conversion designed to minimize the inliner weight and allow both to be | |||||
// inlined. There is also a test (TestInlining) which verifies that these are | |||||
// inlined. | |||||
// | |||||
// See https://github.com/golang/go/issues/42739 for discussion. | |||||
// Sum64String computes the 64-bit xxHash digest of s. | |||||
// It may be faster than Sum64([]byte(s)) by avoiding a copy. | |||||
func Sum64String(s string) uint64 { | |||||
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)})) | |||||
return Sum64(b) | |||||
} | |||||
// WriteString adds more data to d. It always returns len(s), nil. | |||||
// It may be faster than Write([]byte(s)) by avoiding a copy. | |||||
func (d *Digest) WriteString(s string) (n int, err error) { | |||||
d.Write(*(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)}))) | |||||
// d.Write always returns len(s), nil. | |||||
// Ignoring the return output and returning these fixed values buys a | |||||
// savings of 6 in the inliner's cost model. | |||||
return len(s), nil | |||||
} | |||||
// sliceHeader is similar to reflect.SliceHeader, but it assumes that the layout | |||||
// of the first two words is the same as the layout of a string. | |||||
type sliceHeader struct { | |||||
s string | |||||
cap int | |||||
} |
@@ -0,0 +1,21 @@ | |||||
The MIT License (MIT) | |||||
Copyright (c) 2017-2020 Damian Gryski <damian@gryski.com> | |||||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||||
of this software and associated documentation files (the "Software"), to deal | |||||
in the Software without restriction, including without limitation the rights | |||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||||
copies of the Software, and to permit persons to whom the Software is | |||||
furnished to do so, subject to the following conditions: | |||||
The above copyright notice and this permission notice shall be included in | |||||
all copies or substantial portions of the Software. | |||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |||||
THE SOFTWARE. |
@@ -0,0 +1,79 @@ | |||||
package rendezvous | |||||
type Rendezvous struct { | |||||
nodes map[string]int | |||||
nstr []string | |||||
nhash []uint64 | |||||
hash Hasher | |||||
} | |||||
type Hasher func(s string) uint64 | |||||
func New(nodes []string, hash Hasher) *Rendezvous { | |||||
r := &Rendezvous{ | |||||
nodes: make(map[string]int, len(nodes)), | |||||
nstr: make([]string, len(nodes)), | |||||
nhash: make([]uint64, len(nodes)), | |||||
hash: hash, | |||||
} | |||||
for i, n := range nodes { | |||||
r.nodes[n] = i | |||||
r.nstr[i] = n | |||||
r.nhash[i] = hash(n) | |||||
} | |||||
return r | |||||
} | |||||
func (r *Rendezvous) Lookup(k string) string { | |||||
// short-circuit if we're empty | |||||
if len(r.nodes) == 0 { | |||||
return "" | |||||
} | |||||
khash := r.hash(k) | |||||
var midx int | |||||
var mhash = xorshiftMult64(khash ^ r.nhash[0]) | |||||
for i, nhash := range r.nhash[1:] { | |||||
if h := xorshiftMult64(khash ^ nhash); h > mhash { | |||||
midx = i + 1 | |||||
mhash = h | |||||
} | |||||
} | |||||
return r.nstr[midx] | |||||
} | |||||
func (r *Rendezvous) Add(node string) { | |||||
r.nodes[node] = len(r.nstr) | |||||
r.nstr = append(r.nstr, node) | |||||
r.nhash = append(r.nhash, r.hash(node)) | |||||
} | |||||
func (r *Rendezvous) Remove(node string) { | |||||
// find index of node to remove | |||||
nidx := r.nodes[node] | |||||
// remove from the slices | |||||
l := len(r.nstr) | |||||
r.nstr[nidx] = r.nstr[l] | |||||
r.nstr = r.nstr[:l] | |||||
r.nhash[nidx] = r.nhash[l] | |||||
r.nhash = r.nhash[:l] | |||||
// update the map | |||||
delete(r.nodes, node) | |||||
moved := r.nstr[nidx] | |||||
r.nodes[moved] = nidx | |||||
} | |||||
func xorshiftMult64(x uint64) uint64 { | |||||
x ^= x >> 12 // a | |||||
x ^= x << 25 // b | |||||
x ^= x >> 27 // c | |||||
return x * 2685821657736338717 | |||||
} |
@@ -0,0 +1,3 @@ | |||||
*.rdb | |||||
testdata/*/ | |||||
.idea/ |
@@ -0,0 +1,27 @@ | |||||
run: | |||||
concurrency: 8 | |||||
deadline: 5m | |||||
tests: false | |||||
linters: | |||||
enable-all: true | |||||
disable: | |||||
- funlen | |||||
- gochecknoglobals | |||||
- gochecknoinits | |||||
- gocognit | |||||
- goconst | |||||
- godox | |||||
- gosec | |||||
- maligned | |||||
- wsl | |||||
- gomnd | |||||
- goerr113 | |||||
- exhaustive | |||||
- nestif | |||||
- nlreturn | |||||
- exhaustivestruct | |||||
- wrapcheck | |||||
- errorlint | |||||
- cyclop | |||||
- forcetypeassert | |||||
- forbidigo |
@@ -0,0 +1,4 @@ | |||||
semi: false | |||||
singleQuote: true | |||||
proseWrap: always | |||||
printWidth: 100 |
@@ -0,0 +1,149 @@ | |||||
## [8.11.4](https://github.com/go-redis/redis/compare/v8.11.3...v8.11.4) (2021-10-04) | |||||
### Features | |||||
* add acl auth support for sentinels ([f66582f](https://github.com/go-redis/redis/commit/f66582f44f3dc3a4705a5260f982043fde4aa634)) | |||||
* add Cmd.{String,Int,Float,Bool}Slice helpers and an example ([5d3d293](https://github.com/go-redis/redis/commit/5d3d293cc9c60b90871e2420602001463708ce24)) | |||||
* add SetVal method for each command ([168981d](https://github.com/go-redis/redis/commit/168981da2d84ee9e07d15d3e74d738c162e264c4)) | |||||
## v8.11 | |||||
- Remove OpenTelemetry metrics. | |||||
- Supports more redis commands and options. | |||||
## v8.10 | |||||
- Removed extra OpenTelemetry spans from go-redis core. Now go-redis instrumentation only adds a | |||||
single span with a Redis command (instead of 4 spans). There are multiple reasons behind this | |||||
decision: | |||||
- Traces become smaller and less noisy. | |||||
- It may be costly to process those 3 extra spans for each query. | |||||
- go-redis no longer depends on OpenTelemetry. | |||||
Eventually we hope to replace the information that we no longer collect with OpenTelemetry | |||||
Metrics. | |||||
## v8.9 | |||||
- Changed `PubSub.Channel` to only rely on `Ping` result. You can now use `WithChannelSize`, | |||||
`WithChannelHealthCheckInterval`, and `WithChannelSendTimeout` to override default settings. | |||||
## v8.8 | |||||
- To make updating easier, extra modules now have the same version as go-redis does. That means that | |||||
you need to update your imports: | |||||
``` | |||||
github.com/go-redis/redis/extra/redisotel -> github.com/go-redis/redis/extra/redisotel/v8 | |||||
github.com/go-redis/redis/extra/rediscensus -> github.com/go-redis/redis/extra/rediscensus/v8 | |||||
``` | |||||
## v8.5 | |||||
- [knadh](https://github.com/knadh) contributed long-awaited ability to scan Redis Hash into a | |||||
struct: | |||||
```go | |||||
err := rdb.HGetAll(ctx, "hash").Scan(&data) | |||||
err := rdb.MGet(ctx, "key1", "key2").Scan(&data) | |||||
``` | |||||
- Please check [redismock](https://github.com/go-redis/redismock) by | |||||
[monkey92t](https://github.com/monkey92t) if you are looking for mocking Redis Client. | |||||
## v8 | |||||
- All commands require `context.Context` as a first argument, e.g. `rdb.Ping(ctx)`. If you are not | |||||
using `context.Context` yet, the simplest option is to define global package variable | |||||
`var ctx = context.TODO()` and use it when `ctx` is required. | |||||
- Full support for `context.Context` canceling. | |||||
- Added `redis.NewFailoverClusterClient` that supports routing read-only commands to a slave node. | |||||
- Added `redisext.OpenTemetryHook` that adds | |||||
[Redis OpenTelemetry instrumentation](https://redis.uptrace.dev/tracing/). | |||||
- Redis slow log support. | |||||
- Ring uses Rendezvous Hashing by default which provides better distribution. You need to move | |||||
existing keys to a new location or keys will be inaccessible / lost. To use old hashing scheme: | |||||
```go | |||||
import "github.com/golang/groupcache/consistenthash" | |||||
ring := redis.NewRing(&redis.RingOptions{ | |||||
NewConsistentHash: func() { | |||||
return consistenthash.New(100, crc32.ChecksumIEEE) | |||||
}, | |||||
}) | |||||
``` | |||||
- `ClusterOptions.MaxRedirects` default value is changed from 8 to 3. | |||||
- `Options.MaxRetries` default value is changed from 0 to 3. | |||||
- `Cluster.ForEachNode` is renamed to `ForEachShard` for consistency with `Ring`. | |||||
## v7.3 | |||||
- New option `Options.Username` which causes client to use `AuthACL`. Be aware if your connection | |||||
URL contains username. | |||||
## v7.2 | |||||
- Existing `HMSet` is renamed to `HSet` and old deprecated `HMSet` is restored for Redis 3 users. | |||||
## v7.1 | |||||
- Existing `Cmd.String` is renamed to `Cmd.Text`. New `Cmd.String` implements `fmt.Stringer` | |||||
interface. | |||||
## v7 | |||||
- _Important_. Tx.Pipeline now returns a non-transactional pipeline. Use Tx.TxPipeline for a | |||||
transactional pipeline. | |||||
- WrapProcess is replaced with more convenient AddHook that has access to context.Context. | |||||
- WithContext now can not be used to create a shallow copy of the client. | |||||
- New methods ProcessContext, DoContext, and ExecContext. | |||||
- Client respects Context.Deadline when setting net.Conn deadline. | |||||
- Client listens on Context.Done while waiting for a connection from the pool and returns an error | |||||
when context context is cancelled. | |||||
- Add PubSub.ChannelWithSubscriptions that sends `*Subscription` in addition to `*Message` to allow | |||||
detecting reconnections. | |||||
- `time.Time` is now marshalled in RFC3339 format. `rdb.Get("foo").Time()` helper is added to parse | |||||
the time. | |||||
- `SetLimiter` is removed and added `Options.Limiter` instead. | |||||
- `HMSet` is deprecated as of Redis v4. | |||||
## v6.15 | |||||
- Cluster and Ring pipelines process commands for each node in its own goroutine. | |||||
## 6.14 | |||||
- Added Options.MinIdleConns. | |||||
- Added Options.MaxConnAge. | |||||
- PoolStats.FreeConns is renamed to PoolStats.IdleConns. | |||||
- Add Client.Do to simplify creating custom commands. | |||||
- Add Cmd.String, Cmd.Int, Cmd.Int64, Cmd.Uint64, Cmd.Float64, and Cmd.Bool helpers. | |||||
- Lower memory usage. | |||||
## v6.13 | |||||
- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set | |||||
`HashReplicas = 1000` for better keys distribution between shards. | |||||
- Cluster client was optimized to use much less memory when reloading cluster state. | |||||
- PubSub.ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout | |||||
occurres. In most cases it is recommended to use PubSub.Channel instead. | |||||
- Dialer.KeepAlive is set to 5 minutes by default. | |||||
## v6.12 | |||||
- ClusterClient got new option called `ClusterSlots` which allows to build cluster of normal Redis | |||||
Servers that don't have cluster mode enabled. See | |||||
https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup |
@@ -0,0 +1,25 @@ | |||||
Copyright (c) 2013 The github.com/go-redis/redis Authors. | |||||
All rights reserved. | |||||
Redistribution and use in source and binary forms, with or without | |||||
modification, are permitted provided that the following conditions are | |||||
met: | |||||
* Redistributions of source code must retain the above copyright | |||||
notice, this list of conditions and the following disclaimer. | |||||
* Redistributions in binary form must reproduce the above | |||||
copyright notice, this list of conditions and the following disclaimer | |||||
in the documentation and/or other materials provided with the | |||||
distribution. | |||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,35 @@ | |||||
PACKAGE_DIRS := $(shell find . -mindepth 2 -type f -name 'go.mod' -exec dirname {} \; | sort) | |||||
test: testdeps | |||||
go test ./... | |||||
go test ./... -short -race | |||||
go test ./... -run=NONE -bench=. -benchmem | |||||
env GOOS=linux GOARCH=386 go test ./... | |||||
go vet | |||||
testdeps: testdata/redis/src/redis-server | |||||
bench: testdeps | |||||
go test ./... -test.run=NONE -test.bench=. -test.benchmem | |||||
.PHONY: all test testdeps bench | |||||
testdata/redis: | |||||
mkdir -p $@ | |||||
wget -qO- https://download.redis.io/releases/redis-6.2.5.tar.gz | tar xvz --strip-components=1 -C $@ | |||||
testdata/redis/src/redis-server: testdata/redis | |||||
cd $< && make all | |||||
fmt: | |||||
gofmt -w -s ./ | |||||
goimports -w -local github.com/go-redis/redis ./ | |||||
go_mod_tidy: | |||||
go get -u && go mod tidy | |||||
set -e; for dir in $(PACKAGE_DIRS); do \ | |||||
echo "go mod tidy in $${dir}"; \ | |||||
(cd "$${dir}" && \ | |||||
go get -u && \ | |||||
go mod tidy); \ | |||||
done |
@@ -0,0 +1,178 @@ | |||||
<p align="center"> | |||||
<a href="https://uptrace.dev/?utm_source=gh-redis&utm_campaign=gh-redis-banner1"> | |||||
<img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png" alt="All-in-one tool to optimize performance and monitor errors & logs"> | |||||
</a> | |||||
</p> | |||||
# Redis client for Golang | |||||
![build workflow](https://github.com/go-redis/redis/actions/workflows/build.yml/badge.svg) | |||||
[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-redis/redis/v8)](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) | |||||
[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/) | |||||
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) | |||||
- To ask questions, join [Discord](https://discord.gg/rWtp5Aj) or use | |||||
[Discussions](https://github.com/go-redis/redis/discussions). | |||||
- [Newsletter](https://blog.uptrace.dev/pages/newsletter.html) to get latest updates. | |||||
- [Documentation](https://redis.uptrace.dev) | |||||
- [Reference](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) | |||||
- [Examples](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples) | |||||
- [RealWorld example app](https://github.com/uptrace/go-treemux-realworld-example-app) | |||||
Other projects you may like: | |||||
- [Bun](https://bun.uptrace.dev) - fast and simple SQL client for PostgreSQL, MySQL, and SQLite. | |||||
- [treemux](https://github.com/vmihailenco/treemux) - high-speed, flexible, tree-based HTTP router | |||||
for Go. | |||||
## Ecosystem | |||||
- [Redis Mock](https://github.com/go-redis/redismock). | |||||
- [Distributed Locks](https://github.com/bsm/redislock). | |||||
- [Redis Cache](https://github.com/go-redis/cache). | |||||
- [Rate limiting](https://github.com/go-redis/redis_rate). | |||||
## Features | |||||
- Redis 3 commands except QUIT, MONITOR, and SYNC. | |||||
- Automatic connection pooling with | |||||
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. | |||||
- [Pub/Sub](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#PubSub). | |||||
- [Transactions](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline). | |||||
- [Pipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-Pipeline) and | |||||
[TxPipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline). | |||||
- [Scripting](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Script). | |||||
- [Timeouts](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Options). | |||||
- [Redis Sentinel](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewFailoverClient). | |||||
- [Redis Cluster](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewClusterClient). | |||||
- [Cluster of Redis Servers](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-NewClusterClient--ManualSetup) | |||||
without using cluster mode and Redis Sentinel. | |||||
- [Ring](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewRing). | |||||
- [Instrumentation](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#ex-package--Instrumentation). | |||||
## Installation | |||||
go-redis supports 2 last Go versions and requires a Go version with | |||||
[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go | |||||
module: | |||||
```shell | |||||
go mod init github.com/my/repo | |||||
``` | |||||
And then install go-redis/v8 (note _v8_ in the import; omitting it is a popular mistake): | |||||
```shell | |||||
go get github.com/go-redis/redis/v8 | |||||
``` | |||||
## Quickstart | |||||
```go | |||||
import ( | |||||
"context" | |||||
"github.com/go-redis/redis/v8" | |||||
) | |||||
var ctx = context.Background() | |||||
func ExampleClient() { | |||||
rdb := redis.NewClient(&redis.Options{ | |||||
Addr: "localhost:6379", | |||||
Password: "", // no password set | |||||
DB: 0, // use default DB | |||||
}) | |||||
err := rdb.Set(ctx, "key", "value", 0).Err() | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
val, err := rdb.Get(ctx, "key").Result() | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
fmt.Println("key", val) | |||||
val2, err := rdb.Get(ctx, "key2").Result() | |||||
if err == redis.Nil { | |||||
fmt.Println("key2 does not exist") | |||||
} else if err != nil { | |||||
panic(err) | |||||
} else { | |||||
fmt.Println("key2", val2) | |||||
} | |||||
// Output: key value | |||||
// key2 does not exist | |||||
} | |||||
``` | |||||
## Look and feel | |||||
Some corner cases: | |||||
```go | |||||
// SET key value EX 10 NX | |||||
set, err := rdb.SetNX(ctx, "key", "value", 10*time.Second).Result() | |||||
// SET key value keepttl NX | |||||
set, err := rdb.SetNX(ctx, "key", "value", redis.KeepTTL).Result() | |||||
// SORT list LIMIT 0 2 ASC | |||||
vals, err := rdb.Sort(ctx, "list", &redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() | |||||
// ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 | |||||
vals, err := rdb.ZRangeByScoreWithScores(ctx, "zset", &redis.ZRangeBy{ | |||||
Min: "-inf", | |||||
Max: "+inf", | |||||
Offset: 0, | |||||
Count: 2, | |||||
}).Result() | |||||
// ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM | |||||
vals, err := rdb.ZInterStore(ctx, "out", &redis.ZStore{ | |||||
Keys: []string{"zset1", "zset2"}, | |||||
Weights: []int64{2, 3} | |||||
}).Result() | |||||
// EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" | |||||
vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result() | |||||
// custom command | |||||
res, err := rdb.Do(ctx, "set", "key", "value").Result() | |||||
``` | |||||
## Run the test | |||||
go-redis will start a redis-server and run the test cases. | |||||
The paths of redis-server bin file and redis config file are defined in `main_test.go`: | |||||
``` | |||||
var ( | |||||
redisServerBin, _ = filepath.Abs(filepath.Join("testdata", "redis", "src", "redis-server")) | |||||
redisServerConf, _ = filepath.Abs(filepath.Join("testdata", "redis", "redis.conf")) | |||||
) | |||||
``` | |||||
For local testing, you can change the variables to refer to your local files, or create a soft link | |||||
to the corresponding folder for redis-server and copy the config file to `testdata/redis/`: | |||||
``` | |||||
ln -s /usr/bin/redis-server ./go-redis/testdata/redis/src | |||||
cp ./go-redis/testdata/redis.conf ./go-redis/testdata/redis/ | |||||
``` | |||||
Lastly, run: | |||||
``` | |||||
go test | |||||
``` | |||||
## Contributors | |||||
Thanks to all the people who already contributed! | |||||
<a href="https://github.com/go-redis/redis/graphs/contributors"> | |||||
<img src="https://contributors-img.web.app/image?repo=go-redis/redis" /> | |||||
</a> |
@@ -0,0 +1,15 @@ | |||||
# Releasing | |||||
1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: | |||||
```shell | |||||
TAG=v1.0.0 ./scripts/release.sh | |||||
``` | |||||
2. Open a pull request and wait for the build to finish. | |||||
3. Merge the pull request and run `tag.sh` to create tags for packages: | |||||
```shell | |||||
TAG=v1.0.0 ./scripts/tag.sh | |||||
``` |
@@ -0,0 +1,109 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"sync" | |||||
"sync/atomic" | |||||
) | |||||
func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { | |||||
cmd := NewIntCmd(ctx, "dbsize") | |||||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||||
var size int64 | |||||
err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error { | |||||
n, err := master.DBSize(ctx).Result() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
atomic.AddInt64(&size, n) | |||||
return nil | |||||
}) | |||||
if err != nil { | |||||
cmd.SetErr(err) | |||||
} else { | |||||
cmd.val = size | |||||
} | |||||
return nil | |||||
}) | |||||
return cmd | |||||
} | |||||
func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "script", "load", script) | |||||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||||
mu := &sync.Mutex{} | |||||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||||
val, err := shard.ScriptLoad(ctx, script).Result() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
mu.Lock() | |||||
if cmd.Val() == "" { | |||||
cmd.val = val | |||||
} | |||||
mu.Unlock() | |||||
return nil | |||||
}) | |||||
if err != nil { | |||||
cmd.SetErr(err) | |||||
} | |||||
return nil | |||||
}) | |||||
return cmd | |||||
} | |||||
func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { | |||||
cmd := NewStatusCmd(ctx, "script", "flush") | |||||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||||
return shard.ScriptFlush(ctx).Err() | |||||
}) | |||||
if err != nil { | |||||
cmd.SetErr(err) | |||||
} | |||||
return nil | |||||
}) | |||||
return cmd | |||||
} | |||||
func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd { | |||||
args := make([]interface{}, 2+len(hashes)) | |||||
args[0] = "script" | |||||
args[1] = "exists" | |||||
for i, hash := range hashes { | |||||
args[2+i] = hash | |||||
} | |||||
cmd := NewBoolSliceCmd(ctx, args...) | |||||
result := make([]bool, len(hashes)) | |||||
for i := range result { | |||||
result[i] = true | |||||
} | |||||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||||
mu := &sync.Mutex{} | |||||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||||
val, err := shard.ScriptExists(ctx, hashes...).Result() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
mu.Lock() | |||||
for i, v := range val { | |||||
result[i] = result[i] && v | |||||
} | |||||
mu.Unlock() | |||||
return nil | |||||
}) | |||||
if err != nil { | |||||
cmd.SetErr(err) | |||||
} else { | |||||
cmd.val = result | |||||
} | |||||
return nil | |||||
}) | |||||
return cmd | |||||
} |
@@ -0,0 +1,4 @@ | |||||
/* | |||||
Package redis implements a Redis client. | |||||
*/ | |||||
package redis |
@@ -0,0 +1,144 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"io" | |||||
"net" | |||||
"strings" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/proto" | |||||
) | |||||
// ErrClosed performs any operation on the closed client will return this error. | |||||
var ErrClosed = pool.ErrClosed | |||||
type Error interface { | |||||
error | |||||
// RedisError is a no-op function but | |||||
// serves to distinguish types that are Redis | |||||
// errors from ordinary errors: a type is a | |||||
// Redis error if it has a RedisError method. | |||||
RedisError() | |||||
} | |||||
var _ Error = proto.RedisError("") | |||||
func shouldRetry(err error, retryTimeout bool) bool { | |||||
switch err { | |||||
case io.EOF, io.ErrUnexpectedEOF: | |||||
return true | |||||
case nil, context.Canceled, context.DeadlineExceeded: | |||||
return false | |||||
} | |||||
if v, ok := err.(timeoutError); ok { | |||||
if v.Timeout() { | |||||
return retryTimeout | |||||
} | |||||
return true | |||||
} | |||||
s := err.Error() | |||||
if s == "ERR max number of clients reached" { | |||||
return true | |||||
} | |||||
if strings.HasPrefix(s, "LOADING ") { | |||||
return true | |||||
} | |||||
if strings.HasPrefix(s, "READONLY ") { | |||||
return true | |||||
} | |||||
if strings.HasPrefix(s, "CLUSTERDOWN ") { | |||||
return true | |||||
} | |||||
if strings.HasPrefix(s, "TRYAGAIN ") { | |||||
return true | |||||
} | |||||
return false | |||||
} | |||||
func isRedisError(err error) bool { | |||||
_, ok := err.(proto.RedisError) | |||||
return ok | |||||
} | |||||
func isBadConn(err error, allowTimeout bool, addr string) bool { | |||||
switch err { | |||||
case nil: | |||||
return false | |||||
case context.Canceled, context.DeadlineExceeded: | |||||
return true | |||||
} | |||||
if isRedisError(err) { | |||||
switch { | |||||
case isReadOnlyError(err): | |||||
// Close connections in read only state in case domain addr is used | |||||
// and domain resolves to a different Redis Server. See #790. | |||||
return true | |||||
case isMovedSameConnAddr(err, addr): | |||||
// Close connections when we are asked to move to the same addr | |||||
// of the connection. Force a DNS resolution when all connections | |||||
// of the pool are recycled | |||||
return true | |||||
default: | |||||
return false | |||||
} | |||||
} | |||||
if allowTimeout { | |||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { | |||||
return !netErr.Temporary() | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
func isMovedError(err error) (moved bool, ask bool, addr string) { | |||||
if !isRedisError(err) { | |||||
return | |||||
} | |||||
s := err.Error() | |||||
switch { | |||||
case strings.HasPrefix(s, "MOVED "): | |||||
moved = true | |||||
case strings.HasPrefix(s, "ASK "): | |||||
ask = true | |||||
default: | |||||
return | |||||
} | |||||
ind := strings.LastIndex(s, " ") | |||||
if ind == -1 { | |||||
return false, false, "" | |||||
} | |||||
addr = s[ind+1:] | |||||
return | |||||
} | |||||
func isLoadingError(err error) bool { | |||||
return strings.HasPrefix(err.Error(), "LOADING ") | |||||
} | |||||
func isReadOnlyError(err error) bool { | |||||
return strings.HasPrefix(err.Error(), "READONLY ") | |||||
} | |||||
func isMovedSameConnAddr(err error, addr string) bool { | |||||
redisError := err.Error() | |||||
if !strings.HasPrefix(redisError, "MOVED ") { | |||||
return false | |||||
} | |||||
return strings.HasSuffix(redisError, addr) | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type timeoutError interface { | |||||
Timeout() bool | |||||
} |
@@ -0,0 +1,56 @@ | |||||
package internal | |||||
import ( | |||||
"fmt" | |||||
"strconv" | |||||
"time" | |||||
) | |||||
func AppendArg(b []byte, v interface{}) []byte { | |||||
switch v := v.(type) { | |||||
case nil: | |||||
return append(b, "<nil>"...) | |||||
case string: | |||||
return appendUTF8String(b, Bytes(v)) | |||||
case []byte: | |||||
return appendUTF8String(b, v) | |||||
case int: | |||||
return strconv.AppendInt(b, int64(v), 10) | |||||
case int8: | |||||
return strconv.AppendInt(b, int64(v), 10) | |||||
case int16: | |||||
return strconv.AppendInt(b, int64(v), 10) | |||||
case int32: | |||||
return strconv.AppendInt(b, int64(v), 10) | |||||
case int64: | |||||
return strconv.AppendInt(b, v, 10) | |||||
case uint: | |||||
return strconv.AppendUint(b, uint64(v), 10) | |||||
case uint8: | |||||
return strconv.AppendUint(b, uint64(v), 10) | |||||
case uint16: | |||||
return strconv.AppendUint(b, uint64(v), 10) | |||||
case uint32: | |||||
return strconv.AppendUint(b, uint64(v), 10) | |||||
case uint64: | |||||
return strconv.AppendUint(b, v, 10) | |||||
case float32: | |||||
return strconv.AppendFloat(b, float64(v), 'f', -1, 64) | |||||
case float64: | |||||
return strconv.AppendFloat(b, v, 'f', -1, 64) | |||||
case bool: | |||||
if v { | |||||
return append(b, "true"...) | |||||
} | |||||
return append(b, "false"...) | |||||
case time.Time: | |||||
return v.AppendFormat(b, time.RFC3339Nano) | |||||
default: | |||||
return append(b, fmt.Sprint(v)...) | |||||
} | |||||
} | |||||
func appendUTF8String(dst []byte, src []byte) []byte { | |||||
dst = append(dst, src...) | |||||
return dst | |||||
} |
@@ -0,0 +1,78 @@ | |||||
package hashtag | |||||
import ( | |||||
"strings" | |||||
"github.com/go-redis/redis/v8/internal/rand" | |||||
) | |||||
const slotNumber = 16384 | |||||
// CRC16 implementation according to CCITT standards. | |||||
// Copyright 2001-2010 Georges Menie (www.menie.org) | |||||
// Copyright 2013 The Go Authors. All rights reserved. | |||||
// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c | |||||
var crc16tab = [256]uint16{ | |||||
0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, | |||||
0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, | |||||
0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, | |||||
0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, | |||||
0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, | |||||
0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, | |||||
0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, | |||||
0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, | |||||
0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, | |||||
0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, | |||||
0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, | |||||
0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, | |||||
0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, | |||||
0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, | |||||
0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, | |||||
0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, | |||||
0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, | |||||
0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, | |||||
0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, | |||||
0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, | |||||
0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, | |||||
0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, | |||||
0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, | |||||
0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, | |||||
0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, | |||||
0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, | |||||
0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, | |||||
0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, | |||||
0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, | |||||
0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, | |||||
0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, | |||||
0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, | |||||
} | |||||
func Key(key string) string { | |||||
if s := strings.IndexByte(key, '{'); s > -1 { | |||||
if e := strings.IndexByte(key[s+1:], '}'); e > 0 { | |||||
return key[s+1 : s+e+1] | |||||
} | |||||
} | |||||
return key | |||||
} | |||||
func RandomSlot() int { | |||||
return rand.Intn(slotNumber) | |||||
} | |||||
// Slot returns a consistent slot number between 0 and 16383 | |||||
// for any given string key. | |||||
func Slot(key string) int { | |||||
if key == "" { | |||||
return RandomSlot() | |||||
} | |||||
key = Key(key) | |||||
return int(crc16sum(key)) % slotNumber | |||||
} | |||||
func crc16sum(key string) (crc uint16) { | |||||
for i := 0; i < len(key); i++ { | |||||
crc = (crc << 8) ^ crc16tab[(byte(crc>>8)^key[i])&0x00ff] | |||||
} | |||||
return | |||||
} |
@@ -0,0 +1,201 @@ | |||||
package hscan | |||||
import ( | |||||
"errors" | |||||
"fmt" | |||||
"reflect" | |||||
"strconv" | |||||
) | |||||
// decoderFunc represents decoding functions for default built-in types. | |||||
type decoderFunc func(reflect.Value, string) error | |||||
var ( | |||||
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1). | |||||
decoders = []decoderFunc{ | |||||
reflect.Bool: decodeBool, | |||||
reflect.Int: decodeInt, | |||||
reflect.Int8: decodeInt8, | |||||
reflect.Int16: decodeInt16, | |||||
reflect.Int32: decodeInt32, | |||||
reflect.Int64: decodeInt64, | |||||
reflect.Uint: decodeUint, | |||||
reflect.Uint8: decodeUint8, | |||||
reflect.Uint16: decodeUint16, | |||||
reflect.Uint32: decodeUint32, | |||||
reflect.Uint64: decodeUint64, | |||||
reflect.Float32: decodeFloat32, | |||||
reflect.Float64: decodeFloat64, | |||||
reflect.Complex64: decodeUnsupported, | |||||
reflect.Complex128: decodeUnsupported, | |||||
reflect.Array: decodeUnsupported, | |||||
reflect.Chan: decodeUnsupported, | |||||
reflect.Func: decodeUnsupported, | |||||
reflect.Interface: decodeUnsupported, | |||||
reflect.Map: decodeUnsupported, | |||||
reflect.Ptr: decodeUnsupported, | |||||
reflect.Slice: decodeSlice, | |||||
reflect.String: decodeString, | |||||
reflect.Struct: decodeUnsupported, | |||||
reflect.UnsafePointer: decodeUnsupported, | |||||
} | |||||
// Global map of struct field specs that is populated once for every new | |||||
// struct type that is scanned. This caches the field types and the corresponding | |||||
// decoder functions to avoid iterating through struct fields on subsequent scans. | |||||
globalStructMap = newStructMap() | |||||
) | |||||
func Struct(dst interface{}) (StructValue, error) { | |||||
v := reflect.ValueOf(dst) | |||||
// The destination to scan into should be a struct pointer. | |||||
if v.Kind() != reflect.Ptr || v.IsNil() { | |||||
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst) | |||||
} | |||||
v = v.Elem() | |||||
if v.Kind() != reflect.Struct { | |||||
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst) | |||||
} | |||||
return StructValue{ | |||||
spec: globalStructMap.get(v.Type()), | |||||
value: v, | |||||
}, nil | |||||
} | |||||
// Scan scans the results from a key-value Redis map result set to a destination struct. | |||||
// The Redis keys are matched to the struct's field with the `redis` tag. | |||||
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error { | |||||
if len(keys) != len(vals) { | |||||
return errors.New("args should have the same number of keys and vals") | |||||
} | |||||
strct, err := Struct(dst) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// Iterate through the (key, value) sequence. | |||||
for i := 0; i < len(vals); i++ { | |||||
key, ok := keys[i].(string) | |||||
if !ok { | |||||
continue | |||||
} | |||||
val, ok := vals[i].(string) | |||||
if !ok { | |||||
continue | |||||
} | |||||
if err := strct.Scan(key, val); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func decodeBool(f reflect.Value, s string) error { | |||||
b, err := strconv.ParseBool(s) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
f.SetBool(b) | |||||
return nil | |||||
} | |||||
func decodeInt8(f reflect.Value, s string) error { | |||||
return decodeNumber(f, s, 8) | |||||
} | |||||
func decodeInt16(f reflect.Value, s string) error { | |||||
return decodeNumber(f, s, 16) | |||||
} | |||||
func decodeInt32(f reflect.Value, s string) error { | |||||
return decodeNumber(f, s, 32) | |||||
} | |||||
func decodeInt64(f reflect.Value, s string) error { | |||||
return decodeNumber(f, s, 64) | |||||
} | |||||
func decodeInt(f reflect.Value, s string) error { | |||||
return decodeNumber(f, s, 0) | |||||
} | |||||
func decodeNumber(f reflect.Value, s string, bitSize int) error { | |||||
v, err := strconv.ParseInt(s, 10, bitSize) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
f.SetInt(v) | |||||
return nil | |||||
} | |||||
func decodeUint8(f reflect.Value, s string) error { | |||||
return decodeUnsignedNumber(f, s, 8) | |||||
} | |||||
func decodeUint16(f reflect.Value, s string) error { | |||||
return decodeUnsignedNumber(f, s, 16) | |||||
} | |||||
func decodeUint32(f reflect.Value, s string) error { | |||||
return decodeUnsignedNumber(f, s, 32) | |||||
} | |||||
func decodeUint64(f reflect.Value, s string) error { | |||||
return decodeUnsignedNumber(f, s, 64) | |||||
} | |||||
func decodeUint(f reflect.Value, s string) error { | |||||
return decodeUnsignedNumber(f, s, 0) | |||||
} | |||||
func decodeUnsignedNumber(f reflect.Value, s string, bitSize int) error { | |||||
v, err := strconv.ParseUint(s, 10, bitSize) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
f.SetUint(v) | |||||
return nil | |||||
} | |||||
func decodeFloat32(f reflect.Value, s string) error { | |||||
v, err := strconv.ParseFloat(s, 32) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
f.SetFloat(v) | |||||
return nil | |||||
} | |||||
// although the default is float64, but we better define it. | |||||
func decodeFloat64(f reflect.Value, s string) error { | |||||
v, err := strconv.ParseFloat(s, 64) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
f.SetFloat(v) | |||||
return nil | |||||
} | |||||
func decodeString(f reflect.Value, s string) error { | |||||
f.SetString(s) | |||||
return nil | |||||
} | |||||
func decodeSlice(f reflect.Value, s string) error { | |||||
// []byte slice ([]uint8). | |||||
if f.Type().Elem().Kind() == reflect.Uint8 { | |||||
f.SetBytes([]byte(s)) | |||||
} | |||||
return nil | |||||
} | |||||
func decodeUnsupported(v reflect.Value, s string) error { | |||||
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type()) | |||||
} |
@@ -0,0 +1,93 @@ | |||||
package hscan | |||||
import ( | |||||
"fmt" | |||||
"reflect" | |||||
"strings" | |||||
"sync" | |||||
) | |||||
// structMap contains the map of struct fields for target structs | |||||
// indexed by the struct type. | |||||
type structMap struct { | |||||
m sync.Map | |||||
} | |||||
func newStructMap() *structMap { | |||||
return new(structMap) | |||||
} | |||||
func (s *structMap) get(t reflect.Type) *structSpec { | |||||
if v, ok := s.m.Load(t); ok { | |||||
return v.(*structSpec) | |||||
} | |||||
spec := newStructSpec(t, "redis") | |||||
s.m.Store(t, spec) | |||||
return spec | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// structSpec contains the list of all fields in a target struct. | |||||
type structSpec struct { | |||||
m map[string]*structField | |||||
} | |||||
func (s *structSpec) set(tag string, sf *structField) { | |||||
s.m[tag] = sf | |||||
} | |||||
func newStructSpec(t reflect.Type, fieldTag string) *structSpec { | |||||
numField := t.NumField() | |||||
out := &structSpec{ | |||||
m: make(map[string]*structField, numField), | |||||
} | |||||
for i := 0; i < numField; i++ { | |||||
f := t.Field(i) | |||||
tag := f.Tag.Get(fieldTag) | |||||
if tag == "" || tag == "-" { | |||||
continue | |||||
} | |||||
tag = strings.Split(tag, ",")[0] | |||||
if tag == "" { | |||||
continue | |||||
} | |||||
// Use the built-in decoder. | |||||
out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]}) | |||||
} | |||||
return out | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// structField represents a single field in a target struct. | |||||
type structField struct { | |||||
index int | |||||
fn decoderFunc | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type StructValue struct { | |||||
spec *structSpec | |||||
value reflect.Value | |||||
} | |||||
func (s StructValue) Scan(key string, value string) error { | |||||
field, ok := s.spec.m[key] | |||||
if !ok { | |||||
return nil | |||||
} | |||||
if err := field.fn(s.value.Field(field.index), value); err != nil { | |||||
t := s.value.Type() | |||||
return fmt.Errorf("cannot scan redis.result %s into struct field %s.%s of type %s, error-%s", | |||||
value, t.Name(), t.Field(field.index).Name, t.Field(field.index).Type, err.Error()) | |||||
} | |||||
return nil | |||||
} |
@@ -0,0 +1,29 @@ | |||||
package internal | |||||
import ( | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/rand" | |||||
) | |||||
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { | |||||
if retry < 0 { | |||||
panic("not reached") | |||||
} | |||||
if minBackoff == 0 { | |||||
return 0 | |||||
} | |||||
d := minBackoff << uint(retry) | |||||
if d < minBackoff { | |||||
return maxBackoff | |||||
} | |||||
d = minBackoff + time.Duration(rand.Int63n(int64(d))) | |||||
if d > maxBackoff || d < minBackoff { | |||||
d = maxBackoff | |||||
} | |||||
return d | |||||
} |
@@ -0,0 +1,26 @@ | |||||
package internal | |||||
import ( | |||||
"context" | |||||
"fmt" | |||||
"log" | |||||
"os" | |||||
) | |||||
type Logging interface { | |||||
Printf(ctx context.Context, format string, v ...interface{}) | |||||
} | |||||
type logger struct { | |||||
log *log.Logger | |||||
} | |||||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { | |||||
_ = l.log.Output(2, fmt.Sprintf(format, v...)) | |||||
} | |||||
// Logger calls Output to print to the stderr. | |||||
// Arguments are handled in the manner of fmt.Print. | |||||
var Logger Logging = &logger{ | |||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), | |||||
} |
@@ -0,0 +1,60 @@ | |||||
/* | |||||
Copyright 2014 The Camlistore Authors | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
*/ | |||||
package internal | |||||
import ( | |||||
"sync" | |||||
"sync/atomic" | |||||
) | |||||
// A Once will perform a successful action exactly once. | |||||
// | |||||
// Unlike a sync.Once, this Once's func returns an error | |||||
// and is re-armed on failure. | |||||
type Once struct { | |||||
m sync.Mutex | |||||
done uint32 | |||||
} | |||||
// Do calls the function f if and only if Do has not been invoked | |||||
// without error for this instance of Once. In other words, given | |||||
// var once Once | |||||
// if once.Do(f) is called multiple times, only the first call will | |||||
// invoke f, even if f has a different value in each invocation unless | |||||
// f returns an error. A new instance of Once is required for each | |||||
// function to execute. | |||||
// | |||||
// Do is intended for initialization that must be run exactly once. Since f | |||||
// is niladic, it may be necessary to use a function literal to capture the | |||||
// arguments to a function to be invoked by Do: | |||||
// err := config.once.Do(func() error { return config.init(filename) }) | |||||
func (o *Once) Do(f func() error) error { | |||||
if atomic.LoadUint32(&o.done) == 1 { | |||||
return nil | |||||
} | |||||
// Slow-path. | |||||
o.m.Lock() | |||||
defer o.m.Unlock() | |||||
var err error | |||||
if o.done == 0 { | |||||
err = f() | |||||
if err == nil { | |||||
atomic.StoreUint32(&o.done, 1) | |||||
} | |||||
} | |||||
return err | |||||
} |
@@ -0,0 +1,121 @@ | |||||
package pool | |||||
import ( | |||||
"bufio" | |||||
"context" | |||||
"net" | |||||
"sync/atomic" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/proto" | |||||
) | |||||
var noDeadline = time.Time{} | |||||
type Conn struct { | |||||
usedAt int64 // atomic | |||||
netConn net.Conn | |||||
rd *proto.Reader | |||||
bw *bufio.Writer | |||||
wr *proto.Writer | |||||
Inited bool | |||||
pooled bool | |||||
createdAt time.Time | |||||
} | |||||
func NewConn(netConn net.Conn) *Conn { | |||||
cn := &Conn{ | |||||
netConn: netConn, | |||||
createdAt: time.Now(), | |||||
} | |||||
cn.rd = proto.NewReader(netConn) | |||||
cn.bw = bufio.NewWriter(netConn) | |||||
cn.wr = proto.NewWriter(cn.bw) | |||||
cn.SetUsedAt(time.Now()) | |||||
return cn | |||||
} | |||||
func (cn *Conn) UsedAt() time.Time { | |||||
unix := atomic.LoadInt64(&cn.usedAt) | |||||
return time.Unix(unix, 0) | |||||
} | |||||
func (cn *Conn) SetUsedAt(tm time.Time) { | |||||
atomic.StoreInt64(&cn.usedAt, tm.Unix()) | |||||
} | |||||
func (cn *Conn) SetNetConn(netConn net.Conn) { | |||||
cn.netConn = netConn | |||||
cn.rd.Reset(netConn) | |||||
cn.bw.Reset(netConn) | |||||
} | |||||
func (cn *Conn) Write(b []byte) (int, error) { | |||||
return cn.netConn.Write(b) | |||||
} | |||||
func (cn *Conn) RemoteAddr() net.Addr { | |||||
if cn.netConn != nil { | |||||
return cn.netConn.RemoteAddr() | |||||
} | |||||
return nil | |||||
} | |||||
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { | |||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { | |||||
return err | |||||
} | |||||
return fn(cn.rd) | |||||
} | |||||
func (cn *Conn) WithWriter( | |||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, | |||||
) error { | |||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { | |||||
return err | |||||
} | |||||
if cn.bw.Buffered() > 0 { | |||||
cn.bw.Reset(cn.netConn) | |||||
} | |||||
if err := fn(cn.wr); err != nil { | |||||
return err | |||||
} | |||||
return cn.bw.Flush() | |||||
} | |||||
func (cn *Conn) Close() error { | |||||
return cn.netConn.Close() | |||||
} | |||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { | |||||
tm := time.Now() | |||||
cn.SetUsedAt(tm) | |||||
if timeout > 0 { | |||||
tm = tm.Add(timeout) | |||||
} | |||||
if ctx != nil { | |||||
deadline, ok := ctx.Deadline() | |||||
if ok { | |||||
if timeout == 0 { | |||||
return deadline | |||||
} | |||||
if deadline.Before(tm) { | |||||
return deadline | |||||
} | |||||
return tm | |||||
} | |||||
} | |||||
if timeout > 0 { | |||||
return tm | |||||
} | |||||
return noDeadline | |||||
} |
@@ -0,0 +1,557 @@ | |||||
package pool | |||||
import ( | |||||
"context" | |||||
"errors" | |||||
"net" | |||||
"sync" | |||||
"sync/atomic" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal" | |||||
) | |||||
var ( | |||||
// ErrClosed performs any operation on the closed client will return this error. | |||||
ErrClosed = errors.New("redis: client is closed") | |||||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool. | |||||
ErrPoolTimeout = errors.New("redis: connection pool timeout") | |||||
) | |||||
var timers = sync.Pool{ | |||||
New: func() interface{} { | |||||
t := time.NewTimer(time.Hour) | |||||
t.Stop() | |||||
return t | |||||
}, | |||||
} | |||||
// Stats contains pool state information and accumulated stats. | |||||
type Stats struct { | |||||
Hits uint32 // number of times free connection was found in the pool | |||||
Misses uint32 // number of times free connection was NOT found in the pool | |||||
Timeouts uint32 // number of times a wait timeout occurred | |||||
TotalConns uint32 // number of total connections in the pool | |||||
IdleConns uint32 // number of idle connections in the pool | |||||
StaleConns uint32 // number of stale connections removed from the pool | |||||
} | |||||
type Pooler interface { | |||||
NewConn(context.Context) (*Conn, error) | |||||
CloseConn(*Conn) error | |||||
Get(context.Context) (*Conn, error) | |||||
Put(context.Context, *Conn) | |||||
Remove(context.Context, *Conn, error) | |||||
Len() int | |||||
IdleLen() int | |||||
Stats() *Stats | |||||
Close() error | |||||
} | |||||
type Options struct { | |||||
Dialer func(context.Context) (net.Conn, error) | |||||
OnClose func(*Conn) error | |||||
PoolFIFO bool | |||||
PoolSize int | |||||
MinIdleConns int | |||||
MaxConnAge time.Duration | |||||
PoolTimeout time.Duration | |||||
IdleTimeout time.Duration | |||||
IdleCheckFrequency time.Duration | |||||
} | |||||
type lastDialErrorWrap struct { | |||||
err error | |||||
} | |||||
type ConnPool struct { | |||||
opt *Options | |||||
dialErrorsNum uint32 // atomic | |||||
lastDialError atomic.Value | |||||
queue chan struct{} | |||||
connsMu sync.Mutex | |||||
conns []*Conn | |||||
idleConns []*Conn | |||||
poolSize int | |||||
idleConnsLen int | |||||
stats Stats | |||||
_closed uint32 // atomic | |||||
closedCh chan struct{} | |||||
} | |||||
var _ Pooler = (*ConnPool)(nil) | |||||
func NewConnPool(opt *Options) *ConnPool { | |||||
p := &ConnPool{ | |||||
opt: opt, | |||||
queue: make(chan struct{}, opt.PoolSize), | |||||
conns: make([]*Conn, 0, opt.PoolSize), | |||||
idleConns: make([]*Conn, 0, opt.PoolSize), | |||||
closedCh: make(chan struct{}), | |||||
} | |||||
p.connsMu.Lock() | |||||
p.checkMinIdleConns() | |||||
p.connsMu.Unlock() | |||||
if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { | |||||
go p.reaper(opt.IdleCheckFrequency) | |||||
} | |||||
return p | |||||
} | |||||
func (p *ConnPool) checkMinIdleConns() { | |||||
if p.opt.MinIdleConns == 0 { | |||||
return | |||||
} | |||||
for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { | |||||
p.poolSize++ | |||||
p.idleConnsLen++ | |||||
go func() { | |||||
err := p.addIdleConn() | |||||
if err != nil && err != ErrClosed { | |||||
p.connsMu.Lock() | |||||
p.poolSize-- | |||||
p.idleConnsLen-- | |||||
p.connsMu.Unlock() | |||||
} | |||||
}() | |||||
} | |||||
} | |||||
func (p *ConnPool) addIdleConn() error { | |||||
cn, err := p.dialConn(context.TODO(), true) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
p.connsMu.Lock() | |||||
defer p.connsMu.Unlock() | |||||
// It is not allowed to add new connections to the closed connection pool. | |||||
if p.closed() { | |||||
_ = cn.Close() | |||||
return ErrClosed | |||||
} | |||||
p.conns = append(p.conns, cn) | |||||
p.idleConns = append(p.idleConns, cn) | |||||
return nil | |||||
} | |||||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||||
return p.newConn(ctx, false) | |||||
} | |||||
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { | |||||
cn, err := p.dialConn(ctx, pooled) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
p.connsMu.Lock() | |||||
defer p.connsMu.Unlock() | |||||
// It is not allowed to add new connections to the closed connection pool. | |||||
if p.closed() { | |||||
_ = cn.Close() | |||||
return nil, ErrClosed | |||||
} | |||||
p.conns = append(p.conns, cn) | |||||
if pooled { | |||||
// If pool is full remove the cn on next Put. | |||||
if p.poolSize >= p.opt.PoolSize { | |||||
cn.pooled = false | |||||
} else { | |||||
p.poolSize++ | |||||
} | |||||
} | |||||
return cn, nil | |||||
} | |||||
func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { | |||||
if p.closed() { | |||||
return nil, ErrClosed | |||||
} | |||||
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { | |||||
return nil, p.getLastDialError() | |||||
} | |||||
netConn, err := p.opt.Dialer(ctx) | |||||
if err != nil { | |||||
p.setLastDialError(err) | |||||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { | |||||
go p.tryDial() | |||||
} | |||||
return nil, err | |||||
} | |||||
cn := NewConn(netConn) | |||||
cn.pooled = pooled | |||||
return cn, nil | |||||
} | |||||
func (p *ConnPool) tryDial() { | |||||
for { | |||||
if p.closed() { | |||||
return | |||||
} | |||||
conn, err := p.opt.Dialer(context.Background()) | |||||
if err != nil { | |||||
p.setLastDialError(err) | |||||
time.Sleep(time.Second) | |||||
continue | |||||
} | |||||
atomic.StoreUint32(&p.dialErrorsNum, 0) | |||||
_ = conn.Close() | |||||
return | |||||
} | |||||
} | |||||
func (p *ConnPool) setLastDialError(err error) { | |||||
p.lastDialError.Store(&lastDialErrorWrap{err: err}) | |||||
} | |||||
func (p *ConnPool) getLastDialError() error { | |||||
err, _ := p.lastDialError.Load().(*lastDialErrorWrap) | |||||
if err != nil { | |||||
return err.err | |||||
} | |||||
return nil | |||||
} | |||||
// Get returns existed connection from the pool or creates a new one. | |||||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { | |||||
if p.closed() { | |||||
return nil, ErrClosed | |||||
} | |||||
if err := p.waitTurn(ctx); err != nil { | |||||
return nil, err | |||||
} | |||||
for { | |||||
p.connsMu.Lock() | |||||
cn, err := p.popIdle() | |||||
p.connsMu.Unlock() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if cn == nil { | |||||
break | |||||
} | |||||
if p.isStaleConn(cn) { | |||||
_ = p.CloseConn(cn) | |||||
continue | |||||
} | |||||
atomic.AddUint32(&p.stats.Hits, 1) | |||||
return cn, nil | |||||
} | |||||
atomic.AddUint32(&p.stats.Misses, 1) | |||||
newcn, err := p.newConn(ctx, true) | |||||
if err != nil { | |||||
p.freeTurn() | |||||
return nil, err | |||||
} | |||||
return newcn, nil | |||||
} | |||||
func (p *ConnPool) getTurn() { | |||||
p.queue <- struct{}{} | |||||
} | |||||
func (p *ConnPool) waitTurn(ctx context.Context) error { | |||||
select { | |||||
case <-ctx.Done(): | |||||
return ctx.Err() | |||||
default: | |||||
} | |||||
select { | |||||
case p.queue <- struct{}{}: | |||||
return nil | |||||
default: | |||||
} | |||||
timer := timers.Get().(*time.Timer) | |||||
timer.Reset(p.opt.PoolTimeout) | |||||
select { | |||||
case <-ctx.Done(): | |||||
if !timer.Stop() { | |||||
<-timer.C | |||||
} | |||||
timers.Put(timer) | |||||
return ctx.Err() | |||||
case p.queue <- struct{}{}: | |||||
if !timer.Stop() { | |||||
<-timer.C | |||||
} | |||||
timers.Put(timer) | |||||
return nil | |||||
case <-timer.C: | |||||
timers.Put(timer) | |||||
atomic.AddUint32(&p.stats.Timeouts, 1) | |||||
return ErrPoolTimeout | |||||
} | |||||
} | |||||
func (p *ConnPool) freeTurn() { | |||||
<-p.queue | |||||
} | |||||
func (p *ConnPool) popIdle() (*Conn, error) { | |||||
if p.closed() { | |||||
return nil, ErrClosed | |||||
} | |||||
n := len(p.idleConns) | |||||
if n == 0 { | |||||
return nil, nil | |||||
} | |||||
var cn *Conn | |||||
if p.opt.PoolFIFO { | |||||
cn = p.idleConns[0] | |||||
copy(p.idleConns, p.idleConns[1:]) | |||||
p.idleConns = p.idleConns[:n-1] | |||||
} else { | |||||
idx := n - 1 | |||||
cn = p.idleConns[idx] | |||||
p.idleConns = p.idleConns[:idx] | |||||
} | |||||
p.idleConnsLen-- | |||||
p.checkMinIdleConns() | |||||
return cn, nil | |||||
} | |||||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) { | |||||
if cn.rd.Buffered() > 0 { | |||||
internal.Logger.Printf(ctx, "Conn has unread data") | |||||
p.Remove(ctx, cn, BadConnError{}) | |||||
return | |||||
} | |||||
if !cn.pooled { | |||||
p.Remove(ctx, cn, nil) | |||||
return | |||||
} | |||||
p.connsMu.Lock() | |||||
p.idleConns = append(p.idleConns, cn) | |||||
p.idleConnsLen++ | |||||
p.connsMu.Unlock() | |||||
p.freeTurn() | |||||
} | |||||
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||||
p.removeConnWithLock(cn) | |||||
p.freeTurn() | |||||
_ = p.closeConn(cn) | |||||
} | |||||
func (p *ConnPool) CloseConn(cn *Conn) error { | |||||
p.removeConnWithLock(cn) | |||||
return p.closeConn(cn) | |||||
} | |||||
func (p *ConnPool) removeConnWithLock(cn *Conn) { | |||||
p.connsMu.Lock() | |||||
p.removeConn(cn) | |||||
p.connsMu.Unlock() | |||||
} | |||||
func (p *ConnPool) removeConn(cn *Conn) { | |||||
for i, c := range p.conns { | |||||
if c == cn { | |||||
p.conns = append(p.conns[:i], p.conns[i+1:]...) | |||||
if cn.pooled { | |||||
p.poolSize-- | |||||
p.checkMinIdleConns() | |||||
} | |||||
return | |||||
} | |||||
} | |||||
} | |||||
func (p *ConnPool) closeConn(cn *Conn) error { | |||||
if p.opt.OnClose != nil { | |||||
_ = p.opt.OnClose(cn) | |||||
} | |||||
return cn.Close() | |||||
} | |||||
// Len returns total number of connections. | |||||
func (p *ConnPool) Len() int { | |||||
p.connsMu.Lock() | |||||
n := len(p.conns) | |||||
p.connsMu.Unlock() | |||||
return n | |||||
} | |||||
// IdleLen returns number of idle connections. | |||||
func (p *ConnPool) IdleLen() int { | |||||
p.connsMu.Lock() | |||||
n := p.idleConnsLen | |||||
p.connsMu.Unlock() | |||||
return n | |||||
} | |||||
func (p *ConnPool) Stats() *Stats { | |||||
idleLen := p.IdleLen() | |||||
return &Stats{ | |||||
Hits: atomic.LoadUint32(&p.stats.Hits), | |||||
Misses: atomic.LoadUint32(&p.stats.Misses), | |||||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts), | |||||
TotalConns: uint32(p.Len()), | |||||
IdleConns: uint32(idleLen), | |||||
StaleConns: atomic.LoadUint32(&p.stats.StaleConns), | |||||
} | |||||
} | |||||
func (p *ConnPool) closed() bool { | |||||
return atomic.LoadUint32(&p._closed) == 1 | |||||
} | |||||
func (p *ConnPool) Filter(fn func(*Conn) bool) error { | |||||
p.connsMu.Lock() | |||||
defer p.connsMu.Unlock() | |||||
var firstErr error | |||||
for _, cn := range p.conns { | |||||
if fn(cn) { | |||||
if err := p.closeConn(cn); err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
} | |||||
return firstErr | |||||
} | |||||
func (p *ConnPool) Close() error { | |||||
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { | |||||
return ErrClosed | |||||
} | |||||
close(p.closedCh) | |||||
var firstErr error | |||||
p.connsMu.Lock() | |||||
for _, cn := range p.conns { | |||||
if err := p.closeConn(cn); err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
p.conns = nil | |||||
p.poolSize = 0 | |||||
p.idleConns = nil | |||||
p.idleConnsLen = 0 | |||||
p.connsMu.Unlock() | |||||
return firstErr | |||||
} | |||||
func (p *ConnPool) reaper(frequency time.Duration) { | |||||
ticker := time.NewTicker(frequency) | |||||
defer ticker.Stop() | |||||
for { | |||||
select { | |||||
case <-ticker.C: | |||||
// It is possible that ticker and closedCh arrive together, | |||||
// and select pseudo-randomly pick ticker case, we double | |||||
// check here to prevent being executed after closed. | |||||
if p.closed() { | |||||
return | |||||
} | |||||
_, err := p.ReapStaleConns() | |||||
if err != nil { | |||||
internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err) | |||||
continue | |||||
} | |||||
case <-p.closedCh: | |||||
return | |||||
} | |||||
} | |||||
} | |||||
func (p *ConnPool) ReapStaleConns() (int, error) { | |||||
var n int | |||||
for { | |||||
p.getTurn() | |||||
p.connsMu.Lock() | |||||
cn := p.reapStaleConn() | |||||
p.connsMu.Unlock() | |||||
p.freeTurn() | |||||
if cn != nil { | |||||
_ = p.closeConn(cn) | |||||
n++ | |||||
} else { | |||||
break | |||||
} | |||||
} | |||||
atomic.AddUint32(&p.stats.StaleConns, uint32(n)) | |||||
return n, nil | |||||
} | |||||
func (p *ConnPool) reapStaleConn() *Conn { | |||||
if len(p.idleConns) == 0 { | |||||
return nil | |||||
} | |||||
cn := p.idleConns[0] | |||||
if !p.isStaleConn(cn) { | |||||
return nil | |||||
} | |||||
p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) | |||||
p.idleConnsLen-- | |||||
p.removeConn(cn) | |||||
return cn | |||||
} | |||||
func (p *ConnPool) isStaleConn(cn *Conn) bool { | |||||
if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { | |||||
return false | |||||
} | |||||
now := time.Now() | |||||
if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { | |||||
return true | |||||
} | |||||
if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { | |||||
return true | |||||
} | |||||
return false | |||||
} |
@@ -0,0 +1,58 @@ | |||||
package pool | |||||
import "context" | |||||
type SingleConnPool struct { | |||||
pool Pooler | |||||
cn *Conn | |||||
stickyErr error | |||||
} | |||||
var _ Pooler = (*SingleConnPool)(nil) | |||||
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { | |||||
return &SingleConnPool{ | |||||
pool: pool, | |||||
cn: cn, | |||||
} | |||||
} | |||||
func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||||
return p.pool.NewConn(ctx) | |||||
} | |||||
func (p *SingleConnPool) CloseConn(cn *Conn) error { | |||||
return p.pool.CloseConn(cn) | |||||
} | |||||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { | |||||
if p.stickyErr != nil { | |||||
return nil, p.stickyErr | |||||
} | |||||
return p.cn, nil | |||||
} | |||||
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} | |||||
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||||
p.cn = nil | |||||
p.stickyErr = reason | |||||
} | |||||
func (p *SingleConnPool) Close() error { | |||||
p.cn = nil | |||||
p.stickyErr = ErrClosed | |||||
return nil | |||||
} | |||||
func (p *SingleConnPool) Len() int { | |||||
return 0 | |||||
} | |||||
func (p *SingleConnPool) IdleLen() int { | |||||
return 0 | |||||
} | |||||
func (p *SingleConnPool) Stats() *Stats { | |||||
return &Stats{} | |||||
} |
@@ -0,0 +1,201 @@ | |||||
package pool | |||||
import ( | |||||
"context" | |||||
"errors" | |||||
"fmt" | |||||
"sync/atomic" | |||||
) | |||||
const ( | |||||
stateDefault = 0 | |||||
stateInited = 1 | |||||
stateClosed = 2 | |||||
) | |||||
type BadConnError struct { | |||||
wrapped error | |||||
} | |||||
var _ error = (*BadConnError)(nil) | |||||
func (e BadConnError) Error() string { | |||||
s := "redis: Conn is in a bad state" | |||||
if e.wrapped != nil { | |||||
s += ": " + e.wrapped.Error() | |||||
} | |||||
return s | |||||
} | |||||
func (e BadConnError) Unwrap() error { | |||||
return e.wrapped | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type StickyConnPool struct { | |||||
pool Pooler | |||||
shared int32 // atomic | |||||
state uint32 // atomic | |||||
ch chan *Conn | |||||
_badConnError atomic.Value | |||||
} | |||||
var _ Pooler = (*StickyConnPool)(nil) | |||||
func NewStickyConnPool(pool Pooler) *StickyConnPool { | |||||
p, ok := pool.(*StickyConnPool) | |||||
if !ok { | |||||
p = &StickyConnPool{ | |||||
pool: pool, | |||||
ch: make(chan *Conn, 1), | |||||
} | |||||
} | |||||
atomic.AddInt32(&p.shared, 1) | |||||
return p | |||||
} | |||||
func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||||
return p.pool.NewConn(ctx) | |||||
} | |||||
func (p *StickyConnPool) CloseConn(cn *Conn) error { | |||||
return p.pool.CloseConn(cn) | |||||
} | |||||
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { | |||||
// In worst case this races with Close which is not a very common operation. | |||||
for i := 0; i < 1000; i++ { | |||||
switch atomic.LoadUint32(&p.state) { | |||||
case stateDefault: | |||||
cn, err := p.pool.Get(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { | |||||
return cn, nil | |||||
} | |||||
p.pool.Remove(ctx, cn, ErrClosed) | |||||
case stateInited: | |||||
if err := p.badConnError(); err != nil { | |||||
return nil, err | |||||
} | |||||
cn, ok := <-p.ch | |||||
if !ok { | |||||
return nil, ErrClosed | |||||
} | |||||
return cn, nil | |||||
case stateClosed: | |||||
return nil, ErrClosed | |||||
default: | |||||
panic("not reached") | |||||
} | |||||
} | |||||
return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop") | |||||
} | |||||
func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { | |||||
defer func() { | |||||
if recover() != nil { | |||||
p.freeConn(ctx, cn) | |||||
} | |||||
}() | |||||
p.ch <- cn | |||||
} | |||||
func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { | |||||
if err := p.badConnError(); err != nil { | |||||
p.pool.Remove(ctx, cn, err) | |||||
} else { | |||||
p.pool.Put(ctx, cn) | |||||
} | |||||
} | |||||
func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||||
defer func() { | |||||
if recover() != nil { | |||||
p.pool.Remove(ctx, cn, ErrClosed) | |||||
} | |||||
}() | |||||
p._badConnError.Store(BadConnError{wrapped: reason}) | |||||
p.ch <- cn | |||||
} | |||||
func (p *StickyConnPool) Close() error { | |||||
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { | |||||
return nil | |||||
} | |||||
for i := 0; i < 1000; i++ { | |||||
state := atomic.LoadUint32(&p.state) | |||||
if state == stateClosed { | |||||
return ErrClosed | |||||
} | |||||
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { | |||||
close(p.ch) | |||||
cn, ok := <-p.ch | |||||
if ok { | |||||
p.freeConn(context.TODO(), cn) | |||||
} | |||||
return nil | |||||
} | |||||
} | |||||
return errors.New("redis: StickyConnPool.Close: infinite loop") | |||||
} | |||||
func (p *StickyConnPool) Reset(ctx context.Context) error { | |||||
if p.badConnError() == nil { | |||||
return nil | |||||
} | |||||
select { | |||||
case cn, ok := <-p.ch: | |||||
if !ok { | |||||
return ErrClosed | |||||
} | |||||
p.pool.Remove(ctx, cn, ErrClosed) | |||||
p._badConnError.Store(BadConnError{wrapped: nil}) | |||||
default: | |||||
return errors.New("redis: StickyConnPool does not have a Conn") | |||||
} | |||||
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { | |||||
state := atomic.LoadUint32(&p.state) | |||||
return fmt.Errorf("redis: invalid StickyConnPool state: %d", state) | |||||
} | |||||
return nil | |||||
} | |||||
func (p *StickyConnPool) badConnError() error { | |||||
if v := p._badConnError.Load(); v != nil { | |||||
if err := v.(BadConnError); err.wrapped != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func (p *StickyConnPool) Len() int { | |||||
switch atomic.LoadUint32(&p.state) { | |||||
case stateDefault: | |||||
return 0 | |||||
case stateInited: | |||||
return 1 | |||||
case stateClosed: | |||||
return 0 | |||||
default: | |||||
panic("not reached") | |||||
} | |||||
} | |||||
func (p *StickyConnPool) IdleLen() int { | |||||
return len(p.ch) | |||||
} | |||||
func (p *StickyConnPool) Stats() *Stats { | |||||
return &Stats{} | |||||
} |
@@ -0,0 +1,332 @@ | |||||
package proto | |||||
import ( | |||||
"bufio" | |||||
"fmt" | |||||
"io" | |||||
"github.com/go-redis/redis/v8/internal/util" | |||||
) | |||||
// redis resp protocol data type. | |||||
const ( | |||||
ErrorReply = '-' | |||||
StatusReply = '+' | |||||
IntReply = ':' | |||||
StringReply = '$' | |||||
ArrayReply = '*' | |||||
) | |||||
//------------------------------------------------------------------------------ | |||||
const Nil = RedisError("redis: nil") // nolint:errname | |||||
type RedisError string | |||||
func (e RedisError) Error() string { return string(e) } | |||||
func (RedisError) RedisError() {} | |||||
//------------------------------------------------------------------------------ | |||||
type MultiBulkParse func(*Reader, int64) (interface{}, error) | |||||
type Reader struct { | |||||
rd *bufio.Reader | |||||
_buf []byte | |||||
} | |||||
func NewReader(rd io.Reader) *Reader { | |||||
return &Reader{ | |||||
rd: bufio.NewReader(rd), | |||||
_buf: make([]byte, 64), | |||||
} | |||||
} | |||||
func (r *Reader) Buffered() int { | |||||
return r.rd.Buffered() | |||||
} | |||||
func (r *Reader) Peek(n int) ([]byte, error) { | |||||
return r.rd.Peek(n) | |||||
} | |||||
func (r *Reader) Reset(rd io.Reader) { | |||||
r.rd.Reset(rd) | |||||
} | |||||
func (r *Reader) ReadLine() ([]byte, error) { | |||||
line, err := r.readLine() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if isNilReply(line) { | |||||
return nil, Nil | |||||
} | |||||
return line, nil | |||||
} | |||||
// readLine that returns an error if: | |||||
// - there is a pending read error; | |||||
// - or line does not end with \r\n. | |||||
func (r *Reader) readLine() ([]byte, error) { | |||||
b, err := r.rd.ReadSlice('\n') | |||||
if err != nil { | |||||
if err != bufio.ErrBufferFull { | |||||
return nil, err | |||||
} | |||||
full := make([]byte, len(b)) | |||||
copy(full, b) | |||||
b, err = r.rd.ReadBytes('\n') | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
full = append(full, b...) //nolint:makezero | |||||
b = full | |||||
} | |||||
if len(b) <= 2 || b[len(b)-1] != '\n' || b[len(b)-2] != '\r' { | |||||
return nil, fmt.Errorf("redis: invalid reply: %q", b) | |||||
} | |||||
return b[:len(b)-2], nil | |||||
} | |||||
func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return nil, ParseErrorReply(line) | |||||
case StatusReply: | |||||
return string(line[1:]), nil | |||||
case IntReply: | |||||
return util.ParseInt(line[1:], 10, 64) | |||||
case StringReply: | |||||
return r.readStringReply(line) | |||||
case ArrayReply: | |||||
n, err := parseArrayLen(line) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if m == nil { | |||||
err := fmt.Errorf("redis: got %.100q, but multi bulk parser is nil", line) | |||||
return nil, err | |||||
} | |||||
return m(r, n) | |||||
} | |||||
return nil, fmt.Errorf("redis: can't parse %.100q", line) | |||||
} | |||||
func (r *Reader) ReadIntReply() (int64, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return 0, ParseErrorReply(line) | |||||
case IntReply: | |||||
return util.ParseInt(line[1:], 10, 64) | |||||
default: | |||||
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) | |||||
} | |||||
} | |||||
func (r *Reader) ReadString() (string, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return "", ParseErrorReply(line) | |||||
case StringReply: | |||||
return r.readStringReply(line) | |||||
case StatusReply: | |||||
return string(line[1:]), nil | |||||
case IntReply: | |||||
return string(line[1:]), nil | |||||
default: | |||||
return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) | |||||
} | |||||
} | |||||
func (r *Reader) readStringReply(line []byte) (string, error) { | |||||
if isNilReply(line) { | |||||
return "", Nil | |||||
} | |||||
replyLen, err := util.Atoi(line[1:]) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
b := make([]byte, replyLen+2) | |||||
_, err = io.ReadFull(r.rd, b) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
return util.BytesToString(b[:replyLen]), nil | |||||
} | |||||
func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return nil, ParseErrorReply(line) | |||||
case ArrayReply: | |||||
n, err := parseArrayLen(line) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return m(r, n) | |||||
default: | |||||
return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) | |||||
} | |||||
} | |||||
func (r *Reader) ReadArrayLen() (int, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return 0, ParseErrorReply(line) | |||||
case ArrayReply: | |||||
n, err := parseArrayLen(line) | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
return int(n), nil | |||||
default: | |||||
return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) | |||||
} | |||||
} | |||||
func (r *Reader) ReadScanReply() ([]string, uint64, error) { | |||||
n, err := r.ReadArrayLen() | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
if n != 2 { | |||||
return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n) | |||||
} | |||||
cursor, err := r.ReadUint() | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
n, err = r.ReadArrayLen() | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
keys := make([]string, n) | |||||
for i := 0; i < n; i++ { | |||||
key, err := r.ReadString() | |||||
if err != nil { | |||||
return nil, 0, err | |||||
} | |||||
keys[i] = key | |||||
} | |||||
return keys, cursor, err | |||||
} | |||||
func (r *Reader) ReadInt() (int64, error) { | |||||
b, err := r.readTmpBytesReply() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
return util.ParseInt(b, 10, 64) | |||||
} | |||||
func (r *Reader) ReadUint() (uint64, error) { | |||||
b, err := r.readTmpBytesReply() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
return util.ParseUint(b, 10, 64) | |||||
} | |||||
func (r *Reader) ReadFloatReply() (float64, error) { | |||||
b, err := r.readTmpBytesReply() | |||||
if err != nil { | |||||
return 0, err | |||||
} | |||||
return util.ParseFloat(b, 64) | |||||
} | |||||
func (r *Reader) readTmpBytesReply() ([]byte, error) { | |||||
line, err := r.ReadLine() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
switch line[0] { | |||||
case ErrorReply: | |||||
return nil, ParseErrorReply(line) | |||||
case StringReply: | |||||
return r._readTmpBytesReply(line) | |||||
case StatusReply: | |||||
return line[1:], nil | |||||
default: | |||||
return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) | |||||
} | |||||
} | |||||
func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) { | |||||
if isNilReply(line) { | |||||
return nil, Nil | |||||
} | |||||
replyLen, err := util.Atoi(line[1:]) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
buf := r.buf(replyLen + 2) | |||||
_, err = io.ReadFull(r.rd, buf) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return buf[:replyLen], nil | |||||
} | |||||
func (r *Reader) buf(n int) []byte { | |||||
if n <= cap(r._buf) { | |||||
return r._buf[:n] | |||||
} | |||||
d := n - cap(r._buf) | |||||
r._buf = append(r._buf, make([]byte, d)...) | |||||
return r._buf | |||||
} | |||||
func isNilReply(b []byte) bool { | |||||
return len(b) == 3 && | |||||
(b[0] == StringReply || b[0] == ArrayReply) && | |||||
b[1] == '-' && b[2] == '1' | |||||
} | |||||
func ParseErrorReply(line []byte) error { | |||||
return RedisError(string(line[1:])) | |||||
} | |||||
func parseArrayLen(line []byte) (int64, error) { | |||||
if isNilReply(line) { | |||||
return 0, Nil | |||||
} | |||||
return util.ParseInt(line[1:], 10, 64) | |||||
} |
@@ -0,0 +1,172 @@ | |||||
package proto | |||||
import ( | |||||
"encoding" | |||||
"fmt" | |||||
"reflect" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/util" | |||||
) | |||||
// Scan parses bytes `b` to `v` with appropriate type. | |||||
func Scan(b []byte, v interface{}) error { | |||||
switch v := v.(type) { | |||||
case nil: | |||||
return fmt.Errorf("redis: Scan(nil)") | |||||
case *string: | |||||
*v = util.BytesToString(b) | |||||
return nil | |||||
case *[]byte: | |||||
*v = b | |||||
return nil | |||||
case *int: | |||||
var err error | |||||
*v, err = util.Atoi(b) | |||||
return err | |||||
case *int8: | |||||
n, err := util.ParseInt(b, 10, 8) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = int8(n) | |||||
return nil | |||||
case *int16: | |||||
n, err := util.ParseInt(b, 10, 16) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = int16(n) | |||||
return nil | |||||
case *int32: | |||||
n, err := util.ParseInt(b, 10, 32) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = int32(n) | |||||
return nil | |||||
case *int64: | |||||
n, err := util.ParseInt(b, 10, 64) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = n | |||||
return nil | |||||
case *uint: | |||||
n, err := util.ParseUint(b, 10, 64) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = uint(n) | |||||
return nil | |||||
case *uint8: | |||||
n, err := util.ParseUint(b, 10, 8) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = uint8(n) | |||||
return nil | |||||
case *uint16: | |||||
n, err := util.ParseUint(b, 10, 16) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = uint16(n) | |||||
return nil | |||||
case *uint32: | |||||
n, err := util.ParseUint(b, 10, 32) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = uint32(n) | |||||
return nil | |||||
case *uint64: | |||||
n, err := util.ParseUint(b, 10, 64) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = n | |||||
return nil | |||||
case *float32: | |||||
n, err := util.ParseFloat(b, 32) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*v = float32(n) | |||||
return err | |||||
case *float64: | |||||
var err error | |||||
*v, err = util.ParseFloat(b, 64) | |||||
return err | |||||
case *bool: | |||||
*v = len(b) == 1 && b[0] == '1' | |||||
return nil | |||||
case *time.Time: | |||||
var err error | |||||
*v, err = time.Parse(time.RFC3339Nano, util.BytesToString(b)) | |||||
return err | |||||
case encoding.BinaryUnmarshaler: | |||||
return v.UnmarshalBinary(b) | |||||
default: | |||||
return fmt.Errorf( | |||||
"redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v) | |||||
} | |||||
} | |||||
func ScanSlice(data []string, slice interface{}) error { | |||||
v := reflect.ValueOf(slice) | |||||
if !v.IsValid() { | |||||
return fmt.Errorf("redis: ScanSlice(nil)") | |||||
} | |||||
if v.Kind() != reflect.Ptr { | |||||
return fmt.Errorf("redis: ScanSlice(non-pointer %T)", slice) | |||||
} | |||||
v = v.Elem() | |||||
if v.Kind() != reflect.Slice { | |||||
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice) | |||||
} | |||||
next := makeSliceNextElemFunc(v) | |||||
for i, s := range data { | |||||
elem := next() | |||||
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil { | |||||
err = fmt.Errorf("redis: ScanSlice index=%d value=%q failed: %w", i, s, err) | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { | |||||
elemType := v.Type().Elem() | |||||
if elemType.Kind() == reflect.Ptr { | |||||
elemType = elemType.Elem() | |||||
return func() reflect.Value { | |||||
if v.Len() < v.Cap() { | |||||
v.Set(v.Slice(0, v.Len()+1)) | |||||
elem := v.Index(v.Len() - 1) | |||||
if elem.IsNil() { | |||||
elem.Set(reflect.New(elemType)) | |||||
} | |||||
return elem.Elem() | |||||
} | |||||
elem := reflect.New(elemType) | |||||
v.Set(reflect.Append(v, elem)) | |||||
return elem.Elem() | |||||
} | |||||
} | |||||
zero := reflect.Zero(elemType) | |||||
return func() reflect.Value { | |||||
if v.Len() < v.Cap() { | |||||
v.Set(v.Slice(0, v.Len()+1)) | |||||
return v.Index(v.Len() - 1) | |||||
} | |||||
v.Set(reflect.Append(v, zero)) | |||||
return v.Index(v.Len() - 1) | |||||
} | |||||
} |
@@ -0,0 +1,153 @@ | |||||
package proto | |||||
import ( | |||||
"encoding" | |||||
"fmt" | |||||
"io" | |||||
"strconv" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/util" | |||||
) | |||||
type writer interface { | |||||
io.Writer | |||||
io.ByteWriter | |||||
// io.StringWriter | |||||
WriteString(s string) (n int, err error) | |||||
} | |||||
type Writer struct { | |||||
writer | |||||
lenBuf []byte | |||||
numBuf []byte | |||||
} | |||||
func NewWriter(wr writer) *Writer { | |||||
return &Writer{ | |||||
writer: wr, | |||||
lenBuf: make([]byte, 64), | |||||
numBuf: make([]byte, 64), | |||||
} | |||||
} | |||||
func (w *Writer) WriteArgs(args []interface{}) error { | |||||
if err := w.WriteByte(ArrayReply); err != nil { | |||||
return err | |||||
} | |||||
if err := w.writeLen(len(args)); err != nil { | |||||
return err | |||||
} | |||||
for _, arg := range args { | |||||
if err := w.WriteArg(arg); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func (w *Writer) writeLen(n int) error { | |||||
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10) | |||||
w.lenBuf = append(w.lenBuf, '\r', '\n') | |||||
_, err := w.Write(w.lenBuf) | |||||
return err | |||||
} | |||||
func (w *Writer) WriteArg(v interface{}) error { | |||||
switch v := v.(type) { | |||||
case nil: | |||||
return w.string("") | |||||
case string: | |||||
return w.string(v) | |||||
case []byte: | |||||
return w.bytes(v) | |||||
case int: | |||||
return w.int(int64(v)) | |||||
case int8: | |||||
return w.int(int64(v)) | |||||
case int16: | |||||
return w.int(int64(v)) | |||||
case int32: | |||||
return w.int(int64(v)) | |||||
case int64: | |||||
return w.int(v) | |||||
case uint: | |||||
return w.uint(uint64(v)) | |||||
case uint8: | |||||
return w.uint(uint64(v)) | |||||
case uint16: | |||||
return w.uint(uint64(v)) | |||||
case uint32: | |||||
return w.uint(uint64(v)) | |||||
case uint64: | |||||
return w.uint(v) | |||||
case float32: | |||||
return w.float(float64(v)) | |||||
case float64: | |||||
return w.float(v) | |||||
case bool: | |||||
if v { | |||||
return w.int(1) | |||||
} | |||||
return w.int(0) | |||||
case time.Time: | |||||
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) | |||||
return w.bytes(w.numBuf) | |||||
case encoding.BinaryMarshaler: | |||||
b, err := v.MarshalBinary() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return w.bytes(b) | |||||
default: | |||||
return fmt.Errorf( | |||||
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) | |||||
} | |||||
} | |||||
func (w *Writer) bytes(b []byte) error { | |||||
if err := w.WriteByte(StringReply); err != nil { | |||||
return err | |||||
} | |||||
if err := w.writeLen(len(b)); err != nil { | |||||
return err | |||||
} | |||||
if _, err := w.Write(b); err != nil { | |||||
return err | |||||
} | |||||
return w.crlf() | |||||
} | |||||
func (w *Writer) string(s string) error { | |||||
return w.bytes(util.StringToBytes(s)) | |||||
} | |||||
func (w *Writer) uint(n uint64) error { | |||||
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10) | |||||
return w.bytes(w.numBuf) | |||||
} | |||||
func (w *Writer) int(n int64) error { | |||||
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10) | |||||
return w.bytes(w.numBuf) | |||||
} | |||||
func (w *Writer) float(f float64) error { | |||||
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64) | |||||
return w.bytes(w.numBuf) | |||||
} | |||||
func (w *Writer) crlf() error { | |||||
if err := w.WriteByte('\r'); err != nil { | |||||
return err | |||||
} | |||||
return w.WriteByte('\n') | |||||
} |
@@ -0,0 +1,50 @@ | |||||
package rand | |||||
import ( | |||||
"math/rand" | |||||
"sync" | |||||
) | |||||
// Int returns a non-negative pseudo-random int. | |||||
func Int() int { return pseudo.Int() } | |||||
// Intn returns, as an int, a non-negative pseudo-random number in [0,n). | |||||
// It panics if n <= 0. | |||||
func Intn(n int) int { return pseudo.Intn(n) } | |||||
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n). | |||||
// It panics if n <= 0. | |||||
func Int63n(n int64) int64 { return pseudo.Int63n(n) } | |||||
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). | |||||
func Perm(n int) []int { return pseudo.Perm(n) } | |||||
// Seed uses the provided seed value to initialize the default Source to a | |||||
// deterministic state. If Seed is not called, the generator behaves as if | |||||
// seeded by Seed(1). | |||||
func Seed(n int64) { pseudo.Seed(n) } | |||||
var pseudo = rand.New(&source{src: rand.NewSource(1)}) | |||||
type source struct { | |||||
src rand.Source | |||||
mu sync.Mutex | |||||
} | |||||
func (s *source) Int63() int64 { | |||||
s.mu.Lock() | |||||
n := s.src.Int63() | |||||
s.mu.Unlock() | |||||
return n | |||||
} | |||||
func (s *source) Seed(seed int64) { | |||||
s.mu.Lock() | |||||
s.src.Seed(seed) | |||||
s.mu.Unlock() | |||||
} | |||||
// Shuffle pseudo-randomizes the order of elements. | |||||
// n is the number of elements. | |||||
// swap swaps the elements with indexes i and j. | |||||
func Shuffle(n int, swap func(i, j int)) { pseudo.Shuffle(n, swap) } |
@@ -0,0 +1,12 @@ | |||||
//go:build appengine | |||||
// +build appengine | |||||
package internal | |||||
func String(b []byte) string { | |||||
return string(b) | |||||
} | |||||
func Bytes(s string) []byte { | |||||
return []byte(s) | |||||
} |
@@ -0,0 +1,21 @@ | |||||
//go:build !appengine | |||||
// +build !appengine | |||||
package internal | |||||
import "unsafe" | |||||
// String converts byte slice to string. | |||||
func String(b []byte) string { | |||||
return *(*string)(unsafe.Pointer(&b)) | |||||
} | |||||
// Bytes converts string to byte slice. | |||||
func Bytes(s string) []byte { | |||||
return *(*[]byte)(unsafe.Pointer( | |||||
&struct { | |||||
string | |||||
Cap int | |||||
}{s, len(s)}, | |||||
)) | |||||
} |
@@ -0,0 +1,46 @@ | |||||
package internal | |||||
import ( | |||||
"context" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/util" | |||||
) | |||||
func Sleep(ctx context.Context, dur time.Duration) error { | |||||
t := time.NewTimer(dur) | |||||
defer t.Stop() | |||||
select { | |||||
case <-t.C: | |||||
return nil | |||||
case <-ctx.Done(): | |||||
return ctx.Err() | |||||
} | |||||
} | |||||
func ToLower(s string) string { | |||||
if isLower(s) { | |||||
return s | |||||
} | |||||
b := make([]byte, len(s)) | |||||
for i := range b { | |||||
c := s[i] | |||||
if c >= 'A' && c <= 'Z' { | |||||
c += 'a' - 'A' | |||||
} | |||||
b[i] = c | |||||
} | |||||
return util.BytesToString(b) | |||||
} | |||||
func isLower(s string) bool { | |||||
for i := 0; i < len(s); i++ { | |||||
c := s[i] | |||||
if c >= 'A' && c <= 'Z' { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} |
@@ -0,0 +1,12 @@ | |||||
//go:build appengine | |||||
// +build appengine | |||||
package util | |||||
func BytesToString(b []byte) string { | |||||
return string(b) | |||||
} | |||||
func StringToBytes(s string) []byte { | |||||
return []byte(s) | |||||
} |
@@ -0,0 +1,19 @@ | |||||
package util | |||||
import "strconv" | |||||
func Atoi(b []byte) (int, error) { | |||||
return strconv.Atoi(BytesToString(b)) | |||||
} | |||||
func ParseInt(b []byte, base int, bitSize int) (int64, error) { | |||||
return strconv.ParseInt(BytesToString(b), base, bitSize) | |||||
} | |||||
func ParseUint(b []byte, base int, bitSize int) (uint64, error) { | |||||
return strconv.ParseUint(BytesToString(b), base, bitSize) | |||||
} | |||||
func ParseFloat(b []byte, bitSize int) (float64, error) { | |||||
return strconv.ParseFloat(BytesToString(b), bitSize) | |||||
} |
@@ -0,0 +1,23 @@ | |||||
//go:build !appengine | |||||
// +build !appengine | |||||
package util | |||||
import ( | |||||
"unsafe" | |||||
) | |||||
// BytesToString converts byte slice to string. | |||||
func BytesToString(b []byte) string { | |||||
return *(*string)(unsafe.Pointer(&b)) | |||||
} | |||||
// StringToBytes converts string to byte slice. | |||||
func StringToBytes(s string) []byte { | |||||
return *(*[]byte)(unsafe.Pointer( | |||||
&struct { | |||||
string | |||||
Cap int | |||||
}{s, len(s)}, | |||||
)) | |||||
} |
@@ -0,0 +1,77 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"sync" | |||||
) | |||||
// ScanIterator is used to incrementally iterate over a collection of elements. | |||||
// It's safe for concurrent use by multiple goroutines. | |||||
type ScanIterator struct { | |||||
mu sync.Mutex // protects Scanner and pos | |||||
cmd *ScanCmd | |||||
pos int | |||||
} | |||||
// Err returns the last iterator error, if any. | |||||
func (it *ScanIterator) Err() error { | |||||
it.mu.Lock() | |||||
err := it.cmd.Err() | |||||
it.mu.Unlock() | |||||
return err | |||||
} | |||||
// Next advances the cursor and returns true if more values can be read. | |||||
func (it *ScanIterator) Next(ctx context.Context) bool { | |||||
it.mu.Lock() | |||||
defer it.mu.Unlock() | |||||
// Instantly return on errors. | |||||
if it.cmd.Err() != nil { | |||||
return false | |||||
} | |||||
// Advance cursor, check if we are still within range. | |||||
if it.pos < len(it.cmd.page) { | |||||
it.pos++ | |||||
return true | |||||
} | |||||
for { | |||||
// Return if there is no more data to fetch. | |||||
if it.cmd.cursor == 0 { | |||||
return false | |||||
} | |||||
// Fetch next page. | |||||
switch it.cmd.args[0] { | |||||
case "scan", "qscan": | |||||
it.cmd.args[1] = it.cmd.cursor | |||||
default: | |||||
it.cmd.args[2] = it.cmd.cursor | |||||
} | |||||
err := it.cmd.process(ctx, it.cmd) | |||||
if err != nil { | |||||
return false | |||||
} | |||||
it.pos = 1 | |||||
// Redis can occasionally return empty page. | |||||
if len(it.cmd.page) > 0 { | |||||
return true | |||||
} | |||||
} | |||||
} | |||||
// Val returns the key/field at the current cursor position. | |||||
func (it *ScanIterator) Val() string { | |||||
var v string | |||||
it.mu.Lock() | |||||
if it.cmd.Err() == nil && it.pos > 0 && it.pos <= len(it.cmd.page) { | |||||
v = it.cmd.page[it.pos-1] | |||||
} | |||||
it.mu.Unlock() | |||||
return v | |||||
} |
@@ -0,0 +1,429 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"crypto/tls" | |||||
"errors" | |||||
"fmt" | |||||
"net" | |||||
"net/url" | |||||
"runtime" | |||||
"sort" | |||||
"strconv" | |||||
"strings" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
) | |||||
// Limiter is the interface of a rate limiter or a circuit breaker. | |||||
type Limiter interface { | |||||
// Allow returns nil if operation is allowed or an error otherwise. | |||||
// If operation is allowed client must ReportResult of the operation | |||||
// whether it is a success or a failure. | |||||
Allow() error | |||||
// ReportResult reports the result of the previously allowed operation. | |||||
// nil indicates a success, non-nil error usually indicates a failure. | |||||
ReportResult(result error) | |||||
} | |||||
// Options keeps the settings to setup redis connection. | |||||
type Options struct { | |||||
// The network type, either tcp or unix. | |||||
// Default is tcp. | |||||
Network string | |||||
// host:port address. | |||||
Addr string | |||||
// Dialer creates new network connection and has priority over | |||||
// Network and Addr options. | |||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||||
// Hook that is called when new connection is established. | |||||
OnConnect func(ctx context.Context, cn *Conn) error | |||||
// Use the specified Username to authenticate the current connection | |||||
// with one of the connections defined in the ACL list when connecting | |||||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system. | |||||
Username string | |||||
// Optional password. Must match the password specified in the | |||||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), | |||||
// or the User Password when connecting to a Redis 6.0 instance, or greater, | |||||
// that is using the Redis ACL system. | |||||
Password string | |||||
// Database to be selected after connecting to the server. | |||||
DB int | |||||
// Maximum number of retries before giving up. | |||||
// Default is 3 retries; -1 (not 0) disables retries. | |||||
MaxRetries int | |||||
// Minimum backoff between each retry. | |||||
// Default is 8 milliseconds; -1 disables backoff. | |||||
MinRetryBackoff time.Duration | |||||
// Maximum backoff between each retry. | |||||
// Default is 512 milliseconds; -1 disables backoff. | |||||
MaxRetryBackoff time.Duration | |||||
// Dial timeout for establishing new connections. | |||||
// Default is 5 seconds. | |||||
DialTimeout time.Duration | |||||
// Timeout for socket reads. If reached, commands will fail | |||||
// with a timeout instead of blocking. Use value -1 for no timeout and 0 for default. | |||||
// Default is 3 seconds. | |||||
ReadTimeout time.Duration | |||||
// Timeout for socket writes. If reached, commands will fail | |||||
// with a timeout instead of blocking. | |||||
// Default is ReadTimeout. | |||||
WriteTimeout time.Duration | |||||
// Type of connection pool. | |||||
// true for FIFO pool, false for LIFO pool. | |||||
// Note that fifo has higher overhead compared to lifo. | |||||
PoolFIFO bool | |||||
// Maximum number of socket connections. | |||||
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. | |||||
PoolSize int | |||||
// Minimum number of idle connections which is useful when establishing | |||||
// new connection is slow. | |||||
MinIdleConns int | |||||
// Connection age at which client retires (closes) the connection. | |||||
// Default is to not close aged connections. | |||||
MaxConnAge time.Duration | |||||
// Amount of time client waits for connection if all connections | |||||
// are busy before returning an error. | |||||
// Default is ReadTimeout + 1 second. | |||||
PoolTimeout time.Duration | |||||
// Amount of time after which client closes idle connections. | |||||
// Should be less than server's timeout. | |||||
// Default is 5 minutes. -1 disables idle timeout check. | |||||
IdleTimeout time.Duration | |||||
// Frequency of idle checks made by idle connections reaper. | |||||
// Default is 1 minute. -1 disables idle connections reaper, | |||||
// but idle connections are still discarded by the client | |||||
// if IdleTimeout is set. | |||||
IdleCheckFrequency time.Duration | |||||
// Enables read only queries on slave nodes. | |||||
readOnly bool | |||||
// TLS Config to use. When set TLS will be negotiated. | |||||
TLSConfig *tls.Config | |||||
// Limiter interface used to implemented circuit breaker or rate limiter. | |||||
Limiter Limiter | |||||
} | |||||
func (opt *Options) init() { | |||||
if opt.Addr == "" { | |||||
opt.Addr = "localhost:6379" | |||||
} | |||||
if opt.Network == "" { | |||||
if strings.HasPrefix(opt.Addr, "/") { | |||||
opt.Network = "unix" | |||||
} else { | |||||
opt.Network = "tcp" | |||||
} | |||||
} | |||||
if opt.DialTimeout == 0 { | |||||
opt.DialTimeout = 5 * time.Second | |||||
} | |||||
if opt.Dialer == nil { | |||||
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { | |||||
netDialer := &net.Dialer{ | |||||
Timeout: opt.DialTimeout, | |||||
KeepAlive: 5 * time.Minute, | |||||
} | |||||
if opt.TLSConfig == nil { | |||||
return netDialer.DialContext(ctx, network, addr) | |||||
} | |||||
return tls.DialWithDialer(netDialer, network, addr, opt.TLSConfig) | |||||
} | |||||
} | |||||
if opt.PoolSize == 0 { | |||||
opt.PoolSize = 10 * runtime.GOMAXPROCS(0) | |||||
} | |||||
switch opt.ReadTimeout { | |||||
case -1: | |||||
opt.ReadTimeout = 0 | |||||
case 0: | |||||
opt.ReadTimeout = 3 * time.Second | |||||
} | |||||
switch opt.WriteTimeout { | |||||
case -1: | |||||
opt.WriteTimeout = 0 | |||||
case 0: | |||||
opt.WriteTimeout = opt.ReadTimeout | |||||
} | |||||
if opt.PoolTimeout == 0 { | |||||
opt.PoolTimeout = opt.ReadTimeout + time.Second | |||||
} | |||||
if opt.IdleTimeout == 0 { | |||||
opt.IdleTimeout = 5 * time.Minute | |||||
} | |||||
if opt.IdleCheckFrequency == 0 { | |||||
opt.IdleCheckFrequency = time.Minute | |||||
} | |||||
if opt.MaxRetries == -1 { | |||||
opt.MaxRetries = 0 | |||||
} else if opt.MaxRetries == 0 { | |||||
opt.MaxRetries = 3 | |||||
} | |||||
switch opt.MinRetryBackoff { | |||||
case -1: | |||||
opt.MinRetryBackoff = 0 | |||||
case 0: | |||||
opt.MinRetryBackoff = 8 * time.Millisecond | |||||
} | |||||
switch opt.MaxRetryBackoff { | |||||
case -1: | |||||
opt.MaxRetryBackoff = 0 | |||||
case 0: | |||||
opt.MaxRetryBackoff = 512 * time.Millisecond | |||||
} | |||||
} | |||||
func (opt *Options) clone() *Options { | |||||
clone := *opt | |||||
return &clone | |||||
} | |||||
// ParseURL parses an URL into Options that can be used to connect to Redis. | |||||
// Scheme is required. | |||||
// There are two connection types: by tcp socket and by unix socket. | |||||
// Tcp connection: | |||||
// redis://<user>:<password>@<host>:<port>/<db_number> | |||||
// Unix connection: | |||||
// unix://<user>:<password>@</path/to/redis.sock>?db=<db_number> | |||||
// Most Option fields can be set using query parameters, with the following restrictions: | |||||
// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries | |||||
// - only scalar type fields are supported (bool, int, time.Duration) | |||||
// - for time.Duration fields, values must be a valid input for time.ParseDuration(); | |||||
// additionally a plain integer as value (i.e. without unit) is intepreted as seconds | |||||
// - to disable a duration field, use value less than or equal to 0; to use the default | |||||
// value, leave the value blank or remove the parameter | |||||
// - only the last value is interpreted if a parameter is given multiple times | |||||
// - fields "network", "addr", "username" and "password" can only be set using other | |||||
// URL attributes (scheme, host, userinfo, resp.), query paremeters using these | |||||
// names will be treated as unknown parameters | |||||
// - unknown parameter names will result in an error | |||||
// Examples: | |||||
// redis://user:password@localhost:6789/3?dial_timeout=3&db=1&read_timeout=6s&max_retries=2 | |||||
// is equivalent to: | |||||
// &Options{ | |||||
// Network: "tcp", | |||||
// Addr: "localhost:6789", | |||||
// DB: 1, // path "/3" was overridden by "&db=1" | |||||
// DialTimeout: 3 * time.Second, // no time unit = seconds | |||||
// ReadTimeout: 6 * time.Second, | |||||
// MaxRetries: 2, | |||||
// } | |||||
func ParseURL(redisURL string) (*Options, error) { | |||||
u, err := url.Parse(redisURL) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
switch u.Scheme { | |||||
case "redis", "rediss": | |||||
return setupTCPConn(u) | |||||
case "unix": | |||||
return setupUnixConn(u) | |||||
default: | |||||
return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme) | |||||
} | |||||
} | |||||
func setupTCPConn(u *url.URL) (*Options, error) { | |||||
o := &Options{Network: "tcp"} | |||||
o.Username, o.Password = getUserPassword(u) | |||||
h, p, err := net.SplitHostPort(u.Host) | |||||
if err != nil { | |||||
h = u.Host | |||||
} | |||||
if h == "" { | |||||
h = "localhost" | |||||
} | |||||
if p == "" { | |||||
p = "6379" | |||||
} | |||||
o.Addr = net.JoinHostPort(h, p) | |||||
f := strings.FieldsFunc(u.Path, func(r rune) bool { | |||||
return r == '/' | |||||
}) | |||||
switch len(f) { | |||||
case 0: | |||||
o.DB = 0 | |||||
case 1: | |||||
if o.DB, err = strconv.Atoi(f[0]); err != nil { | |||||
return nil, fmt.Errorf("redis: invalid database number: %q", f[0]) | |||||
} | |||||
default: | |||||
return nil, fmt.Errorf("redis: invalid URL path: %s", u.Path) | |||||
} | |||||
if u.Scheme == "rediss" { | |||||
o.TLSConfig = &tls.Config{ServerName: h} | |||||
} | |||||
return setupConnParams(u, o) | |||||
} | |||||
func setupUnixConn(u *url.URL) (*Options, error) { | |||||
o := &Options{ | |||||
Network: "unix", | |||||
} | |||||
if strings.TrimSpace(u.Path) == "" { // path is required with unix connection | |||||
return nil, errors.New("redis: empty unix socket path") | |||||
} | |||||
o.Addr = u.Path | |||||
o.Username, o.Password = getUserPassword(u) | |||||
return setupConnParams(u, o) | |||||
} | |||||
type queryOptions struct { | |||||
q url.Values | |||||
err error | |||||
} | |||||
func (o *queryOptions) string(name string) string { | |||||
vs := o.q[name] | |||||
if len(vs) == 0 { | |||||
return "" | |||||
} | |||||
delete(o.q, name) // enable detection of unknown parameters | |||||
return vs[len(vs)-1] | |||||
} | |||||
func (o *queryOptions) int(name string) int { | |||||
s := o.string(name) | |||||
if s == "" { | |||||
return 0 | |||||
} | |||||
i, err := strconv.Atoi(s) | |||||
if err == nil { | |||||
return i | |||||
} | |||||
if o.err == nil { | |||||
o.err = fmt.Errorf("redis: invalid %s number: %s", name, err) | |||||
} | |||||
return 0 | |||||
} | |||||
func (o *queryOptions) duration(name string) time.Duration { | |||||
s := o.string(name) | |||||
if s == "" { | |||||
return 0 | |||||
} | |||||
// try plain number first | |||||
if i, err := strconv.Atoi(s); err == nil { | |||||
if i <= 0 { | |||||
// disable timeouts | |||||
return -1 | |||||
} | |||||
return time.Duration(i) * time.Second | |||||
} | |||||
dur, err := time.ParseDuration(s) | |||||
if err == nil { | |||||
return dur | |||||
} | |||||
if o.err == nil { | |||||
o.err = fmt.Errorf("redis: invalid %s duration: %w", name, err) | |||||
} | |||||
return 0 | |||||
} | |||||
func (o *queryOptions) bool(name string) bool { | |||||
switch s := o.string(name); s { | |||||
case "true", "1": | |||||
return true | |||||
case "false", "0", "": | |||||
return false | |||||
default: | |||||
if o.err == nil { | |||||
o.err = fmt.Errorf("redis: invalid %s boolean: expected true/false/1/0 or an empty string, got %q", name, s) | |||||
} | |||||
return false | |||||
} | |||||
} | |||||
func (o *queryOptions) remaining() []string { | |||||
if len(o.q) == 0 { | |||||
return nil | |||||
} | |||||
keys := make([]string, 0, len(o.q)) | |||||
for k := range o.q { | |||||
keys = append(keys, k) | |||||
} | |||||
sort.Strings(keys) | |||||
return keys | |||||
} | |||||
// setupConnParams converts query parameters in u to option value in o. | |||||
func setupConnParams(u *url.URL, o *Options) (*Options, error) { | |||||
q := queryOptions{q: u.Query()} | |||||
// compat: a future major release may use q.int("db") | |||||
if tmp := q.string("db"); tmp != "" { | |||||
db, err := strconv.Atoi(tmp) | |||||
if err != nil { | |||||
return nil, fmt.Errorf("redis: invalid database number: %w", err) | |||||
} | |||||
o.DB = db | |||||
} | |||||
o.MaxRetries = q.int("max_retries") | |||||
o.MinRetryBackoff = q.duration("min_retry_backoff") | |||||
o.MaxRetryBackoff = q.duration("max_retry_backoff") | |||||
o.DialTimeout = q.duration("dial_timeout") | |||||
o.ReadTimeout = q.duration("read_timeout") | |||||
o.WriteTimeout = q.duration("write_timeout") | |||||
o.PoolFIFO = q.bool("pool_fifo") | |||||
o.PoolSize = q.int("pool_size") | |||||
o.MinIdleConns = q.int("min_idle_conns") | |||||
o.MaxConnAge = q.duration("max_conn_age") | |||||
o.PoolTimeout = q.duration("pool_timeout") | |||||
o.IdleTimeout = q.duration("idle_timeout") | |||||
o.IdleCheckFrequency = q.duration("idle_check_frequency") | |||||
if q.err != nil { | |||||
return nil, q.err | |||||
} | |||||
// any parameters left? | |||||
if r := q.remaining(); len(r) > 0 { | |||||
return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", ")) | |||||
} | |||||
return o, nil | |||||
} | |||||
func getUserPassword(u *url.URL) (string, string) { | |||||
var user, password string | |||||
if u.User != nil { | |||||
user = u.User.Username() | |||||
if p, ok := u.User.Password(); ok { | |||||
password = p | |||||
} | |||||
} | |||||
return user, password | |||||
} | |||||
func newConnPool(opt *Options) *pool.ConnPool { | |||||
return pool.NewConnPool(&pool.Options{ | |||||
Dialer: func(ctx context.Context) (net.Conn, error) { | |||||
return opt.Dialer(ctx, opt.Network, opt.Addr) | |||||
}, | |||||
PoolFIFO: opt.PoolFIFO, | |||||
PoolSize: opt.PoolSize, | |||||
MinIdleConns: opt.MinIdleConns, | |||||
MaxConnAge: opt.MaxConnAge, | |||||
PoolTimeout: opt.PoolTimeout, | |||||
IdleTimeout: opt.IdleTimeout, | |||||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||||
}) | |||||
} |
@@ -0,0 +1,8 @@ | |||||
{ | |||||
"name": "redis", | |||||
"version": "8.11.4", | |||||
"main": "index.js", | |||||
"repository": "git@github.com:go-redis/redis.git", | |||||
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>", | |||||
"license": "BSD-2-clause" | |||||
} |
@@ -0,0 +1,137 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"sync" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
) | |||||
type pipelineExecer func(context.Context, []Cmder) error | |||||
// Pipeliner is an mechanism to realise Redis Pipeline technique. | |||||
// | |||||
// Pipelining is a technique to extremely speed up processing by packing | |||||
// operations to batches, send them at once to Redis and read a replies in a | |||||
// singe step. | |||||
// See https://redis.io/topics/pipelining | |||||
// | |||||
// Pay attention, that Pipeline is not a transaction, so you can get unexpected | |||||
// results in case of big pipelines and small read/write timeouts. | |||||
// Redis client has retransmission logic in case of timeouts, pipeline | |||||
// can be retransmitted and commands can be executed more then once. | |||||
// To avoid this: it is good idea to use reasonable bigger read/write timeouts | |||||
// depends of your batch size and/or use TxPipeline. | |||||
type Pipeliner interface { | |||||
StatefulCmdable | |||||
Do(ctx context.Context, args ...interface{}) *Cmd | |||||
Process(ctx context.Context, cmd Cmder) error | |||||
Close() error | |||||
Discard() error | |||||
Exec(ctx context.Context) ([]Cmder, error) | |||||
} | |||||
var _ Pipeliner = (*Pipeline)(nil) | |||||
// Pipeline implements pipelining as described in | |||||
// http://redis.io/topics/pipelining. It's safe for concurrent use | |||||
// by multiple goroutines. | |||||
type Pipeline struct { | |||||
cmdable | |||||
statefulCmdable | |||||
ctx context.Context | |||||
exec pipelineExecer | |||||
mu sync.Mutex | |||||
cmds []Cmder | |||||
closed bool | |||||
} | |||||
func (c *Pipeline) init() { | |||||
c.cmdable = c.Process | |||||
c.statefulCmdable = c.Process | |||||
} | |||||
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd { | |||||
cmd := NewCmd(ctx, args...) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Process queues the cmd for later execution. | |||||
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error { | |||||
c.mu.Lock() | |||||
c.cmds = append(c.cmds, cmd) | |||||
c.mu.Unlock() | |||||
return nil | |||||
} | |||||
// Close closes the pipeline, releasing any open resources. | |||||
func (c *Pipeline) Close() error { | |||||
c.mu.Lock() | |||||
_ = c.discard() | |||||
c.closed = true | |||||
c.mu.Unlock() | |||||
return nil | |||||
} | |||||
// Discard resets the pipeline and discards queued commands. | |||||
func (c *Pipeline) Discard() error { | |||||
c.mu.Lock() | |||||
err := c.discard() | |||||
c.mu.Unlock() | |||||
return err | |||||
} | |||||
func (c *Pipeline) discard() error { | |||||
if c.closed { | |||||
return pool.ErrClosed | |||||
} | |||||
c.cmds = c.cmds[:0] | |||||
return nil | |||||
} | |||||
// Exec executes all previously queued commands using one | |||||
// client-server roundtrip. | |||||
// | |||||
// Exec always returns list of commands and error of the first failed | |||||
// command if any. | |||||
func (c *Pipeline) Exec(ctx context.Context) ([]Cmder, error) { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.closed { | |||||
return nil, pool.ErrClosed | |||||
} | |||||
if len(c.cmds) == 0 { | |||||
return nil, nil | |||||
} | |||||
cmds := c.cmds | |||||
c.cmds = nil | |||||
return cmds, c.exec(ctx, cmds) | |||||
} | |||||
func (c *Pipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
if err := fn(c); err != nil { | |||||
return nil, err | |||||
} | |||||
cmds, err := c.Exec(ctx) | |||||
_ = c.Close() | |||||
return cmds, err | |||||
} | |||||
func (c *Pipeline) Pipeline() Pipeliner { | |||||
return c | |||||
} | |||||
func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.Pipelined(ctx, fn) | |||||
} | |||||
func (c *Pipeline) TxPipeline() Pipeliner { | |||||
return c | |||||
} |
@@ -0,0 +1,668 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"fmt" | |||||
"strings" | |||||
"sync" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/proto" | |||||
) | |||||
// PubSub implements Pub/Sub commands as described in | |||||
// http://redis.io/topics/pubsub. Message receiving is NOT safe | |||||
// for concurrent use by multiple goroutines. | |||||
// | |||||
// PubSub automatically reconnects to Redis Server and resubscribes | |||||
// to the channels in case of network errors. | |||||
type PubSub struct { | |||||
opt *Options | |||||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error) | |||||
closeConn func(*pool.Conn) error | |||||
mu sync.Mutex | |||||
cn *pool.Conn | |||||
channels map[string]struct{} | |||||
patterns map[string]struct{} | |||||
closed bool | |||||
exit chan struct{} | |||||
cmd *Cmd | |||||
chOnce sync.Once | |||||
msgCh *channel | |||||
allCh *channel | |||||
} | |||||
func (c *PubSub) init() { | |||||
c.exit = make(chan struct{}) | |||||
} | |||||
func (c *PubSub) String() string { | |||||
channels := mapKeys(c.channels) | |||||
channels = append(channels, mapKeys(c.patterns)...) | |||||
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) | |||||
} | |||||
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { | |||||
c.mu.Lock() | |||||
cn, err := c.conn(ctx, nil) | |||||
c.mu.Unlock() | |||||
return cn, err | |||||
} | |||||
func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) { | |||||
if c.closed { | |||||
return nil, pool.ErrClosed | |||||
} | |||||
if c.cn != nil { | |||||
return c.cn, nil | |||||
} | |||||
channels := mapKeys(c.channels) | |||||
channels = append(channels, newChannels...) | |||||
cn, err := c.newConn(ctx, channels) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if err := c.resubscribe(ctx, cn); err != nil { | |||||
_ = c.closeConn(cn) | |||||
return nil, err | |||||
} | |||||
c.cn = cn | |||||
return cn, nil | |||||
} | |||||
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error { | |||||
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||||
return writeCmd(wr, cmd) | |||||
}) | |||||
} | |||||
func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { | |||||
var firstErr error | |||||
if len(c.channels) > 0 { | |||||
firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels)) | |||||
} | |||||
if len(c.patterns) > 0 { | |||||
err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns)) | |||||
if err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
return firstErr | |||||
} | |||||
func mapKeys(m map[string]struct{}) []string { | |||||
s := make([]string, len(m)) | |||||
i := 0 | |||||
for k := range m { | |||||
s[i] = k | |||||
i++ | |||||
} | |||||
return s | |||||
} | |||||
func (c *PubSub) _subscribe( | |||||
ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, | |||||
) error { | |||||
args := make([]interface{}, 0, 1+len(channels)) | |||||
args = append(args, redisCmd) | |||||
for _, channel := range channels { | |||||
args = append(args, channel) | |||||
} | |||||
cmd := NewSliceCmd(ctx, args...) | |||||
return c.writeCmd(ctx, cn, cmd) | |||||
} | |||||
func (c *PubSub) releaseConnWithLock( | |||||
ctx context.Context, | |||||
cn *pool.Conn, | |||||
err error, | |||||
allowTimeout bool, | |||||
) { | |||||
c.mu.Lock() | |||||
c.releaseConn(ctx, cn, err, allowTimeout) | |||||
c.mu.Unlock() | |||||
} | |||||
func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { | |||||
if c.cn != cn { | |||||
return | |||||
} | |||||
if isBadConn(err, allowTimeout, c.opt.Addr) { | |||||
c.reconnect(ctx, err) | |||||
} | |||||
} | |||||
func (c *PubSub) reconnect(ctx context.Context, reason error) { | |||||
_ = c.closeTheCn(reason) | |||||
_, _ = c.conn(ctx, nil) | |||||
} | |||||
func (c *PubSub) closeTheCn(reason error) error { | |||||
if c.cn == nil { | |||||
return nil | |||||
} | |||||
if !c.closed { | |||||
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) | |||||
} | |||||
err := c.closeConn(c.cn) | |||||
c.cn = nil | |||||
return err | |||||
} | |||||
func (c *PubSub) Close() error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.closed { | |||||
return pool.ErrClosed | |||||
} | |||||
c.closed = true | |||||
close(c.exit) | |||||
return c.closeTheCn(pool.ErrClosed) | |||||
} | |||||
// Subscribe the client to the specified channels. It returns | |||||
// empty subscription if there are no channels. | |||||
func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
err := c.subscribe(ctx, "subscribe", channels...) | |||||
if c.channels == nil { | |||||
c.channels = make(map[string]struct{}) | |||||
} | |||||
for _, s := range channels { | |||||
c.channels[s] = struct{}{} | |||||
} | |||||
return err | |||||
} | |||||
// PSubscribe the client to the given patterns. It returns | |||||
// empty subscription if there are no patterns. | |||||
func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
err := c.subscribe(ctx, "psubscribe", patterns...) | |||||
if c.patterns == nil { | |||||
c.patterns = make(map[string]struct{}) | |||||
} | |||||
for _, s := range patterns { | |||||
c.patterns[s] = struct{}{} | |||||
} | |||||
return err | |||||
} | |||||
// Unsubscribe the client from the given channels, or from all of | |||||
// them if none is given. | |||||
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
for _, channel := range channels { | |||||
delete(c.channels, channel) | |||||
} | |||||
err := c.subscribe(ctx, "unsubscribe", channels...) | |||||
return err | |||||
} | |||||
// PUnsubscribe the client from the given patterns, or from all of | |||||
// them if none is given. | |||||
func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
for _, pattern := range patterns { | |||||
delete(c.patterns, pattern) | |||||
} | |||||
err := c.subscribe(ctx, "punsubscribe", patterns...) | |||||
return err | |||||
} | |||||
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error { | |||||
cn, err := c.conn(ctx, channels) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
err = c._subscribe(ctx, cn, redisCmd, channels) | |||||
c.releaseConn(ctx, cn, err, false) | |||||
return err | |||||
} | |||||
func (c *PubSub) Ping(ctx context.Context, payload ...string) error { | |||||
args := []interface{}{"ping"} | |||||
if len(payload) == 1 { | |||||
args = append(args, payload[0]) | |||||
} | |||||
cmd := NewCmd(ctx, args...) | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
cn, err := c.conn(ctx, nil) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
err = c.writeCmd(ctx, cn, cmd) | |||||
c.releaseConn(ctx, cn, err, false) | |||||
return err | |||||
} | |||||
// Subscription received after a successful subscription to channel. | |||||
type Subscription struct { | |||||
// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". | |||||
Kind string | |||||
// Channel name we have subscribed to. | |||||
Channel string | |||||
// Number of channels we are currently subscribed to. | |||||
Count int | |||||
} | |||||
func (m *Subscription) String() string { | |||||
return fmt.Sprintf("%s: %s", m.Kind, m.Channel) | |||||
} | |||||
// Message received as result of a PUBLISH command issued by another client. | |||||
type Message struct { | |||||
Channel string | |||||
Pattern string | |||||
Payload string | |||||
PayloadSlice []string | |||||
} | |||||
func (m *Message) String() string { | |||||
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) | |||||
} | |||||
// Pong received as result of a PING command issued by another client. | |||||
type Pong struct { | |||||
Payload string | |||||
} | |||||
func (p *Pong) String() string { | |||||
if p.Payload != "" { | |||||
return fmt.Sprintf("Pong<%s>", p.Payload) | |||||
} | |||||
return "Pong" | |||||
} | |||||
func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { | |||||
switch reply := reply.(type) { | |||||
case string: | |||||
return &Pong{ | |||||
Payload: reply, | |||||
}, nil | |||||
case []interface{}: | |||||
switch kind := reply[0].(string); kind { | |||||
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": | |||||
// Can be nil in case of "unsubscribe". | |||||
channel, _ := reply[1].(string) | |||||
return &Subscription{ | |||||
Kind: kind, | |||||
Channel: channel, | |||||
Count: int(reply[2].(int64)), | |||||
}, nil | |||||
case "message": | |||||
switch payload := reply[2].(type) { | |||||
case string: | |||||
return &Message{ | |||||
Channel: reply[1].(string), | |||||
Payload: payload, | |||||
}, nil | |||||
case []interface{}: | |||||
ss := make([]string, len(payload)) | |||||
for i, s := range payload { | |||||
ss[i] = s.(string) | |||||
} | |||||
return &Message{ | |||||
Channel: reply[1].(string), | |||||
PayloadSlice: ss, | |||||
}, nil | |||||
default: | |||||
return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload) | |||||
} | |||||
case "pmessage": | |||||
return &Message{ | |||||
Pattern: reply[1].(string), | |||||
Channel: reply[2].(string), | |||||
Payload: reply[3].(string), | |||||
}, nil | |||||
case "pong": | |||||
return &Pong{ | |||||
Payload: reply[1].(string), | |||||
}, nil | |||||
default: | |||||
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) | |||||
} | |||||
default: | |||||
return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply) | |||||
} | |||||
} | |||||
// ReceiveTimeout acts like Receive but returns an error if message | |||||
// is not received in time. This is low-level API and in most cases | |||||
// Channel should be used instead. | |||||
func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) { | |||||
if c.cmd == nil { | |||||
c.cmd = NewCmd(ctx) | |||||
} | |||||
// Don't hold the lock to allow subscriptions and pings. | |||||
cn, err := c.connWithLock(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error { | |||||
return c.cmd.readReply(rd) | |||||
}) | |||||
c.releaseConnWithLock(ctx, cn, err, timeout > 0) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return c.newMessage(c.cmd.Val()) | |||||
} | |||||
// Receive returns a message as a Subscription, Message, Pong or error. | |||||
// See PubSub example for details. This is low-level API and in most cases | |||||
// Channel should be used instead. | |||||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { | |||||
return c.ReceiveTimeout(ctx, 0) | |||||
} | |||||
// ReceiveMessage returns a Message or error ignoring Subscription and Pong | |||||
// messages. This is low-level API and in most cases Channel should be used | |||||
// instead. | |||||
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { | |||||
for { | |||||
msg, err := c.Receive(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
switch msg := msg.(type) { | |||||
case *Subscription: | |||||
// Ignore. | |||||
case *Pong: | |||||
// Ignore. | |||||
case *Message: | |||||
return msg, nil | |||||
default: | |||||
err := fmt.Errorf("redis: unknown message: %T", msg) | |||||
return nil, err | |||||
} | |||||
} | |||||
} | |||||
func (c *PubSub) getContext() context.Context { | |||||
if c.cmd != nil { | |||||
return c.cmd.ctx | |||||
} | |||||
return context.Background() | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// Channel returns a Go channel for concurrently receiving messages. | |||||
// The channel is closed together with the PubSub. If the Go channel | |||||
// is blocked full for 30 seconds the message is dropped. | |||||
// Receive* APIs can not be used after channel is created. | |||||
// | |||||
// go-redis periodically sends ping messages to test connection health | |||||
// and re-subscribes if ping can not not received for 30 seconds. | |||||
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { | |||||
c.chOnce.Do(func() { | |||||
c.msgCh = newChannel(c, opts...) | |||||
c.msgCh.initMsgChan() | |||||
}) | |||||
if c.msgCh == nil { | |||||
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") | |||||
panic(err) | |||||
} | |||||
return c.msgCh.msgCh | |||||
} | |||||
// ChannelSize is like Channel, but creates a Go channel | |||||
// with specified buffer size. | |||||
// | |||||
// Deprecated: use Channel(WithChannelSize(size)), remove in v9. | |||||
func (c *PubSub) ChannelSize(size int) <-chan *Message { | |||||
return c.Channel(WithChannelSize(size)) | |||||
} | |||||
// ChannelWithSubscriptions is like Channel, but message type can be either | |||||
// *Subscription or *Message. Subscription messages can be used to detect | |||||
// reconnections. | |||||
// | |||||
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize. | |||||
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { | |||||
c.chOnce.Do(func() { | |||||
c.allCh = newChannel(c, WithChannelSize(size)) | |||||
c.allCh.initAllChan() | |||||
}) | |||||
if c.allCh == nil { | |||||
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") | |||||
panic(err) | |||||
} | |||||
return c.allCh.allCh | |||||
} | |||||
type ChannelOption func(c *channel) | |||||
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. | |||||
// | |||||
// The default is 100 messages. | |||||
func WithChannelSize(size int) ChannelOption { | |||||
return func(c *channel) { | |||||
c.chanSize = size | |||||
} | |||||
} | |||||
// WithChannelHealthCheckInterval specifies the health check interval. | |||||
// PubSub will ping Redis Server if it does not receive any messages within the interval. | |||||
// To disable health check, use zero interval. | |||||
// | |||||
// The default is 3 seconds. | |||||
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption { | |||||
return func(c *channel) { | |||||
c.checkInterval = d | |||||
} | |||||
} | |||||
// WithChannelSendTimeout specifies the channel send timeout after which | |||||
// the message is dropped. | |||||
// | |||||
// The default is 60 seconds. | |||||
func WithChannelSendTimeout(d time.Duration) ChannelOption { | |||||
return func(c *channel) { | |||||
c.chanSendTimeout = d | |||||
} | |||||
} | |||||
type channel struct { | |||||
pubSub *PubSub | |||||
msgCh chan *Message | |||||
allCh chan interface{} | |||||
ping chan struct{} | |||||
chanSize int | |||||
chanSendTimeout time.Duration | |||||
checkInterval time.Duration | |||||
} | |||||
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { | |||||
c := &channel{ | |||||
pubSub: pubSub, | |||||
chanSize: 100, | |||||
chanSendTimeout: time.Minute, | |||||
checkInterval: 3 * time.Second, | |||||
} | |||||
for _, opt := range opts { | |||||
opt(c) | |||||
} | |||||
if c.checkInterval > 0 { | |||||
c.initHealthCheck() | |||||
} | |||||
return c | |||||
} | |||||
func (c *channel) initHealthCheck() { | |||||
ctx := context.TODO() | |||||
c.ping = make(chan struct{}, 1) | |||||
go func() { | |||||
timer := time.NewTimer(time.Minute) | |||||
timer.Stop() | |||||
for { | |||||
timer.Reset(c.checkInterval) | |||||
select { | |||||
case <-c.ping: | |||||
if !timer.Stop() { | |||||
<-timer.C | |||||
} | |||||
case <-timer.C: | |||||
if pingErr := c.pubSub.Ping(ctx); pingErr != nil { | |||||
c.pubSub.mu.Lock() | |||||
c.pubSub.reconnect(ctx, pingErr) | |||||
c.pubSub.mu.Unlock() | |||||
} | |||||
case <-c.pubSub.exit: | |||||
return | |||||
} | |||||
} | |||||
}() | |||||
} | |||||
// initMsgChan must be in sync with initAllChan. | |||||
func (c *channel) initMsgChan() { | |||||
ctx := context.TODO() | |||||
c.msgCh = make(chan *Message, c.chanSize) | |||||
go func() { | |||||
timer := time.NewTimer(time.Minute) | |||||
timer.Stop() | |||||
var errCount int | |||||
for { | |||||
msg, err := c.pubSub.Receive(ctx) | |||||
if err != nil { | |||||
if err == pool.ErrClosed { | |||||
close(c.msgCh) | |||||
return | |||||
} | |||||
if errCount > 0 { | |||||
time.Sleep(100 * time.Millisecond) | |||||
} | |||||
errCount++ | |||||
continue | |||||
} | |||||
errCount = 0 | |||||
// Any message is as good as a ping. | |||||
select { | |||||
case c.ping <- struct{}{}: | |||||
default: | |||||
} | |||||
switch msg := msg.(type) { | |||||
case *Subscription: | |||||
// Ignore. | |||||
case *Pong: | |||||
// Ignore. | |||||
case *Message: | |||||
timer.Reset(c.chanSendTimeout) | |||||
select { | |||||
case c.msgCh <- msg: | |||||
if !timer.Stop() { | |||||
<-timer.C | |||||
} | |||||
case <-timer.C: | |||||
internal.Logger.Printf( | |||||
ctx, "redis: %s channel is full for %s (message is dropped)", | |||||
c, c.chanSendTimeout) | |||||
} | |||||
default: | |||||
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) | |||||
} | |||||
} | |||||
}() | |||||
} | |||||
// initAllChan must be in sync with initMsgChan. | |||||
func (c *channel) initAllChan() { | |||||
ctx := context.TODO() | |||||
c.allCh = make(chan interface{}, c.chanSize) | |||||
go func() { | |||||
timer := time.NewTimer(time.Minute) | |||||
timer.Stop() | |||||
var errCount int | |||||
for { | |||||
msg, err := c.pubSub.Receive(ctx) | |||||
if err != nil { | |||||
if err == pool.ErrClosed { | |||||
close(c.allCh) | |||||
return | |||||
} | |||||
if errCount > 0 { | |||||
time.Sleep(100 * time.Millisecond) | |||||
} | |||||
errCount++ | |||||
continue | |||||
} | |||||
errCount = 0 | |||||
// Any message is as good as a ping. | |||||
select { | |||||
case c.ping <- struct{}{}: | |||||
default: | |||||
} | |||||
switch msg := msg.(type) { | |||||
case *Pong: | |||||
// Ignore. | |||||
case *Subscription, *Message: | |||||
timer.Reset(c.chanSendTimeout) | |||||
select { | |||||
case c.allCh <- msg: | |||||
if !timer.Stop() { | |||||
<-timer.C | |||||
} | |||||
case <-timer.C: | |||||
internal.Logger.Printf( | |||||
ctx, "redis: %s channel is full for %s (message is dropped)", | |||||
c, c.chanSendTimeout) | |||||
} | |||||
default: | |||||
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) | |||||
} | |||||
} | |||||
}() | |||||
} |
@@ -0,0 +1,773 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"errors" | |||||
"fmt" | |||||
"sync/atomic" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/proto" | |||||
) | |||||
// Nil reply returned by Redis when key does not exist. | |||||
const Nil = proto.Nil | |||||
func SetLogger(logger internal.Logging) { | |||||
internal.Logger = logger | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type Hook interface { | |||||
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) | |||||
AfterProcess(ctx context.Context, cmd Cmder) error | |||||
BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) | |||||
AfterProcessPipeline(ctx context.Context, cmds []Cmder) error | |||||
} | |||||
type hooks struct { | |||||
hooks []Hook | |||||
} | |||||
func (hs *hooks) lock() { | |||||
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] | |||||
} | |||||
func (hs hooks) clone() hooks { | |||||
clone := hs | |||||
clone.lock() | |||||
return clone | |||||
} | |||||
func (hs *hooks) AddHook(hook Hook) { | |||||
hs.hooks = append(hs.hooks, hook) | |||||
} | |||||
func (hs hooks) process( | |||||
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, | |||||
) error { | |||||
if len(hs.hooks) == 0 { | |||||
err := fn(ctx, cmd) | |||||
cmd.SetErr(err) | |||||
return err | |||||
} | |||||
var hookIndex int | |||||
var retErr error | |||||
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { | |||||
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) | |||||
if retErr != nil { | |||||
cmd.SetErr(retErr) | |||||
} | |||||
} | |||||
if retErr == nil { | |||||
retErr = fn(ctx, cmd) | |||||
cmd.SetErr(retErr) | |||||
} | |||||
for hookIndex--; hookIndex >= 0; hookIndex-- { | |||||
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { | |||||
retErr = err | |||||
cmd.SetErr(retErr) | |||||
} | |||||
} | |||||
return retErr | |||||
} | |||||
func (hs hooks) processPipeline( | |||||
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, | |||||
) error { | |||||
if len(hs.hooks) == 0 { | |||||
err := fn(ctx, cmds) | |||||
return err | |||||
} | |||||
var hookIndex int | |||||
var retErr error | |||||
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { | |||||
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) | |||||
if retErr != nil { | |||||
setCmdsErr(cmds, retErr) | |||||
} | |||||
} | |||||
if retErr == nil { | |||||
retErr = fn(ctx, cmds) | |||||
} | |||||
for hookIndex--; hookIndex >= 0; hookIndex-- { | |||||
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { | |||||
retErr = err | |||||
setCmdsErr(cmds, retErr) | |||||
} | |||||
} | |||||
return retErr | |||||
} | |||||
func (hs hooks) processTxPipeline( | |||||
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, | |||||
) error { | |||||
cmds = wrapMultiExec(ctx, cmds) | |||||
return hs.processPipeline(ctx, cmds, fn) | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type baseClient struct { | |||||
opt *Options | |||||
connPool pool.Pooler | |||||
onClose func() error // hook called when client is closed | |||||
} | |||||
func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { | |||||
return &baseClient{ | |||||
opt: opt, | |||||
connPool: connPool, | |||||
} | |||||
} | |||||
func (c *baseClient) clone() *baseClient { | |||||
clone := *c | |||||
return &clone | |||||
} | |||||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { | |||||
opt := c.opt.clone() | |||||
opt.ReadTimeout = timeout | |||||
opt.WriteTimeout = timeout | |||||
clone := c.clone() | |||||
clone.opt = opt | |||||
return clone | |||||
} | |||||
func (c *baseClient) String() string { | |||||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) | |||||
} | |||||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { | |||||
cn, err := c.connPool.NewConn(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
err = c.initConn(ctx, cn) | |||||
if err != nil { | |||||
_ = c.connPool.CloseConn(cn) | |||||
return nil, err | |||||
} | |||||
return cn, nil | |||||
} | |||||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { | |||||
if c.opt.Limiter != nil { | |||||
err := c.opt.Limiter.Allow() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} | |||||
cn, err := c._getConn(ctx) | |||||
if err != nil { | |||||
if c.opt.Limiter != nil { | |||||
c.opt.Limiter.ReportResult(err) | |||||
} | |||||
return nil, err | |||||
} | |||||
return cn, nil | |||||
} | |||||
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { | |||||
cn, err := c.connPool.Get(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if cn.Inited { | |||||
return cn, nil | |||||
} | |||||
if err := c.initConn(ctx, cn); err != nil { | |||||
c.connPool.Remove(ctx, cn, err) | |||||
if err := errors.Unwrap(err); err != nil { | |||||
return nil, err | |||||
} | |||||
return nil, err | |||||
} | |||||
return cn, nil | |||||
} | |||||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { | |||||
if cn.Inited { | |||||
return nil | |||||
} | |||||
cn.Inited = true | |||||
if c.opt.Password == "" && | |||||
c.opt.DB == 0 && | |||||
!c.opt.readOnly && | |||||
c.opt.OnConnect == nil { | |||||
return nil | |||||
} | |||||
connPool := pool.NewSingleConnPool(c.connPool, cn) | |||||
conn := newConn(ctx, c.opt, connPool) | |||||
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { | |||||
if c.opt.Password != "" { | |||||
if c.opt.Username != "" { | |||||
pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) | |||||
} else { | |||||
pipe.Auth(ctx, c.opt.Password) | |||||
} | |||||
} | |||||
if c.opt.DB > 0 { | |||||
pipe.Select(ctx, c.opt.DB) | |||||
} | |||||
if c.opt.readOnly { | |||||
pipe.ReadOnly(ctx) | |||||
} | |||||
return nil | |||||
}) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if c.opt.OnConnect != nil { | |||||
return c.opt.OnConnect(ctx, conn) | |||||
} | |||||
return nil | |||||
} | |||||
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { | |||||
if c.opt.Limiter != nil { | |||||
c.opt.Limiter.ReportResult(err) | |||||
} | |||||
if isBadConn(err, false, c.opt.Addr) { | |||||
c.connPool.Remove(ctx, cn, err) | |||||
} else { | |||||
c.connPool.Put(ctx, cn) | |||||
} | |||||
} | |||||
func (c *baseClient) withConn( | |||||
ctx context.Context, fn func(context.Context, *pool.Conn) error, | |||||
) error { | |||||
cn, err := c.getConn(ctx) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
defer func() { | |||||
c.releaseConn(ctx, cn, err) | |||||
}() | |||||
done := ctx.Done() //nolint:ifshort | |||||
if done == nil { | |||||
err = fn(ctx, cn) | |||||
return err | |||||
} | |||||
errc := make(chan error, 1) | |||||
go func() { errc <- fn(ctx, cn) }() | |||||
select { | |||||
case <-done: | |||||
_ = cn.Close() | |||||
// Wait for the goroutine to finish and send something. | |||||
<-errc | |||||
err = ctx.Err() | |||||
return err | |||||
case err = <-errc: | |||||
return err | |||||
} | |||||
} | |||||
func (c *baseClient) process(ctx context.Context, cmd Cmder) error { | |||||
var lastErr error | |||||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||||
attempt := attempt | |||||
retry, err := c._process(ctx, cmd, attempt) | |||||
if err == nil || !retry { | |||||
return err | |||||
} | |||||
lastErr = err | |||||
} | |||||
return lastErr | |||||
} | |||||
func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { | |||||
if attempt > 0 { | |||||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||||
return false, err | |||||
} | |||||
} | |||||
retryTimeout := uint32(1) | |||||
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { | |||||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||||
return writeCmd(wr, cmd) | |||||
}) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) | |||||
if err != nil { | |||||
if cmd.readTimeout() == nil { | |||||
atomic.StoreUint32(&retryTimeout, 1) | |||||
} | |||||
return err | |||||
} | |||||
return nil | |||||
}) | |||||
if err == nil { | |||||
return false, nil | |||||
} | |||||
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) | |||||
return retry, err | |||||
} | |||||
func (c *baseClient) retryBackoff(attempt int) time.Duration { | |||||
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) | |||||
} | |||||
func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { | |||||
if timeout := cmd.readTimeout(); timeout != nil { | |||||
t := *timeout | |||||
if t == 0 { | |||||
return 0 | |||||
} | |||||
return t + 10*time.Second | |||||
} | |||||
return c.opt.ReadTimeout | |||||
} | |||||
// Close closes the client, releasing any open resources. | |||||
// | |||||
// It is rare to Close a Client, as the Client is meant to be | |||||
// long-lived and shared between many goroutines. | |||||
func (c *baseClient) Close() error { | |||||
var firstErr error | |||||
if c.onClose != nil { | |||||
if err := c.onClose(); err != nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
if err := c.connPool.Close(); err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
return firstErr | |||||
} | |||||
func (c *baseClient) getAddr() string { | |||||
return c.opt.Addr | |||||
} | |||||
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) | |||||
} | |||||
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) | |||||
} | |||||
type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) | |||||
func (c *baseClient) generalProcessPipeline( | |||||
ctx context.Context, cmds []Cmder, p pipelineProcessor, | |||||
) error { | |||||
err := c._generalProcessPipeline(ctx, cmds, p) | |||||
if err != nil { | |||||
setCmdsErr(cmds, err) | |||||
return err | |||||
} | |||||
return cmdsFirstErr(cmds) | |||||
} | |||||
func (c *baseClient) _generalProcessPipeline( | |||||
ctx context.Context, cmds []Cmder, p pipelineProcessor, | |||||
) error { | |||||
var lastErr error | |||||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||||
if attempt > 0 { | |||||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
var canRetry bool | |||||
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { | |||||
var err error | |||||
canRetry, err = p(ctx, cn, cmds) | |||||
return err | |||||
}) | |||||
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { | |||||
return lastErr | |||||
} | |||||
} | |||||
return lastErr | |||||
} | |||||
func (c *baseClient) pipelineProcessCmds( | |||||
ctx context.Context, cn *pool.Conn, cmds []Cmder, | |||||
) (bool, error) { | |||||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||||
return writeCmds(wr, cmds) | |||||
}) | |||||
if err != nil { | |||||
return true, err | |||||
} | |||||
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { | |||||
return pipelineReadCmds(rd, cmds) | |||||
}) | |||||
return true, err | |||||
} | |||||
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { | |||||
for _, cmd := range cmds { | |||||
err := cmd.readReply(rd) | |||||
cmd.SetErr(err) | |||||
if err != nil && !isRedisError(err) { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func (c *baseClient) txPipelineProcessCmds( | |||||
ctx context.Context, cn *pool.Conn, cmds []Cmder, | |||||
) (bool, error) { | |||||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||||
return writeCmds(wr, cmds) | |||||
}) | |||||
if err != nil { | |||||
return true, err | |||||
} | |||||
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { | |||||
statusCmd := cmds[0].(*StatusCmd) | |||||
// Trim multi and exec. | |||||
cmds = cmds[1 : len(cmds)-1] | |||||
err := txPipelineReadQueued(rd, statusCmd, cmds) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
return pipelineReadCmds(rd, cmds) | |||||
}) | |||||
return false, err | |||||
} | |||||
func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { | |||||
if len(cmds) == 0 { | |||||
panic("not reached") | |||||
} | |||||
cmdCopy := make([]Cmder, len(cmds)+2) | |||||
cmdCopy[0] = NewStatusCmd(ctx, "multi") | |||||
copy(cmdCopy[1:], cmds) | |||||
cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") | |||||
return cmdCopy | |||||
} | |||||
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { | |||||
// Parse queued replies. | |||||
if err := statusCmd.readReply(rd); err != nil { | |||||
return err | |||||
} | |||||
for range cmds { | |||||
if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { | |||||
return err | |||||
} | |||||
} | |||||
// Parse number of replies. | |||||
line, err := rd.ReadLine() | |||||
if err != nil { | |||||
if err == Nil { | |||||
err = TxFailedErr | |||||
} | |||||
return err | |||||
} | |||||
switch line[0] { | |||||
case proto.ErrorReply: | |||||
return proto.ParseErrorReply(line) | |||||
case proto.ArrayReply: | |||||
// ok | |||||
default: | |||||
err := fmt.Errorf("redis: expected '*', but got line %q", line) | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// Client is a Redis client representing a pool of zero or more | |||||
// underlying connections. It's safe for concurrent use by multiple | |||||
// goroutines. | |||||
type Client struct { | |||||
*baseClient | |||||
cmdable | |||||
hooks | |||||
ctx context.Context | |||||
} | |||||
// NewClient returns a client to the Redis Server specified by Options. | |||||
func NewClient(opt *Options) *Client { | |||||
opt.init() | |||||
c := Client{ | |||||
baseClient: newBaseClient(opt, newConnPool(opt)), | |||||
ctx: context.Background(), | |||||
} | |||||
c.cmdable = c.Process | |||||
return &c | |||||
} | |||||
func (c *Client) clone() *Client { | |||||
clone := *c | |||||
clone.cmdable = clone.Process | |||||
clone.hooks.lock() | |||||
return &clone | |||||
} | |||||
func (c *Client) WithTimeout(timeout time.Duration) *Client { | |||||
clone := c.clone() | |||||
clone.baseClient = c.baseClient.withTimeout(timeout) | |||||
return clone | |||||
} | |||||
func (c *Client) Context() context.Context { | |||||
return c.ctx | |||||
} | |||||
func (c *Client) WithContext(ctx context.Context) *Client { | |||||
if ctx == nil { | |||||
panic("nil context") | |||||
} | |||||
clone := c.clone() | |||||
clone.ctx = ctx | |||||
return clone | |||||
} | |||||
func (c *Client) Conn(ctx context.Context) *Conn { | |||||
return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) | |||||
} | |||||
// Do creates a Cmd from the args and processes the cmd. | |||||
func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { | |||||
cmd := NewCmd(ctx, args...) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
func (c *Client) Process(ctx context.Context, cmd Cmder) error { | |||||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||||
} | |||||
func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||||
} | |||||
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||||
} | |||||
// Options returns read-only Options that were used to create the client. | |||||
func (c *Client) Options() *Options { | |||||
return c.opt | |||||
} | |||||
type PoolStats pool.Stats | |||||
// PoolStats returns connection pool stats. | |||||
func (c *Client) PoolStats() *PoolStats { | |||||
stats := c.connPool.Stats() | |||||
return (*PoolStats)(stats) | |||||
} | |||||
func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.Pipeline().Pipelined(ctx, fn) | |||||
} | |||||
func (c *Client) Pipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.TxPipeline().Pipelined(ctx, fn) | |||||
} | |||||
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. | |||||
func (c *Client) TxPipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processTxPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
func (c *Client) pubSub() *PubSub { | |||||
pubsub := &PubSub{ | |||||
opt: c.opt, | |||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { | |||||
return c.newConn(ctx) | |||||
}, | |||||
closeConn: c.connPool.CloseConn, | |||||
} | |||||
pubsub.init() | |||||
return pubsub | |||||
} | |||||
// Subscribe subscribes the client to the specified channels. | |||||
// Channels can be omitted to create empty subscription. | |||||
// Note that this method does not wait on a response from Redis, so the | |||||
// subscription may not be active immediately. To force the connection to wait, | |||||
// you may call the Receive() method on the returned *PubSub like so: | |||||
// | |||||
// sub := client.Subscribe(queryResp) | |||||
// iface, err := sub.Receive() | |||||
// if err != nil { | |||||
// // handle error | |||||
// } | |||||
// | |||||
// // Should be *Subscription, but others are possible if other actions have been | |||||
// // taken on sub since it was created. | |||||
// switch iface.(type) { | |||||
// case *Subscription: | |||||
// // subscribe succeeded | |||||
// case *Message: | |||||
// // received first message | |||||
// case *Pong: | |||||
// // pong received | |||||
// default: | |||||
// // handle error | |||||
// } | |||||
// | |||||
// ch := sub.Channel() | |||||
func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||||
pubsub := c.pubSub() | |||||
if len(channels) > 0 { | |||||
_ = pubsub.Subscribe(ctx, channels...) | |||||
} | |||||
return pubsub | |||||
} | |||||
// PSubscribe subscribes the client to the given patterns. | |||||
// Patterns can be omitted to create empty subscription. | |||||
func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||||
pubsub := c.pubSub() | |||||
if len(channels) > 0 { | |||||
_ = pubsub.PSubscribe(ctx, channels...) | |||||
} | |||||
return pubsub | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type conn struct { | |||||
baseClient | |||||
cmdable | |||||
statefulCmdable | |||||
hooks // TODO: inherit hooks | |||||
} | |||||
// Conn represents a single Redis connection rather than a pool of connections. | |||||
// Prefer running commands from Client unless there is a specific need | |||||
// for a continuous single Redis connection. | |||||
type Conn struct { | |||||
*conn | |||||
ctx context.Context | |||||
} | |||||
func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { | |||||
c := Conn{ | |||||
conn: &conn{ | |||||
baseClient: baseClient{ | |||||
opt: opt, | |||||
connPool: connPool, | |||||
}, | |||||
}, | |||||
ctx: ctx, | |||||
} | |||||
c.cmdable = c.Process | |||||
c.statefulCmdable = c.Process | |||||
return &c | |||||
} | |||||
func (c *Conn) Process(ctx context.Context, cmd Cmder) error { | |||||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||||
} | |||||
func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||||
} | |||||
func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||||
} | |||||
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.Pipeline().Pipelined(ctx, fn) | |||||
} | |||||
func (c *Conn) Pipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.TxPipeline().Pipelined(ctx, fn) | |||||
} | |||||
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. | |||||
func (c *Conn) TxPipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processTxPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} |
@@ -0,0 +1,180 @@ | |||||
package redis | |||||
import "time" | |||||
// NewCmdResult returns a Cmd initialised with val and err for testing. | |||||
func NewCmdResult(val interface{}, err error) *Cmd { | |||||
var cmd Cmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewSliceResult returns a SliceCmd initialised with val and err for testing. | |||||
func NewSliceResult(val []interface{}, err error) *SliceCmd { | |||||
var cmd SliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewStatusResult returns a StatusCmd initialised with val and err for testing. | |||||
func NewStatusResult(val string, err error) *StatusCmd { | |||||
var cmd StatusCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewIntResult returns an IntCmd initialised with val and err for testing. | |||||
func NewIntResult(val int64, err error) *IntCmd { | |||||
var cmd IntCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewDurationResult returns a DurationCmd initialised with val and err for testing. | |||||
func NewDurationResult(val time.Duration, err error) *DurationCmd { | |||||
var cmd DurationCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewBoolResult returns a BoolCmd initialised with val and err for testing. | |||||
func NewBoolResult(val bool, err error) *BoolCmd { | |||||
var cmd BoolCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewStringResult returns a StringCmd initialised with val and err for testing. | |||||
func NewStringResult(val string, err error) *StringCmd { | |||||
var cmd StringCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewFloatResult returns a FloatCmd initialised with val and err for testing. | |||||
func NewFloatResult(val float64, err error) *FloatCmd { | |||||
var cmd FloatCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewStringSliceResult returns a StringSliceCmd initialised with val and err for testing. | |||||
func NewStringSliceResult(val []string, err error) *StringSliceCmd { | |||||
var cmd StringSliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewBoolSliceResult returns a BoolSliceCmd initialised with val and err for testing. | |||||
func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { | |||||
var cmd BoolSliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewStringStringMapResult returns a StringStringMapCmd initialised with val and err for testing. | |||||
func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd { | |||||
var cmd StringStringMapCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewStringIntMapCmdResult returns a StringIntMapCmd initialised with val and err for testing. | |||||
func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd { | |||||
var cmd StringIntMapCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewTimeCmdResult returns a TimeCmd initialised with val and err for testing. | |||||
func NewTimeCmdResult(val time.Time, err error) *TimeCmd { | |||||
var cmd TimeCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewZSliceCmdResult returns a ZSliceCmd initialised with val and err for testing. | |||||
func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd { | |||||
var cmd ZSliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewZWithKeyCmdResult returns a NewZWithKeyCmd initialised with val and err for testing. | |||||
func NewZWithKeyCmdResult(val *ZWithKey, err error) *ZWithKeyCmd { | |||||
var cmd ZWithKeyCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewScanCmdResult returns a ScanCmd initialised with val and err for testing. | |||||
func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd { | |||||
var cmd ScanCmd | |||||
cmd.page = keys | |||||
cmd.cursor = cursor | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewClusterSlotsCmdResult returns a ClusterSlotsCmd initialised with val and err for testing. | |||||
func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd { | |||||
var cmd ClusterSlotsCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewGeoLocationCmdResult returns a GeoLocationCmd initialised with val and err for testing. | |||||
func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd { | |||||
var cmd GeoLocationCmd | |||||
cmd.locations = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewGeoPosCmdResult returns a GeoPosCmd initialised with val and err for testing. | |||||
func NewGeoPosCmdResult(val []*GeoPos, err error) *GeoPosCmd { | |||||
var cmd GeoPosCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewCommandsInfoCmdResult returns a CommandsInfoCmd initialised with val and err for testing. | |||||
func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd { | |||||
var cmd CommandsInfoCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewXMessageSliceCmdResult returns a XMessageSliceCmd initialised with val and err for testing. | |||||
func NewXMessageSliceCmdResult(val []XMessage, err error) *XMessageSliceCmd { | |||||
var cmd XMessageSliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} | |||||
// NewXStreamSliceCmdResult returns a XStreamSliceCmd initialised with val and err for testing. | |||||
func NewXStreamSliceCmdResult(val []XStream, err error) *XStreamSliceCmd { | |||||
var cmd XStreamSliceCmd | |||||
cmd.val = val | |||||
cmd.SetErr(err) | |||||
return &cmd | |||||
} |
@@ -0,0 +1,736 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"crypto/tls" | |||||
"errors" | |||||
"fmt" | |||||
"net" | |||||
"strconv" | |||||
"sync" | |||||
"sync/atomic" | |||||
"time" | |||||
"github.com/cespare/xxhash/v2" | |||||
rendezvous "github.com/dgryski/go-rendezvous" //nolint | |||||
"github.com/go-redis/redis/v8/internal" | |||||
"github.com/go-redis/redis/v8/internal/hashtag" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/rand" | |||||
) | |||||
var errRingShardsDown = errors.New("redis: all ring shards are down") | |||||
//------------------------------------------------------------------------------ | |||||
type ConsistentHash interface { | |||||
Get(string) string | |||||
} | |||||
type rendezvousWrapper struct { | |||||
*rendezvous.Rendezvous | |||||
} | |||||
func (w rendezvousWrapper) Get(key string) string { | |||||
return w.Lookup(key) | |||||
} | |||||
func newRendezvous(shards []string) ConsistentHash { | |||||
return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)} | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// RingOptions are used to configure a ring client and should be | |||||
// passed to NewRing. | |||||
type RingOptions struct { | |||||
// Map of name => host:port addresses of ring shards. | |||||
Addrs map[string]string | |||||
// NewClient creates a shard client with provided name and options. | |||||
NewClient func(name string, opt *Options) *Client | |||||
// Frequency of PING commands sent to check shards availability. | |||||
// Shard is considered down after 3 subsequent failed checks. | |||||
HeartbeatFrequency time.Duration | |||||
// NewConsistentHash returns a consistent hash that is used | |||||
// to distribute keys across the shards. | |||||
// | |||||
// See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8 | |||||
// for consistent hashing algorithmic tradeoffs. | |||||
NewConsistentHash func(shards []string) ConsistentHash | |||||
// Following options are copied from Options struct. | |||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||||
OnConnect func(ctx context.Context, cn *Conn) error | |||||
Username string | |||||
Password string | |||||
DB int | |||||
MaxRetries int | |||||
MinRetryBackoff time.Duration | |||||
MaxRetryBackoff time.Duration | |||||
DialTimeout time.Duration | |||||
ReadTimeout time.Duration | |||||
WriteTimeout time.Duration | |||||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||||
PoolFIFO bool | |||||
PoolSize int | |||||
MinIdleConns int | |||||
MaxConnAge time.Duration | |||||
PoolTimeout time.Duration | |||||
IdleTimeout time.Duration | |||||
IdleCheckFrequency time.Duration | |||||
TLSConfig *tls.Config | |||||
Limiter Limiter | |||||
} | |||||
func (opt *RingOptions) init() { | |||||
if opt.NewClient == nil { | |||||
opt.NewClient = func(name string, opt *Options) *Client { | |||||
return NewClient(opt) | |||||
} | |||||
} | |||||
if opt.HeartbeatFrequency == 0 { | |||||
opt.HeartbeatFrequency = 500 * time.Millisecond | |||||
} | |||||
if opt.NewConsistentHash == nil { | |||||
opt.NewConsistentHash = newRendezvous | |||||
} | |||||
if opt.MaxRetries == -1 { | |||||
opt.MaxRetries = 0 | |||||
} else if opt.MaxRetries == 0 { | |||||
opt.MaxRetries = 3 | |||||
} | |||||
switch opt.MinRetryBackoff { | |||||
case -1: | |||||
opt.MinRetryBackoff = 0 | |||||
case 0: | |||||
opt.MinRetryBackoff = 8 * time.Millisecond | |||||
} | |||||
switch opt.MaxRetryBackoff { | |||||
case -1: | |||||
opt.MaxRetryBackoff = 0 | |||||
case 0: | |||||
opt.MaxRetryBackoff = 512 * time.Millisecond | |||||
} | |||||
} | |||||
func (opt *RingOptions) clientOptions() *Options { | |||||
return &Options{ | |||||
Dialer: opt.Dialer, | |||||
OnConnect: opt.OnConnect, | |||||
Username: opt.Username, | |||||
Password: opt.Password, | |||||
DB: opt.DB, | |||||
MaxRetries: -1, | |||||
DialTimeout: opt.DialTimeout, | |||||
ReadTimeout: opt.ReadTimeout, | |||||
WriteTimeout: opt.WriteTimeout, | |||||
PoolFIFO: opt.PoolFIFO, | |||||
PoolSize: opt.PoolSize, | |||||
MinIdleConns: opt.MinIdleConns, | |||||
MaxConnAge: opt.MaxConnAge, | |||||
PoolTimeout: opt.PoolTimeout, | |||||
IdleTimeout: opt.IdleTimeout, | |||||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||||
TLSConfig: opt.TLSConfig, | |||||
Limiter: opt.Limiter, | |||||
} | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type ringShard struct { | |||||
Client *Client | |||||
down int32 | |||||
} | |||||
func newRingShard(opt *RingOptions, name, addr string) *ringShard { | |||||
clopt := opt.clientOptions() | |||||
clopt.Addr = addr | |||||
return &ringShard{ | |||||
Client: opt.NewClient(name, clopt), | |||||
} | |||||
} | |||||
func (shard *ringShard) String() string { | |||||
var state string | |||||
if shard.IsUp() { | |||||
state = "up" | |||||
} else { | |||||
state = "down" | |||||
} | |||||
return fmt.Sprintf("%s is %s", shard.Client, state) | |||||
} | |||||
func (shard *ringShard) IsDown() bool { | |||||
const threshold = 3 | |||||
return atomic.LoadInt32(&shard.down) >= threshold | |||||
} | |||||
func (shard *ringShard) IsUp() bool { | |||||
return !shard.IsDown() | |||||
} | |||||
// Vote votes to set shard state and returns true if state was changed. | |||||
func (shard *ringShard) Vote(up bool) bool { | |||||
if up { | |||||
changed := shard.IsDown() | |||||
atomic.StoreInt32(&shard.down, 0) | |||||
return changed | |||||
} | |||||
if shard.IsDown() { | |||||
return false | |||||
} | |||||
atomic.AddInt32(&shard.down, 1) | |||||
return shard.IsDown() | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type ringShards struct { | |||||
opt *RingOptions | |||||
mu sync.RWMutex | |||||
hash ConsistentHash | |||||
shards map[string]*ringShard // read only | |||||
list []*ringShard // read only | |||||
numShard int | |||||
closed bool | |||||
} | |||||
func newRingShards(opt *RingOptions) *ringShards { | |||||
shards := make(map[string]*ringShard, len(opt.Addrs)) | |||||
list := make([]*ringShard, 0, len(shards)) | |||||
for name, addr := range opt.Addrs { | |||||
shard := newRingShard(opt, name, addr) | |||||
shards[name] = shard | |||||
list = append(list, shard) | |||||
} | |||||
c := &ringShards{ | |||||
opt: opt, | |||||
shards: shards, | |||||
list: list, | |||||
} | |||||
c.rebalance() | |||||
return c | |||||
} | |||||
func (c *ringShards) List() []*ringShard { | |||||
var list []*ringShard | |||||
c.mu.RLock() | |||||
if !c.closed { | |||||
list = c.list | |||||
} | |||||
c.mu.RUnlock() | |||||
return list | |||||
} | |||||
func (c *ringShards) Hash(key string) string { | |||||
key = hashtag.Key(key) | |||||
var hash string | |||||
c.mu.RLock() | |||||
if c.numShard > 0 { | |||||
hash = c.hash.Get(key) | |||||
} | |||||
c.mu.RUnlock() | |||||
return hash | |||||
} | |||||
func (c *ringShards) GetByKey(key string) (*ringShard, error) { | |||||
key = hashtag.Key(key) | |||||
c.mu.RLock() | |||||
if c.closed { | |||||
c.mu.RUnlock() | |||||
return nil, pool.ErrClosed | |||||
} | |||||
if c.numShard == 0 { | |||||
c.mu.RUnlock() | |||||
return nil, errRingShardsDown | |||||
} | |||||
hash := c.hash.Get(key) | |||||
if hash == "" { | |||||
c.mu.RUnlock() | |||||
return nil, errRingShardsDown | |||||
} | |||||
shard := c.shards[hash] | |||||
c.mu.RUnlock() | |||||
return shard, nil | |||||
} | |||||
func (c *ringShards) GetByName(shardName string) (*ringShard, error) { | |||||
if shardName == "" { | |||||
return c.Random() | |||||
} | |||||
c.mu.RLock() | |||||
shard := c.shards[shardName] | |||||
c.mu.RUnlock() | |||||
return shard, nil | |||||
} | |||||
func (c *ringShards) Random() (*ringShard, error) { | |||||
return c.GetByKey(strconv.Itoa(rand.Int())) | |||||
} | |||||
// heartbeat monitors state of each shard in the ring. | |||||
func (c *ringShards) Heartbeat(frequency time.Duration) { | |||||
ticker := time.NewTicker(frequency) | |||||
defer ticker.Stop() | |||||
ctx := context.Background() | |||||
for range ticker.C { | |||||
var rebalance bool | |||||
for _, shard := range c.List() { | |||||
err := shard.Client.Ping(ctx).Err() | |||||
isUp := err == nil || err == pool.ErrPoolTimeout | |||||
if shard.Vote(isUp) { | |||||
internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard) | |||||
rebalance = true | |||||
} | |||||
} | |||||
if rebalance { | |||||
c.rebalance() | |||||
} | |||||
} | |||||
} | |||||
// rebalance removes dead shards from the Ring. | |||||
func (c *ringShards) rebalance() { | |||||
c.mu.RLock() | |||||
shards := c.shards | |||||
c.mu.RUnlock() | |||||
liveShards := make([]string, 0, len(shards)) | |||||
for name, shard := range shards { | |||||
if shard.IsUp() { | |||||
liveShards = append(liveShards, name) | |||||
} | |||||
} | |||||
hash := c.opt.NewConsistentHash(liveShards) | |||||
c.mu.Lock() | |||||
c.hash = hash | |||||
c.numShard = len(liveShards) | |||||
c.mu.Unlock() | |||||
} | |||||
func (c *ringShards) Len() int { | |||||
c.mu.RLock() | |||||
l := c.numShard | |||||
c.mu.RUnlock() | |||||
return l | |||||
} | |||||
func (c *ringShards) Close() error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.closed { | |||||
return nil | |||||
} | |||||
c.closed = true | |||||
var firstErr error | |||||
for _, shard := range c.shards { | |||||
if err := shard.Client.Close(); err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
c.hash = nil | |||||
c.shards = nil | |||||
c.list = nil | |||||
return firstErr | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type ring struct { | |||||
opt *RingOptions | |||||
shards *ringShards | |||||
cmdsInfoCache *cmdsInfoCache //nolint:structcheck | |||||
} | |||||
// Ring is a Redis client that uses consistent hashing to distribute | |||||
// keys across multiple Redis servers (shards). It's safe for | |||||
// concurrent use by multiple goroutines. | |||||
// | |||||
// Ring monitors the state of each shard and removes dead shards from | |||||
// the ring. When a shard comes online it is added back to the ring. This | |||||
// gives you maximum availability and partition tolerance, but no | |||||
// consistency between different shards or even clients. Each client | |||||
// uses shards that are available to the client and does not do any | |||||
// coordination when shard state is changed. | |||||
// | |||||
// Ring should be used when you need multiple Redis servers for caching | |||||
// and can tolerate losing data when one of the servers dies. | |||||
// Otherwise you should use Redis Cluster. | |||||
type Ring struct { | |||||
*ring | |||||
cmdable | |||||
hooks | |||||
ctx context.Context | |||||
} | |||||
func NewRing(opt *RingOptions) *Ring { | |||||
opt.init() | |||||
ring := Ring{ | |||||
ring: &ring{ | |||||
opt: opt, | |||||
shards: newRingShards(opt), | |||||
}, | |||||
ctx: context.Background(), | |||||
} | |||||
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) | |||||
ring.cmdable = ring.Process | |||||
go ring.shards.Heartbeat(opt.HeartbeatFrequency) | |||||
return &ring | |||||
} | |||||
func (c *Ring) Context() context.Context { | |||||
return c.ctx | |||||
} | |||||
func (c *Ring) WithContext(ctx context.Context) *Ring { | |||||
if ctx == nil { | |||||
panic("nil context") | |||||
} | |||||
clone := *c | |||||
clone.cmdable = clone.Process | |||||
clone.hooks.lock() | |||||
clone.ctx = ctx | |||||
return &clone | |||||
} | |||||
// Do creates a Cmd from the args and processes the cmd. | |||||
func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { | |||||
cmd := NewCmd(ctx, args...) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
func (c *Ring) Process(ctx context.Context, cmd Cmder) error { | |||||
return c.hooks.process(ctx, cmd, c.process) | |||||
} | |||||
// Options returns read-only Options that were used to create the client. | |||||
func (c *Ring) Options() *RingOptions { | |||||
return c.opt | |||||
} | |||||
func (c *Ring) retryBackoff(attempt int) time.Duration { | |||||
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) | |||||
} | |||||
// PoolStats returns accumulated connection pool stats. | |||||
func (c *Ring) PoolStats() *PoolStats { | |||||
shards := c.shards.List() | |||||
var acc PoolStats | |||||
for _, shard := range shards { | |||||
s := shard.Client.connPool.Stats() | |||||
acc.Hits += s.Hits | |||||
acc.Misses += s.Misses | |||||
acc.Timeouts += s.Timeouts | |||||
acc.TotalConns += s.TotalConns | |||||
acc.IdleConns += s.IdleConns | |||||
} | |||||
return &acc | |||||
} | |||||
// Len returns the current number of shards in the ring. | |||||
func (c *Ring) Len() int { | |||||
return c.shards.Len() | |||||
} | |||||
// Subscribe subscribes the client to the specified channels. | |||||
func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||||
if len(channels) == 0 { | |||||
panic("at least one channel is required") | |||||
} | |||||
shard, err := c.shards.GetByKey(channels[0]) | |||||
if err != nil { | |||||
// TODO: return PubSub with sticky error | |||||
panic(err) | |||||
} | |||||
return shard.Client.Subscribe(ctx, channels...) | |||||
} | |||||
// PSubscribe subscribes the client to the given patterns. | |||||
func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||||
if len(channels) == 0 { | |||||
panic("at least one channel is required") | |||||
} | |||||
shard, err := c.shards.GetByKey(channels[0]) | |||||
if err != nil { | |||||
// TODO: return PubSub with sticky error | |||||
panic(err) | |||||
} | |||||
return shard.Client.PSubscribe(ctx, channels...) | |||||
} | |||||
// ForEachShard concurrently calls the fn on each live shard in the ring. | |||||
// It returns the first error if any. | |||||
func (c *Ring) ForEachShard( | |||||
ctx context.Context, | |||||
fn func(ctx context.Context, client *Client) error, | |||||
) error { | |||||
shards := c.shards.List() | |||||
var wg sync.WaitGroup | |||||
errCh := make(chan error, 1) | |||||
for _, shard := range shards { | |||||
if shard.IsDown() { | |||||
continue | |||||
} | |||||
wg.Add(1) | |||||
go func(shard *ringShard) { | |||||
defer wg.Done() | |||||
err := fn(ctx, shard.Client) | |||||
if err != nil { | |||||
select { | |||||
case errCh <- err: | |||||
default: | |||||
} | |||||
} | |||||
}(shard) | |||||
} | |||||
wg.Wait() | |||||
select { | |||||
case err := <-errCh: | |||||
return err | |||||
default: | |||||
return nil | |||||
} | |||||
} | |||||
func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { | |||||
shards := c.shards.List() | |||||
var firstErr error | |||||
for _, shard := range shards { | |||||
cmdsInfo, err := shard.Client.Command(ctx).Result() | |||||
if err == nil { | |||||
return cmdsInfo, nil | |||||
} | |||||
if firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
} | |||||
if firstErr == nil { | |||||
return nil, errRingShardsDown | |||||
} | |||||
return nil, firstErr | |||||
} | |||||
func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo { | |||||
cmdsInfo, err := c.cmdsInfoCache.Get(ctx) | |||||
if err != nil { | |||||
return nil | |||||
} | |||||
info := cmdsInfo[name] | |||||
if info == nil { | |||||
internal.Logger.Printf(c.Context(), "info for cmd=%s not found", name) | |||||
} | |||||
return info | |||||
} | |||||
func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { | |||||
cmdInfo := c.cmdInfo(ctx, cmd.Name()) | |||||
pos := cmdFirstKeyPos(cmd, cmdInfo) | |||||
if pos == 0 { | |||||
return c.shards.Random() | |||||
} | |||||
firstKey := cmd.stringArg(pos) | |||||
return c.shards.GetByKey(firstKey) | |||||
} | |||||
func (c *Ring) process(ctx context.Context, cmd Cmder) error { | |||||
var lastErr error | |||||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||||
if attempt > 0 { | |||||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
shard, err := c.cmdShard(ctx, cmd) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
lastErr = shard.Client.Process(ctx, cmd) | |||||
if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { | |||||
return lastErr | |||||
} | |||||
} | |||||
return lastErr | |||||
} | |||||
func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.Pipeline().Pipelined(ctx, fn) | |||||
} | |||||
func (c *Ring) Pipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { | |||||
return c.generalProcessPipeline(ctx, cmds, false) | |||||
}) | |||||
} | |||||
func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.TxPipeline().Pipelined(ctx, fn) | |||||
} | |||||
func (c *Ring) TxPipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: c.processTxPipeline, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { | |||||
return c.generalProcessPipeline(ctx, cmds, true) | |||||
}) | |||||
} | |||||
func (c *Ring) generalProcessPipeline( | |||||
ctx context.Context, cmds []Cmder, tx bool, | |||||
) error { | |||||
cmdsMap := make(map[string][]Cmder) | |||||
for _, cmd := range cmds { | |||||
cmdInfo := c.cmdInfo(ctx, cmd.Name()) | |||||
hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) | |||||
if hash != "" { | |||||
hash = c.shards.Hash(hash) | |||||
} | |||||
cmdsMap[hash] = append(cmdsMap[hash], cmd) | |||||
} | |||||
var wg sync.WaitGroup | |||||
for hash, cmds := range cmdsMap { | |||||
wg.Add(1) | |||||
go func(hash string, cmds []Cmder) { | |||||
defer wg.Done() | |||||
_ = c.processShardPipeline(ctx, hash, cmds, tx) | |||||
}(hash, cmds) | |||||
} | |||||
wg.Wait() | |||||
return cmdsFirstErr(cmds) | |||||
} | |||||
func (c *Ring) processShardPipeline( | |||||
ctx context.Context, hash string, cmds []Cmder, tx bool, | |||||
) error { | |||||
// TODO: retry? | |||||
shard, err := c.shards.GetByName(hash) | |||||
if err != nil { | |||||
setCmdsErr(cmds, err) | |||||
return err | |||||
} | |||||
if tx { | |||||
return shard.Client.processTxPipeline(ctx, cmds) | |||||
} | |||||
return shard.Client.processPipeline(ctx, cmds) | |||||
} | |||||
func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { | |||||
if len(keys) == 0 { | |||||
return fmt.Errorf("redis: Watch requires at least one key") | |||||
} | |||||
var shards []*ringShard | |||||
for _, key := range keys { | |||||
if key != "" { | |||||
shard, err := c.shards.GetByKey(hashtag.Key(key)) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
shards = append(shards, shard) | |||||
} | |||||
} | |||||
if len(shards) == 0 { | |||||
return fmt.Errorf("redis: Watch requires at least one shard") | |||||
} | |||||
if len(shards) > 1 { | |||||
for _, shard := range shards[1:] { | |||||
if shard.Client != shards[0].Client { | |||||
err := fmt.Errorf("redis: Watch requires all keys to be in the same shard") | |||||
return err | |||||
} | |||||
} | |||||
} | |||||
return shards[0].Client.Watch(ctx, fn, keys...) | |||||
} | |||||
// Close closes the ring client, releasing any open resources. | |||||
// | |||||
// It is rare to Close a Ring, as the Ring is meant to be long-lived | |||||
// and shared between many goroutines. | |||||
func (c *Ring) Close() error { | |||||
return c.shards.Close() | |||||
} |
@@ -0,0 +1,65 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"crypto/sha1" | |||||
"encoding/hex" | |||||
"io" | |||||
"strings" | |||||
) | |||||
type Scripter interface { | |||||
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd | |||||
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd | |||||
ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd | |||||
ScriptLoad(ctx context.Context, script string) *StringCmd | |||||
} | |||||
var ( | |||||
_ Scripter = (*Client)(nil) | |||||
_ Scripter = (*Ring)(nil) | |||||
_ Scripter = (*ClusterClient)(nil) | |||||
) | |||||
type Script struct { | |||||
src, hash string | |||||
} | |||||
func NewScript(src string) *Script { | |||||
h := sha1.New() | |||||
_, _ = io.WriteString(h, src) | |||||
return &Script{ | |||||
src: src, | |||||
hash: hex.EncodeToString(h.Sum(nil)), | |||||
} | |||||
} | |||||
func (s *Script) Hash() string { | |||||
return s.hash | |||||
} | |||||
func (s *Script) Load(ctx context.Context, c Scripter) *StringCmd { | |||||
return c.ScriptLoad(ctx, s.src) | |||||
} | |||||
func (s *Script) Exists(ctx context.Context, c Scripter) *BoolSliceCmd { | |||||
return c.ScriptExists(ctx, s.hash) | |||||
} | |||||
func (s *Script) Eval(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||||
return c.Eval(ctx, s.src, keys, args...) | |||||
} | |||||
func (s *Script) EvalSha(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||||
return c.EvalSha(ctx, s.hash, keys, args...) | |||||
} | |||||
// Run optimistically uses EVALSHA to run the script. If script does not exist | |||||
// it is retried using EVAL. | |||||
func (s *Script) Run(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||||
r := s.EvalSha(ctx, c, keys, args...) | |||||
if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { | |||||
return s.Eval(ctx, c, keys, args...) | |||||
} | |||||
return r | |||||
} |
@@ -0,0 +1,796 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"crypto/tls" | |||||
"errors" | |||||
"net" | |||||
"strings" | |||||
"sync" | |||||
"time" | |||||
"github.com/go-redis/redis/v8/internal" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/rand" | |||||
) | |||||
//------------------------------------------------------------------------------ | |||||
// FailoverOptions are used to configure a failover client and should | |||||
// be passed to NewFailoverClient. | |||||
type FailoverOptions struct { | |||||
// The master name. | |||||
MasterName string | |||||
// A seed list of host:port addresses of sentinel nodes. | |||||
SentinelAddrs []string | |||||
// If specified with SentinelPassword, enables ACL-based authentication (via | |||||
// AUTH <user> <pass>). | |||||
SentinelUsername string | |||||
// Sentinel password from "requirepass <password>" (if enabled) in Sentinel | |||||
// configuration, or, if SentinelUsername is also supplied, used for ACL-based | |||||
// authentication. | |||||
SentinelPassword string | |||||
// Allows routing read-only commands to the closest master or slave node. | |||||
// This option only works with NewFailoverClusterClient. | |||||
RouteByLatency bool | |||||
// Allows routing read-only commands to the random master or slave node. | |||||
// This option only works with NewFailoverClusterClient. | |||||
RouteRandomly bool | |||||
// Route all commands to slave read-only nodes. | |||||
SlaveOnly bool | |||||
// Use slaves disconnected with master when cannot get connected slaves | |||||
// Now, this option only works in RandomSlaveAddr function. | |||||
UseDisconnectedSlaves bool | |||||
// Following options are copied from Options struct. | |||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||||
OnConnect func(ctx context.Context, cn *Conn) error | |||||
Username string | |||||
Password string | |||||
DB int | |||||
MaxRetries int | |||||
MinRetryBackoff time.Duration | |||||
MaxRetryBackoff time.Duration | |||||
DialTimeout time.Duration | |||||
ReadTimeout time.Duration | |||||
WriteTimeout time.Duration | |||||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||||
PoolFIFO bool | |||||
PoolSize int | |||||
MinIdleConns int | |||||
MaxConnAge time.Duration | |||||
PoolTimeout time.Duration | |||||
IdleTimeout time.Duration | |||||
IdleCheckFrequency time.Duration | |||||
TLSConfig *tls.Config | |||||
} | |||||
func (opt *FailoverOptions) clientOptions() *Options { | |||||
return &Options{ | |||||
Addr: "FailoverClient", | |||||
Dialer: opt.Dialer, | |||||
OnConnect: opt.OnConnect, | |||||
DB: opt.DB, | |||||
Username: opt.Username, | |||||
Password: opt.Password, | |||||
MaxRetries: opt.MaxRetries, | |||||
MinRetryBackoff: opt.MinRetryBackoff, | |||||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||||
DialTimeout: opt.DialTimeout, | |||||
ReadTimeout: opt.ReadTimeout, | |||||
WriteTimeout: opt.WriteTimeout, | |||||
PoolFIFO: opt.PoolFIFO, | |||||
PoolSize: opt.PoolSize, | |||||
PoolTimeout: opt.PoolTimeout, | |||||
IdleTimeout: opt.IdleTimeout, | |||||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||||
MinIdleConns: opt.MinIdleConns, | |||||
MaxConnAge: opt.MaxConnAge, | |||||
TLSConfig: opt.TLSConfig, | |||||
} | |||||
} | |||||
func (opt *FailoverOptions) sentinelOptions(addr string) *Options { | |||||
return &Options{ | |||||
Addr: addr, | |||||
Dialer: opt.Dialer, | |||||
OnConnect: opt.OnConnect, | |||||
DB: 0, | |||||
Username: opt.SentinelUsername, | |||||
Password: opt.SentinelPassword, | |||||
MaxRetries: opt.MaxRetries, | |||||
MinRetryBackoff: opt.MinRetryBackoff, | |||||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||||
DialTimeout: opt.DialTimeout, | |||||
ReadTimeout: opt.ReadTimeout, | |||||
WriteTimeout: opt.WriteTimeout, | |||||
PoolFIFO: opt.PoolFIFO, | |||||
PoolSize: opt.PoolSize, | |||||
PoolTimeout: opt.PoolTimeout, | |||||
IdleTimeout: opt.IdleTimeout, | |||||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||||
MinIdleConns: opt.MinIdleConns, | |||||
MaxConnAge: opt.MaxConnAge, | |||||
TLSConfig: opt.TLSConfig, | |||||
} | |||||
} | |||||
func (opt *FailoverOptions) clusterOptions() *ClusterOptions { | |||||
return &ClusterOptions{ | |||||
Dialer: opt.Dialer, | |||||
OnConnect: opt.OnConnect, | |||||
Username: opt.Username, | |||||
Password: opt.Password, | |||||
MaxRedirects: opt.MaxRetries, | |||||
RouteByLatency: opt.RouteByLatency, | |||||
RouteRandomly: opt.RouteRandomly, | |||||
MinRetryBackoff: opt.MinRetryBackoff, | |||||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||||
DialTimeout: opt.DialTimeout, | |||||
ReadTimeout: opt.ReadTimeout, | |||||
WriteTimeout: opt.WriteTimeout, | |||||
PoolFIFO: opt.PoolFIFO, | |||||
PoolSize: opt.PoolSize, | |||||
PoolTimeout: opt.PoolTimeout, | |||||
IdleTimeout: opt.IdleTimeout, | |||||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||||
MinIdleConns: opt.MinIdleConns, | |||||
MaxConnAge: opt.MaxConnAge, | |||||
TLSConfig: opt.TLSConfig, | |||||
} | |||||
} | |||||
// NewFailoverClient returns a Redis client that uses Redis Sentinel | |||||
// for automatic failover. It's safe for concurrent use by multiple | |||||
// goroutines. | |||||
func NewFailoverClient(failoverOpt *FailoverOptions) *Client { | |||||
if failoverOpt.RouteByLatency { | |||||
panic("to route commands by latency, use NewFailoverClusterClient") | |||||
} | |||||
if failoverOpt.RouteRandomly { | |||||
panic("to route commands randomly, use NewFailoverClusterClient") | |||||
} | |||||
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs)) | |||||
copy(sentinelAddrs, failoverOpt.SentinelAddrs) | |||||
rand.Shuffle(len(sentinelAddrs), func(i, j int) { | |||||
sentinelAddrs[i], sentinelAddrs[j] = sentinelAddrs[j], sentinelAddrs[i] | |||||
}) | |||||
failover := &sentinelFailover{ | |||||
opt: failoverOpt, | |||||
sentinelAddrs: sentinelAddrs, | |||||
} | |||||
opt := failoverOpt.clientOptions() | |||||
opt.Dialer = masterSlaveDialer(failover) | |||||
opt.init() | |||||
connPool := newConnPool(opt) | |||||
failover.mu.Lock() | |||||
failover.onFailover = func(ctx context.Context, addr string) { | |||||
_ = connPool.Filter(func(cn *pool.Conn) bool { | |||||
return cn.RemoteAddr().String() != addr | |||||
}) | |||||
} | |||||
failover.mu.Unlock() | |||||
c := Client{ | |||||
baseClient: newBaseClient(opt, connPool), | |||||
ctx: context.Background(), | |||||
} | |||||
c.cmdable = c.Process | |||||
c.onClose = failover.Close | |||||
return &c | |||||
} | |||||
func masterSlaveDialer( | |||||
failover *sentinelFailover, | |||||
) func(ctx context.Context, network, addr string) (net.Conn, error) { | |||||
return func(ctx context.Context, network, _ string) (net.Conn, error) { | |||||
var addr string | |||||
var err error | |||||
if failover.opt.SlaveOnly { | |||||
addr, err = failover.RandomSlaveAddr(ctx) | |||||
} else { | |||||
addr, err = failover.MasterAddr(ctx) | |||||
if err == nil { | |||||
failover.trySwitchMaster(ctx, addr) | |||||
} | |||||
} | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if failover.opt.Dialer != nil { | |||||
return failover.opt.Dialer(ctx, network, addr) | |||||
} | |||||
netDialer := &net.Dialer{ | |||||
Timeout: failover.opt.DialTimeout, | |||||
KeepAlive: 5 * time.Minute, | |||||
} | |||||
if failover.opt.TLSConfig == nil { | |||||
return netDialer.DialContext(ctx, network, addr) | |||||
} | |||||
return tls.DialWithDialer(netDialer, network, addr, failover.opt.TLSConfig) | |||||
} | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// SentinelClient is a client for a Redis Sentinel. | |||||
type SentinelClient struct { | |||||
*baseClient | |||||
hooks | |||||
ctx context.Context | |||||
} | |||||
func NewSentinelClient(opt *Options) *SentinelClient { | |||||
opt.init() | |||||
c := &SentinelClient{ | |||||
baseClient: &baseClient{ | |||||
opt: opt, | |||||
connPool: newConnPool(opt), | |||||
}, | |||||
ctx: context.Background(), | |||||
} | |||||
return c | |||||
} | |||||
func (c *SentinelClient) Context() context.Context { | |||||
return c.ctx | |||||
} | |||||
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { | |||||
if ctx == nil { | |||||
panic("nil context") | |||||
} | |||||
clone := *c | |||||
clone.ctx = ctx | |||||
return &clone | |||||
} | |||||
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { | |||||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||||
} | |||||
func (c *SentinelClient) pubSub() *PubSub { | |||||
pubsub := &PubSub{ | |||||
opt: c.opt, | |||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { | |||||
return c.newConn(ctx) | |||||
}, | |||||
closeConn: c.connPool.CloseConn, | |||||
} | |||||
pubsub.init() | |||||
return pubsub | |||||
} | |||||
// Ping is used to test if a connection is still alive, or to | |||||
// measure latency. | |||||
func (c *SentinelClient) Ping(ctx context.Context) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "ping") | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Subscribe subscribes the client to the specified channels. | |||||
// Channels can be omitted to create empty subscription. | |||||
func (c *SentinelClient) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||||
pubsub := c.pubSub() | |||||
if len(channels) > 0 { | |||||
_ = pubsub.Subscribe(ctx, channels...) | |||||
} | |||||
return pubsub | |||||
} | |||||
// PSubscribe subscribes the client to the given patterns. | |||||
// Patterns can be omitted to create empty subscription. | |||||
func (c *SentinelClient) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||||
pubsub := c.pubSub() | |||||
if len(channels) > 0 { | |||||
_ = pubsub.PSubscribe(ctx, channels...) | |||||
} | |||||
return pubsub | |||||
} | |||||
func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) *StringSliceCmd { | |||||
cmd := NewStringSliceCmd(ctx, "sentinel", "get-master-addr-by-name", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
func (c *SentinelClient) Sentinels(ctx context.Context, name string) *SliceCmd { | |||||
cmd := NewSliceCmd(ctx, "sentinel", "sentinels", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Failover forces a failover as if the master was not reachable, and without | |||||
// asking for agreement to other Sentinels. | |||||
func (c *SentinelClient) Failover(ctx context.Context, name string) *StatusCmd { | |||||
cmd := NewStatusCmd(ctx, "sentinel", "failover", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Reset resets all the masters with matching name. The pattern argument is a | |||||
// glob-style pattern. The reset process clears any previous state in a master | |||||
// (including a failover in progress), and removes every slave and sentinel | |||||
// already discovered and associated with the master. | |||||
func (c *SentinelClient) Reset(ctx context.Context, pattern string) *IntCmd { | |||||
cmd := NewIntCmd(ctx, "sentinel", "reset", pattern) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// FlushConfig forces Sentinel to rewrite its configuration on disk, including | |||||
// the current Sentinel state. | |||||
func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd { | |||||
cmd := NewStatusCmd(ctx, "sentinel", "flushconfig") | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Master shows the state and info of the specified master. | |||||
func (c *SentinelClient) Master(ctx context.Context, name string) *StringStringMapCmd { | |||||
cmd := NewStringStringMapCmd(ctx, "sentinel", "master", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Masters shows a list of monitored masters and their state. | |||||
func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd { | |||||
cmd := NewSliceCmd(ctx, "sentinel", "masters") | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Slaves shows a list of slaves for the specified master and their state. | |||||
func (c *SentinelClient) Slaves(ctx context.Context, name string) *SliceCmd { | |||||
cmd := NewSliceCmd(ctx, "sentinel", "slaves", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// CkQuorum checks if the current Sentinel configuration is able to reach the | |||||
// quorum needed to failover a master, and the majority needed to authorize the | |||||
// failover. This command should be used in monitoring systems to check if a | |||||
// Sentinel deployment is ok. | |||||
func (c *SentinelClient) CkQuorum(ctx context.Context, name string) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "sentinel", "ckquorum", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Monitor tells the Sentinel to start monitoring a new master with the specified | |||||
// name, ip, port, and quorum. | |||||
func (c *SentinelClient) Monitor(ctx context.Context, name, ip, port, quorum string) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "sentinel", "monitor", name, ip, port, quorum) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Set is used in order to change configuration parameters of a specific master. | |||||
func (c *SentinelClient) Set(ctx context.Context, name, option, value string) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "sentinel", "set", name, option, value) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Remove is used in order to remove the specified master: the master will no | |||||
// longer be monitored, and will totally be removed from the internal state of | |||||
// the Sentinel. | |||||
func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd { | |||||
cmd := NewStringCmd(ctx, "sentinel", "remove", name) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
type sentinelFailover struct { | |||||
opt *FailoverOptions | |||||
sentinelAddrs []string | |||||
onFailover func(ctx context.Context, addr string) | |||||
onUpdate func(ctx context.Context) | |||||
mu sync.RWMutex | |||||
_masterAddr string | |||||
sentinel *SentinelClient | |||||
pubsub *PubSub | |||||
} | |||||
func (c *sentinelFailover) Close() error { | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.sentinel != nil { | |||||
return c.closeSentinel() | |||||
} | |||||
return nil | |||||
} | |||||
func (c *sentinelFailover) closeSentinel() error { | |||||
firstErr := c.pubsub.Close() | |||||
c.pubsub = nil | |||||
err := c.sentinel.Close() | |||||
if err != nil && firstErr == nil { | |||||
firstErr = err | |||||
} | |||||
c.sentinel = nil | |||||
return firstErr | |||||
} | |||||
func (c *sentinelFailover) RandomSlaveAddr(ctx context.Context) (string, error) { | |||||
if c.opt == nil { | |||||
return "", errors.New("opt is nil") | |||||
} | |||||
addresses, err := c.slaveAddrs(ctx, false) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
if len(addresses) == 0 && c.opt.UseDisconnectedSlaves { | |||||
addresses, err = c.slaveAddrs(ctx, true) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
} | |||||
if len(addresses) == 0 { | |||||
return c.MasterAddr(ctx) | |||||
} | |||||
return addresses[rand.Intn(len(addresses))], nil | |||||
} | |||||
func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { | |||||
c.mu.RLock() | |||||
sentinel := c.sentinel | |||||
c.mu.RUnlock() | |||||
if sentinel != nil { | |||||
addr := c.getMasterAddr(ctx, sentinel) | |||||
if addr != "" { | |||||
return addr, nil | |||||
} | |||||
} | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.sentinel != nil { | |||||
addr := c.getMasterAddr(ctx, c.sentinel) | |||||
if addr != "" { | |||||
return addr, nil | |||||
} | |||||
_ = c.closeSentinel() | |||||
} | |||||
for i, sentinelAddr := range c.sentinelAddrs { | |||||
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr)) | |||||
masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result() | |||||
if err != nil { | |||||
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName master=%q failed: %s", | |||||
c.opt.MasterName, err) | |||||
_ = sentinel.Close() | |||||
continue | |||||
} | |||||
// Push working sentinel to the top. | |||||
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] | |||||
c.setSentinel(ctx, sentinel) | |||||
addr := net.JoinHostPort(masterAddr[0], masterAddr[1]) | |||||
return addr, nil | |||||
} | |||||
return "", errors.New("redis: all sentinels specified in configuration are unreachable") | |||||
} | |||||
func (c *sentinelFailover) slaveAddrs(ctx context.Context, useDisconnected bool) ([]string, error) { | |||||
c.mu.RLock() | |||||
sentinel := c.sentinel | |||||
c.mu.RUnlock() | |||||
if sentinel != nil { | |||||
addrs := c.getSlaveAddrs(ctx, sentinel) | |||||
if len(addrs) > 0 { | |||||
return addrs, nil | |||||
} | |||||
} | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if c.sentinel != nil { | |||||
addrs := c.getSlaveAddrs(ctx, c.sentinel) | |||||
if len(addrs) > 0 { | |||||
return addrs, nil | |||||
} | |||||
_ = c.closeSentinel() | |||||
} | |||||
var sentinelReachable bool | |||||
for i, sentinelAddr := range c.sentinelAddrs { | |||||
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr)) | |||||
slaves, err := sentinel.Slaves(ctx, c.opt.MasterName).Result() | |||||
if err != nil { | |||||
internal.Logger.Printf(ctx, "sentinel: Slaves master=%q failed: %s", | |||||
c.opt.MasterName, err) | |||||
_ = sentinel.Close() | |||||
continue | |||||
} | |||||
sentinelReachable = true | |||||
addrs := parseSlaveAddrs(slaves, useDisconnected) | |||||
if len(addrs) == 0 { | |||||
continue | |||||
} | |||||
// Push working sentinel to the top. | |||||
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] | |||||
c.setSentinel(ctx, sentinel) | |||||
return addrs, nil | |||||
} | |||||
if sentinelReachable { | |||||
return []string{}, nil | |||||
} | |||||
return []string{}, errors.New("redis: all sentinels specified in configuration are unreachable") | |||||
} | |||||
func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) string { | |||||
addr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result() | |||||
if err != nil { | |||||
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", | |||||
c.opt.MasterName, err) | |||||
return "" | |||||
} | |||||
return net.JoinHostPort(addr[0], addr[1]) | |||||
} | |||||
func (c *sentinelFailover) getSlaveAddrs(ctx context.Context, sentinel *SentinelClient) []string { | |||||
addrs, err := sentinel.Slaves(ctx, c.opt.MasterName).Result() | |||||
if err != nil { | |||||
internal.Logger.Printf(ctx, "sentinel: Slaves name=%q failed: %s", | |||||
c.opt.MasterName, err) | |||||
return []string{} | |||||
} | |||||
return parseSlaveAddrs(addrs, false) | |||||
} | |||||
func parseSlaveAddrs(addrs []interface{}, keepDisconnected bool) []string { | |||||
nodes := make([]string, 0, len(addrs)) | |||||
for _, node := range addrs { | |||||
ip := "" | |||||
port := "" | |||||
flags := []string{} | |||||
lastkey := "" | |||||
isDown := false | |||||
for _, key := range node.([]interface{}) { | |||||
switch lastkey { | |||||
case "ip": | |||||
ip = key.(string) | |||||
case "port": | |||||
port = key.(string) | |||||
case "flags": | |||||
flags = strings.Split(key.(string), ",") | |||||
} | |||||
lastkey = key.(string) | |||||
} | |||||
for _, flag := range flags { | |||||
switch flag { | |||||
case "s_down", "o_down": | |||||
isDown = true | |||||
case "disconnected": | |||||
if !keepDisconnected { | |||||
isDown = true | |||||
} | |||||
} | |||||
} | |||||
if !isDown { | |||||
nodes = append(nodes, net.JoinHostPort(ip, port)) | |||||
} | |||||
} | |||||
return nodes | |||||
} | |||||
func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { | |||||
c.mu.RLock() | |||||
currentAddr := c._masterAddr //nolint:ifshort | |||||
c.mu.RUnlock() | |||||
if addr == currentAddr { | |||||
return | |||||
} | |||||
c.mu.Lock() | |||||
defer c.mu.Unlock() | |||||
if addr == c._masterAddr { | |||||
return | |||||
} | |||||
c._masterAddr = addr | |||||
internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", | |||||
c.opt.MasterName, addr) | |||||
if c.onFailover != nil { | |||||
c.onFailover(ctx, addr) | |||||
} | |||||
} | |||||
func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) { | |||||
if c.sentinel != nil { | |||||
panic("not reached") | |||||
} | |||||
c.sentinel = sentinel | |||||
c.discoverSentinels(ctx) | |||||
c.pubsub = sentinel.Subscribe(ctx, "+switch-master", "+slave-reconf-done") | |||||
go c.listen(c.pubsub) | |||||
} | |||||
func (c *sentinelFailover) discoverSentinels(ctx context.Context) { | |||||
sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result() | |||||
if err != nil { | |||||
internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) | |||||
return | |||||
} | |||||
for _, sentinel := range sentinels { | |||||
vals := sentinel.([]interface{}) | |||||
var ip, port string | |||||
for i := 0; i < len(vals); i += 2 { | |||||
key := vals[i].(string) | |||||
switch key { | |||||
case "ip": | |||||
ip = vals[i+1].(string) | |||||
case "port": | |||||
port = vals[i+1].(string) | |||||
} | |||||
} | |||||
if ip != "" && port != "" { | |||||
sentinelAddr := net.JoinHostPort(ip, port) | |||||
if !contains(c.sentinelAddrs, sentinelAddr) { | |||||
internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", | |||||
sentinelAddr, c.opt.MasterName) | |||||
c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) | |||||
} | |||||
} | |||||
} | |||||
} | |||||
func (c *sentinelFailover) listen(pubsub *PubSub) { | |||||
ctx := context.TODO() | |||||
if c.onUpdate != nil { | |||||
c.onUpdate(ctx) | |||||
} | |||||
ch := pubsub.Channel() | |||||
for msg := range ch { | |||||
if msg.Channel == "+switch-master" { | |||||
parts := strings.Split(msg.Payload, " ") | |||||
if parts[0] != c.opt.MasterName { | |||||
internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) | |||||
continue | |||||
} | |||||
addr := net.JoinHostPort(parts[3], parts[4]) | |||||
c.trySwitchMaster(pubsub.getContext(), addr) | |||||
} | |||||
if c.onUpdate != nil { | |||||
c.onUpdate(ctx) | |||||
} | |||||
} | |||||
} | |||||
func contains(slice []string, str string) bool { | |||||
for _, s := range slice { | |||||
if s == str { | |||||
return true | |||||
} | |||||
} | |||||
return false | |||||
} | |||||
//------------------------------------------------------------------------------ | |||||
// NewFailoverClusterClient returns a client that supports routing read-only commands | |||||
// to a slave node. | |||||
func NewFailoverClusterClient(failoverOpt *FailoverOptions) *ClusterClient { | |||||
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs)) | |||||
copy(sentinelAddrs, failoverOpt.SentinelAddrs) | |||||
failover := &sentinelFailover{ | |||||
opt: failoverOpt, | |||||
sentinelAddrs: sentinelAddrs, | |||||
} | |||||
opt := failoverOpt.clusterOptions() | |||||
opt.ClusterSlots = func(ctx context.Context) ([]ClusterSlot, error) { | |||||
masterAddr, err := failover.MasterAddr(ctx) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
nodes := []ClusterNode{{ | |||||
Addr: masterAddr, | |||||
}} | |||||
slaveAddrs, err := failover.slaveAddrs(ctx, false) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
for _, slaveAddr := range slaveAddrs { | |||||
nodes = append(nodes, ClusterNode{ | |||||
Addr: slaveAddr, | |||||
}) | |||||
} | |||||
slots := []ClusterSlot{ | |||||
{ | |||||
Start: 0, | |||||
End: 16383, | |||||
Nodes: nodes, | |||||
}, | |||||
} | |||||
return slots, nil | |||||
} | |||||
c := NewClusterClient(opt) | |||||
failover.mu.Lock() | |||||
failover.onUpdate = func(ctx context.Context) { | |||||
c.ReloadState(ctx) | |||||
} | |||||
failover.mu.Unlock() | |||||
return c | |||||
} |
@@ -0,0 +1,149 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"github.com/go-redis/redis/v8/internal/pool" | |||||
"github.com/go-redis/redis/v8/internal/proto" | |||||
) | |||||
// TxFailedErr transaction redis failed. | |||||
const TxFailedErr = proto.RedisError("redis: transaction failed") | |||||
// Tx implements Redis transactions as described in | |||||
// http://redis.io/topics/transactions. It's NOT safe for concurrent use | |||||
// by multiple goroutines, because Exec resets list of watched keys. | |||||
// | |||||
// If you don't need WATCH, use Pipeline instead. | |||||
type Tx struct { | |||||
baseClient | |||||
cmdable | |||||
statefulCmdable | |||||
hooks | |||||
ctx context.Context | |||||
} | |||||
func (c *Client) newTx(ctx context.Context) *Tx { | |||||
tx := Tx{ | |||||
baseClient: baseClient{ | |||||
opt: c.opt, | |||||
connPool: pool.NewStickyConnPool(c.connPool), | |||||
}, | |||||
hooks: c.hooks.clone(), | |||||
ctx: ctx, | |||||
} | |||||
tx.init() | |||||
return &tx | |||||
} | |||||
func (c *Tx) init() { | |||||
c.cmdable = c.Process | |||||
c.statefulCmdable = c.Process | |||||
} | |||||
func (c *Tx) Context() context.Context { | |||||
return c.ctx | |||||
} | |||||
func (c *Tx) WithContext(ctx context.Context) *Tx { | |||||
if ctx == nil { | |||||
panic("nil context") | |||||
} | |||||
clone := *c | |||||
clone.init() | |||||
clone.hooks.lock() | |||||
clone.ctx = ctx | |||||
return &clone | |||||
} | |||||
func (c *Tx) Process(ctx context.Context, cmd Cmder) error { | |||||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||||
} | |||||
// Watch prepares a transaction and marks the keys to be watched | |||||
// for conditional execution if there are any keys. | |||||
// | |||||
// The transaction is automatically closed when fn exits. | |||||
func (c *Client) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { | |||||
tx := c.newTx(ctx) | |||||
defer tx.Close(ctx) | |||||
if len(keys) > 0 { | |||||
if err := tx.Watch(ctx, keys...).Err(); err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return fn(tx) | |||||
} | |||||
// Close closes the transaction, releasing any open resources. | |||||
func (c *Tx) Close(ctx context.Context) error { | |||||
_ = c.Unwatch(ctx).Err() | |||||
return c.baseClient.Close() | |||||
} | |||||
// Watch marks the keys to be watched for conditional execution | |||||
// of a transaction. | |||||
func (c *Tx) Watch(ctx context.Context, keys ...string) *StatusCmd { | |||||
args := make([]interface{}, 1+len(keys)) | |||||
args[0] = "watch" | |||||
for i, key := range keys { | |||||
args[1+i] = key | |||||
} | |||||
cmd := NewStatusCmd(ctx, args...) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Unwatch flushes all the previously watched keys for a transaction. | |||||
func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd { | |||||
args := make([]interface{}, 1+len(keys)) | |||||
args[0] = "unwatch" | |||||
for i, key := range keys { | |||||
args[1+i] = key | |||||
} | |||||
cmd := NewStatusCmd(ctx, args...) | |||||
_ = c.Process(ctx, cmd) | |||||
return cmd | |||||
} | |||||
// Pipeline creates a pipeline. Usually it is more convenient to use Pipelined. | |||||
func (c *Tx) Pipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: func(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||||
}, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} | |||||
// Pipelined executes commands queued in the fn outside of the transaction. | |||||
// Use TxPipelined if you need transactional behavior. | |||||
func (c *Tx) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.Pipeline().Pipelined(ctx, fn) | |||||
} | |||||
// TxPipelined executes commands queued in the fn in the transaction. | |||||
// | |||||
// When using WATCH, EXEC will execute commands only if the watched keys | |||||
// were not modified, allowing for a check-and-set mechanism. | |||||
// | |||||
// Exec always returns list of commands. If transaction fails | |||||
// TxFailedErr is returned. Otherwise Exec returns an error of the first | |||||
// failed command or nil. | |||||
func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||||
return c.TxPipeline().Pipelined(ctx, fn) | |||||
} | |||||
// TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined. | |||||
func (c *Tx) TxPipeline() Pipeliner { | |||||
pipe := Pipeline{ | |||||
ctx: c.ctx, | |||||
exec: func(ctx context.Context, cmds []Cmder) error { | |||||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||||
}, | |||||
} | |||||
pipe.init() | |||||
return &pipe | |||||
} |
@@ -0,0 +1,213 @@ | |||||
package redis | |||||
import ( | |||||
"context" | |||||
"crypto/tls" | |||||
"net" | |||||
"time" | |||||
) | |||||
// UniversalOptions information is required by UniversalClient to establish | |||||
// connections. | |||||
type UniversalOptions struct { | |||||
// Either a single address or a seed list of host:port addresses | |||||
// of cluster/sentinel nodes. | |||||
Addrs []string | |||||
// Database to be selected after connecting to the server. | |||||
// Only single-node and failover clients. | |||||
DB int | |||||
// Common options. | |||||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||||
OnConnect func(ctx context.Context, cn *Conn) error | |||||
Username string | |||||
Password string | |||||
SentinelPassword string | |||||
MaxRetries int | |||||
MinRetryBackoff time.Duration | |||||
MaxRetryBackoff time.Duration | |||||
DialTimeout time.Duration | |||||
ReadTimeout time.Duration | |||||
WriteTimeout time.Duration | |||||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||||
PoolFIFO bool | |||||
PoolSize int | |||||
MinIdleConns int | |||||
MaxConnAge time.Duration | |||||
PoolTimeout time.Duration | |||||
IdleTimeout time.Duration | |||||
IdleCheckFrequency time.Duration | |||||
TLSConfig *tls.Config | |||||
// Only cluster clients. | |||||
MaxRedirects int | |||||
ReadOnly bool | |||||
RouteByLatency bool | |||||
RouteRandomly bool | |||||
// The sentinel master name. | |||||
// Only failover clients. | |||||
MasterName string | |||||
} | |||||
// Cluster returns cluster options created from the universal options. | |||||
func (o *UniversalOptions) Cluster() *ClusterOptions { | |||||
if len(o.Addrs) == 0 { | |||||
o.Addrs = []string{"127.0.0.1:6379"} | |||||
} | |||||
return &ClusterOptions{ | |||||
Addrs: o.Addrs, | |||||
Dialer: o.Dialer, | |||||
OnConnect: o.OnConnect, | |||||
Username: o.Username, | |||||
Password: o.Password, | |||||
MaxRedirects: o.MaxRedirects, | |||||
ReadOnly: o.ReadOnly, | |||||
RouteByLatency: o.RouteByLatency, | |||||
RouteRandomly: o.RouteRandomly, | |||||
MaxRetries: o.MaxRetries, | |||||
MinRetryBackoff: o.MinRetryBackoff, | |||||
MaxRetryBackoff: o.MaxRetryBackoff, | |||||
DialTimeout: o.DialTimeout, | |||||
ReadTimeout: o.ReadTimeout, | |||||
WriteTimeout: o.WriteTimeout, | |||||
PoolFIFO: o.PoolFIFO, | |||||
PoolSize: o.PoolSize, | |||||
MinIdleConns: o.MinIdleConns, | |||||
MaxConnAge: o.MaxConnAge, | |||||
PoolTimeout: o.PoolTimeout, | |||||
IdleTimeout: o.IdleTimeout, | |||||
IdleCheckFrequency: o.IdleCheckFrequency, | |||||
TLSConfig: o.TLSConfig, | |||||
} | |||||
} | |||||
// Failover returns failover options created from the universal options. | |||||
func (o *UniversalOptions) Failover() *FailoverOptions { | |||||
if len(o.Addrs) == 0 { | |||||
o.Addrs = []string{"127.0.0.1:26379"} | |||||
} | |||||
return &FailoverOptions{ | |||||
SentinelAddrs: o.Addrs, | |||||
MasterName: o.MasterName, | |||||
Dialer: o.Dialer, | |||||
OnConnect: o.OnConnect, | |||||
DB: o.DB, | |||||
Username: o.Username, | |||||
Password: o.Password, | |||||
SentinelPassword: o.SentinelPassword, | |||||
MaxRetries: o.MaxRetries, | |||||
MinRetryBackoff: o.MinRetryBackoff, | |||||
MaxRetryBackoff: o.MaxRetryBackoff, | |||||
DialTimeout: o.DialTimeout, | |||||
ReadTimeout: o.ReadTimeout, | |||||
WriteTimeout: o.WriteTimeout, | |||||
PoolFIFO: o.PoolFIFO, | |||||
PoolSize: o.PoolSize, | |||||
MinIdleConns: o.MinIdleConns, | |||||
MaxConnAge: o.MaxConnAge, | |||||
PoolTimeout: o.PoolTimeout, | |||||
IdleTimeout: o.IdleTimeout, | |||||
IdleCheckFrequency: o.IdleCheckFrequency, | |||||
TLSConfig: o.TLSConfig, | |||||
} | |||||
} | |||||
// Simple returns basic options created from the universal options. | |||||
func (o *UniversalOptions) Simple() *Options { | |||||
addr := "127.0.0.1:6379" | |||||
if len(o.Addrs) > 0 { | |||||
addr = o.Addrs[0] | |||||
} | |||||
return &Options{ | |||||
Addr: addr, | |||||
Dialer: o.Dialer, | |||||
OnConnect: o.OnConnect, | |||||
DB: o.DB, | |||||
Username: o.Username, | |||||
Password: o.Password, | |||||
MaxRetries: o.MaxRetries, | |||||
MinRetryBackoff: o.MinRetryBackoff, | |||||
MaxRetryBackoff: o.MaxRetryBackoff, | |||||
DialTimeout: o.DialTimeout, | |||||
ReadTimeout: o.ReadTimeout, | |||||
WriteTimeout: o.WriteTimeout, | |||||
PoolFIFO: o.PoolFIFO, | |||||
PoolSize: o.PoolSize, | |||||
MinIdleConns: o.MinIdleConns, | |||||
MaxConnAge: o.MaxConnAge, | |||||
PoolTimeout: o.PoolTimeout, | |||||
IdleTimeout: o.IdleTimeout, | |||||
IdleCheckFrequency: o.IdleCheckFrequency, | |||||
TLSConfig: o.TLSConfig, | |||||
} | |||||
} | |||||
// -------------------------------------------------------------------- | |||||
// UniversalClient is an abstract client which - based on the provided options - | |||||
// represents either a ClusterClient, a FailoverClient, or a single-node Client. | |||||
// This can be useful for testing cluster-specific applications locally or having different | |||||
// clients in different environments. | |||||
type UniversalClient interface { | |||||
Cmdable | |||||
Context() context.Context | |||||
AddHook(Hook) | |||||
Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error | |||||
Do(ctx context.Context, args ...interface{}) *Cmd | |||||
Process(ctx context.Context, cmd Cmder) error | |||||
Subscribe(ctx context.Context, channels ...string) *PubSub | |||||
PSubscribe(ctx context.Context, channels ...string) *PubSub | |||||
Close() error | |||||
PoolStats() *PoolStats | |||||
} | |||||
var ( | |||||
_ UniversalClient = (*Client)(nil) | |||||
_ UniversalClient = (*ClusterClient)(nil) | |||||
_ UniversalClient = (*Ring)(nil) | |||||
) | |||||
// NewUniversalClient returns a new multi client. The type of the returned client depends | |||||
// on the following conditions: | |||||
// | |||||
// 1. If the MasterName option is specified, a sentinel-backed FailoverClient is returned. | |||||
// 2. if the number of Addrs is two or more, a ClusterClient is returned. | |||||
// 3. Otherwise, a single-node Client is returned. | |||||
func NewUniversalClient(opts *UniversalOptions) UniversalClient { | |||||
if opts.MasterName != "" { | |||||
return NewFailoverClient(opts.Failover()) | |||||
} else if len(opts.Addrs) > 1 { | |||||
return NewClusterClient(opts.Cluster()) | |||||
} | |||||
return NewClient(opts.Simple()) | |||||
} |
@@ -0,0 +1,6 @@ | |||||
package redis | |||||
// Version is the current release version. | |||||
func Version() string { | |||||
return "8.11.4" | |||||
} |
@@ -0,0 +1,20 @@ | |||||
# github.com/BurntSushi/toml v1.0.0 | |||||
## explicit; go 1.16 | |||||
github.com/BurntSushi/toml | |||||
github.com/BurntSushi/toml/internal | |||||
# github.com/cespare/xxhash/v2 v2.1.2 | |||||
## explicit; go 1.11 | |||||
github.com/cespare/xxhash/v2 | |||||
# github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f | |||||
## explicit | |||||
github.com/dgryski/go-rendezvous | |||||
# github.com/go-redis/redis/v8 v8.11.4 | |||||
## explicit; go 1.13 | |||||
github.com/go-redis/redis/v8 | |||||
github.com/go-redis/redis/v8/internal | |||||
github.com/go-redis/redis/v8/internal/hashtag | |||||
github.com/go-redis/redis/v8/internal/hscan | |||||
github.com/go-redis/redis/v8/internal/pool | |||||
github.com/go-redis/redis/v8/internal/proto | |||||
github.com/go-redis/redis/v8/internal/rand | |||||
github.com/go-redis/redis/v8/internal/util |