You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

571 lines
12 KiB

  1. // Copyright 2012 Gary Burd
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"): you may
  4. // not use this file except in compliance with the License. You may obtain
  5. // a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  11. // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  12. // License for the specific language governing permissions and limitations
  13. // under the License.
  14. package redis
  15. import (
  16. "bufio"
  17. "bytes"
  18. "errors"
  19. "fmt"
  20. "io"
  21. "net"
  22. "net/url"
  23. "regexp"
  24. "strconv"
  25. "sync"
  26. "time"
  27. )
  28. // conn is the low-level implementation of Conn
  29. type conn struct {
  30. // Shared
  31. mu sync.Mutex
  32. pending int
  33. err error
  34. conn net.Conn
  35. // Read
  36. readTimeout time.Duration
  37. br *bufio.Reader
  38. // Write
  39. writeTimeout time.Duration
  40. bw *bufio.Writer
  41. // Scratch space for formatting argument length.
  42. // '*' or '$', length, "\r\n"
  43. lenScratch [32]byte
  44. // Scratch space for formatting integers and floats.
  45. numScratch [40]byte
  46. }
  47. // DialTimeout acts like Dial but takes timeouts for establishing the
  48. // connection to the server, writing a command and reading a reply.
  49. //
  50. // Deprecated: Use Dial with options instead.
  51. func DialTimeout(network, address string, connectTimeout, readTimeout, writeTimeout time.Duration) (Conn, error) {
  52. return Dial(network, address,
  53. DialConnectTimeout(connectTimeout),
  54. DialReadTimeout(readTimeout),
  55. DialWriteTimeout(writeTimeout))
  56. }
  57. // DialOption specifies an option for dialing a Redis server.
  58. type DialOption struct {
  59. f func(*dialOptions)
  60. }
  61. type dialOptions struct {
  62. readTimeout time.Duration
  63. writeTimeout time.Duration
  64. dial func(network, addr string) (net.Conn, error)
  65. db int
  66. password string
  67. }
  68. // DialReadTimeout specifies the timeout for reading a single command reply.
  69. func DialReadTimeout(d time.Duration) DialOption {
  70. return DialOption{func(do *dialOptions) {
  71. do.readTimeout = d
  72. }}
  73. }
  74. // DialWriteTimeout specifies the timeout for writing a single command.
  75. func DialWriteTimeout(d time.Duration) DialOption {
  76. return DialOption{func(do *dialOptions) {
  77. do.writeTimeout = d
  78. }}
  79. }
  80. // DialConnectTimeout specifies the timeout for connecting to the Redis server.
  81. func DialConnectTimeout(d time.Duration) DialOption {
  82. return DialOption{func(do *dialOptions) {
  83. dialer := net.Dialer{Timeout: d}
  84. do.dial = dialer.Dial
  85. }}
  86. }
  87. // DialNetDial specifies a custom dial function for creating TCP
  88. // connections. If this option is left out, then net.Dial is
  89. // used. DialNetDial overrides DialConnectTimeout.
  90. func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
  91. return DialOption{func(do *dialOptions) {
  92. do.dial = dial
  93. }}
  94. }
  95. // DialDatabase specifies the database to select when dialing a connection.
  96. func DialDatabase(db int) DialOption {
  97. return DialOption{func(do *dialOptions) {
  98. do.db = db
  99. }}
  100. }
  101. // DialPassword specifies the password to use when connecting to
  102. // the Redis server.
  103. func DialPassword(password string) DialOption {
  104. return DialOption{func(do *dialOptions) {
  105. do.password = password
  106. }}
  107. }
  108. // Dial connects to the Redis server at the given network and
  109. // address using the specified options.
  110. func Dial(network, address string, options ...DialOption) (Conn, error) {
  111. do := dialOptions{
  112. dial: net.Dial,
  113. }
  114. for _, option := range options {
  115. option.f(&do)
  116. }
  117. netConn, err := do.dial(network, address)
  118. if err != nil {
  119. return nil, err
  120. }
  121. c := &conn{
  122. conn: netConn,
  123. bw: bufio.NewWriter(netConn),
  124. br: bufio.NewReader(netConn),
  125. readTimeout: do.readTimeout,
  126. writeTimeout: do.writeTimeout,
  127. }
  128. if do.password != "" {
  129. if _, err := c.Do("AUTH", do.password); err != nil {
  130. netConn.Close()
  131. return nil, err
  132. }
  133. }
  134. if do.db != 0 {
  135. if _, err := c.Do("SELECT", do.db); err != nil {
  136. netConn.Close()
  137. return nil, err
  138. }
  139. }
  140. return c, nil
  141. }
  142. var pathDBRegexp = regexp.MustCompile(`/(\d*)\z`)
  143. // DialURL connects to a Redis server at the given URL using the Redis
  144. // URI scheme. URLs should follow the draft IANA specification for the
  145. // scheme (https://www.iana.org/assignments/uri-schemes/prov/redis).
  146. func DialURL(rawurl string, options ...DialOption) (Conn, error) {
  147. u, err := url.Parse(rawurl)
  148. if err != nil {
  149. return nil, err
  150. }
  151. if u.Scheme != "redis" {
  152. return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme)
  153. }
  154. // As per the IANA draft spec, the host defaults to localhost and
  155. // the port defaults to 6379.
  156. host, port, err := net.SplitHostPort(u.Host)
  157. if err != nil {
  158. // assume port is missing
  159. host = u.Host
  160. port = "6379"
  161. }
  162. if host == "" {
  163. host = "localhost"
  164. }
  165. address := net.JoinHostPort(host, port)
  166. if u.User != nil {
  167. password, isSet := u.User.Password()
  168. if isSet {
  169. options = append(options, DialPassword(password))
  170. }
  171. }
  172. match := pathDBRegexp.FindStringSubmatch(u.Path)
  173. if len(match) == 2 {
  174. db := 0
  175. if len(match[1]) > 0 {
  176. db, err = strconv.Atoi(match[1])
  177. if err != nil {
  178. return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
  179. }
  180. }
  181. if db != 0 {
  182. options = append(options, DialDatabase(db))
  183. }
  184. } else if u.Path != "" {
  185. return nil, fmt.Errorf("invalid database: %s", u.Path[1:])
  186. }
  187. return Dial("tcp", address, options...)
  188. }
  189. // NewConn returns a new Redigo connection for the given net connection.
  190. func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
  191. return &conn{
  192. conn: netConn,
  193. bw: bufio.NewWriter(netConn),
  194. br: bufio.NewReader(netConn),
  195. readTimeout: readTimeout,
  196. writeTimeout: writeTimeout,
  197. }
  198. }
  199. func (c *conn) Close() error {
  200. c.mu.Lock()
  201. err := c.err
  202. if c.err == nil {
  203. c.err = errors.New("redigo: closed")
  204. err = c.conn.Close()
  205. }
  206. c.mu.Unlock()
  207. return err
  208. }
  209. func (c *conn) fatal(err error) error {
  210. c.mu.Lock()
  211. if c.err == nil {
  212. c.err = err
  213. // Close connection to force errors on subsequent calls and to unblock
  214. // other reader or writer.
  215. c.conn.Close()
  216. }
  217. c.mu.Unlock()
  218. return err
  219. }
  220. func (c *conn) Err() error {
  221. c.mu.Lock()
  222. err := c.err
  223. c.mu.Unlock()
  224. return err
  225. }
  226. func (c *conn) writeLen(prefix byte, n int) error {
  227. c.lenScratch[len(c.lenScratch)-1] = '\n'
  228. c.lenScratch[len(c.lenScratch)-2] = '\r'
  229. i := len(c.lenScratch) - 3
  230. for {
  231. c.lenScratch[i] = byte('0' + n%10)
  232. i -= 1
  233. n = n / 10
  234. if n == 0 {
  235. break
  236. }
  237. }
  238. c.lenScratch[i] = prefix
  239. _, err := c.bw.Write(c.lenScratch[i:])
  240. return err
  241. }
  242. func (c *conn) writeString(s string) error {
  243. c.writeLen('$', len(s))
  244. c.bw.WriteString(s)
  245. _, err := c.bw.WriteString("\r\n")
  246. return err
  247. }
  248. func (c *conn) writeBytes(p []byte) error {
  249. c.writeLen('$', len(p))
  250. c.bw.Write(p)
  251. _, err := c.bw.WriteString("\r\n")
  252. return err
  253. }
  254. func (c *conn) writeInt64(n int64) error {
  255. return c.writeBytes(strconv.AppendInt(c.numScratch[:0], n, 10))
  256. }
  257. func (c *conn) writeFloat64(n float64) error {
  258. return c.writeBytes(strconv.AppendFloat(c.numScratch[:0], n, 'g', -1, 64))
  259. }
  260. func (c *conn) writeCommand(cmd string, args []interface{}) (err error) {
  261. c.writeLen('*', 1+len(args))
  262. err = c.writeString(cmd)
  263. for _, arg := range args {
  264. if err != nil {
  265. break
  266. }
  267. switch arg := arg.(type) {
  268. case string:
  269. err = c.writeString(arg)
  270. case []byte:
  271. err = c.writeBytes(arg)
  272. case int:
  273. err = c.writeInt64(int64(arg))
  274. case int64:
  275. err = c.writeInt64(arg)
  276. case float64:
  277. err = c.writeFloat64(arg)
  278. case bool:
  279. if arg {
  280. err = c.writeString("1")
  281. } else {
  282. err = c.writeString("0")
  283. }
  284. case nil:
  285. err = c.writeString("")
  286. default:
  287. var buf bytes.Buffer
  288. fmt.Fprint(&buf, arg)
  289. err = c.writeBytes(buf.Bytes())
  290. }
  291. }
  292. return err
  293. }
  294. type protocolError string
  295. func (pe protocolError) Error() string {
  296. return fmt.Sprintf("redigo: %s (possible server error or unsupported concurrent read by application)", string(pe))
  297. }
  298. func (c *conn) readLine() ([]byte, error) {
  299. p, err := c.br.ReadSlice('\n')
  300. if err == bufio.ErrBufferFull {
  301. return nil, protocolError("long response line")
  302. }
  303. if err != nil {
  304. return nil, err
  305. }
  306. i := len(p) - 2
  307. if i < 0 || p[i] != '\r' {
  308. return nil, protocolError("bad response line terminator")
  309. }
  310. return p[:i], nil
  311. }
  312. // parseLen parses bulk string and array lengths.
  313. func parseLen(p []byte) (int, error) {
  314. if len(p) == 0 {
  315. return -1, protocolError("malformed length")
  316. }
  317. if p[0] == '-' && len(p) == 2 && p[1] == '1' {
  318. // handle $-1 and $-1 null replies.
  319. return -1, nil
  320. }
  321. var n int
  322. for _, b := range p {
  323. n *= 10
  324. if b < '0' || b > '9' {
  325. return -1, protocolError("illegal bytes in length")
  326. }
  327. n += int(b - '0')
  328. }
  329. return n, nil
  330. }
  331. // parseInt parses an integer reply.
  332. func parseInt(p []byte) (interface{}, error) {
  333. if len(p) == 0 {
  334. return 0, protocolError("malformed integer")
  335. }
  336. var negate bool
  337. if p[0] == '-' {
  338. negate = true
  339. p = p[1:]
  340. if len(p) == 0 {
  341. return 0, protocolError("malformed integer")
  342. }
  343. }
  344. var n int64
  345. for _, b := range p {
  346. n *= 10
  347. if b < '0' || b > '9' {
  348. return 0, protocolError("illegal bytes in length")
  349. }
  350. n += int64(b - '0')
  351. }
  352. if negate {
  353. n = -n
  354. }
  355. return n, nil
  356. }
  357. var (
  358. okReply interface{} = "OK"
  359. pongReply interface{} = "PONG"
  360. )
  361. func (c *conn) readReply() (interface{}, error) {
  362. line, err := c.readLine()
  363. if err != nil {
  364. return nil, err
  365. }
  366. if len(line) == 0 {
  367. return nil, protocolError("short response line")
  368. }
  369. switch line[0] {
  370. case '+':
  371. switch {
  372. case len(line) == 3 && line[1] == 'O' && line[2] == 'K':
  373. // Avoid allocation for frequent "+OK" response.
  374. return okReply, nil
  375. case len(line) == 5 && line[1] == 'P' && line[2] == 'O' && line[3] == 'N' && line[4] == 'G':
  376. // Avoid allocation in PING command benchmarks :)
  377. return pongReply, nil
  378. default:
  379. return string(line[1:]), nil
  380. }
  381. case '-':
  382. return Error(string(line[1:])), nil
  383. case ':':
  384. return parseInt(line[1:])
  385. case '$':
  386. n, err := parseLen(line[1:])
  387. if n < 0 || err != nil {
  388. return nil, err
  389. }
  390. p := make([]byte, n)
  391. _, err = io.ReadFull(c.br, p)
  392. if err != nil {
  393. return nil, err
  394. }
  395. if line, err := c.readLine(); err != nil {
  396. return nil, err
  397. } else if len(line) != 0 {
  398. return nil, protocolError("bad bulk string format")
  399. }
  400. return p, nil
  401. case '*':
  402. n, err := parseLen(line[1:])
  403. if n < 0 || err != nil {
  404. return nil, err
  405. }
  406. r := make([]interface{}, n)
  407. for i := range r {
  408. r[i], err = c.readReply()
  409. if err != nil {
  410. return nil, err
  411. }
  412. }
  413. return r, nil
  414. }
  415. return nil, protocolError("unexpected response line")
  416. }
  417. func (c *conn) Send(cmd string, args ...interface{}) error {
  418. c.mu.Lock()
  419. c.pending += 1
  420. c.mu.Unlock()
  421. if c.writeTimeout != 0 {
  422. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  423. }
  424. if err := c.writeCommand(cmd, args); err != nil {
  425. return c.fatal(err)
  426. }
  427. return nil
  428. }
  429. func (c *conn) Flush() error {
  430. if c.writeTimeout != 0 {
  431. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  432. }
  433. if err := c.bw.Flush(); err != nil {
  434. return c.fatal(err)
  435. }
  436. return nil
  437. }
  438. func (c *conn) Receive() (reply interface{}, err error) {
  439. if c.readTimeout != 0 {
  440. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  441. }
  442. if reply, err = c.readReply(); err != nil {
  443. return nil, c.fatal(err)
  444. }
  445. // When using pub/sub, the number of receives can be greater than the
  446. // number of sends. To enable normal use of the connection after
  447. // unsubscribing from all channels, we do not decrement pending to a
  448. // negative value.
  449. //
  450. // The pending field is decremented after the reply is read to handle the
  451. // case where Receive is called before Send.
  452. c.mu.Lock()
  453. if c.pending > 0 {
  454. c.pending -= 1
  455. }
  456. c.mu.Unlock()
  457. if err, ok := reply.(Error); ok {
  458. return nil, err
  459. }
  460. return
  461. }
  462. func (c *conn) Do(cmd string, args ...interface{}) (interface{}, error) {
  463. c.mu.Lock()
  464. pending := c.pending
  465. c.pending = 0
  466. c.mu.Unlock()
  467. if cmd == "" && pending == 0 {
  468. return nil, nil
  469. }
  470. if c.writeTimeout != 0 {
  471. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  472. }
  473. if cmd != "" {
  474. if err := c.writeCommand(cmd, args); err != nil {
  475. return nil, c.fatal(err)
  476. }
  477. }
  478. if err := c.bw.Flush(); err != nil {
  479. return nil, c.fatal(err)
  480. }
  481. if c.readTimeout != 0 {
  482. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  483. }
  484. if cmd == "" {
  485. reply := make([]interface{}, pending)
  486. for i := range reply {
  487. r, e := c.readReply()
  488. if e != nil {
  489. return nil, c.fatal(e)
  490. }
  491. reply[i] = r
  492. }
  493. return reply, nil
  494. }
  495. var err error
  496. var reply interface{}
  497. for i := 0; i <= pending; i++ {
  498. var e error
  499. if reply, e = c.readReply(); e != nil {
  500. return nil, c.fatal(e)
  501. }
  502. if e, ok := reply.(Error); ok && err == nil {
  503. err = e
  504. }
  505. }
  506. return reply, err
  507. }