You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

245 lines
5.1 KiB

  1. /*
  2. *
  3. * Copyright 2017 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package bufconn provides a net.Conn implemented by a buffer and related
  19. // dialing and listening functionality.
  20. package bufconn
  21. import (
  22. "fmt"
  23. "io"
  24. "net"
  25. "sync"
  26. "time"
  27. )
  28. // Listener implements a net.Listener that creates local, buffered net.Conns
  29. // via its Accept and Dial method.
  30. type Listener struct {
  31. mu sync.Mutex
  32. sz int
  33. ch chan net.Conn
  34. done chan struct{}
  35. }
  36. var errClosed = fmt.Errorf("closed")
  37. // Listen returns a Listener that can only be contacted by its own Dialers and
  38. // creates buffered connections between the two.
  39. func Listen(sz int) *Listener {
  40. return &Listener{sz: sz, ch: make(chan net.Conn), done: make(chan struct{})}
  41. }
  42. // Accept blocks until Dial is called, then returns a net.Conn for the server
  43. // half of the connection.
  44. func (l *Listener) Accept() (net.Conn, error) {
  45. select {
  46. case <-l.done:
  47. return nil, errClosed
  48. case c := <-l.ch:
  49. return c, nil
  50. }
  51. }
  52. // Close stops the listener.
  53. func (l *Listener) Close() error {
  54. l.mu.Lock()
  55. defer l.mu.Unlock()
  56. select {
  57. case <-l.done:
  58. // Already closed.
  59. break
  60. default:
  61. close(l.done)
  62. }
  63. return nil
  64. }
  65. // Addr reports the address of the listener.
  66. func (l *Listener) Addr() net.Addr { return addr{} }
  67. // Dial creates an in-memory full-duplex network connection, unblocks Accept by
  68. // providing it the server half of the connection, and returns the client half
  69. // of the connection.
  70. func (l *Listener) Dial() (net.Conn, error) {
  71. p1, p2 := newPipe(l.sz), newPipe(l.sz)
  72. select {
  73. case <-l.done:
  74. return nil, errClosed
  75. case l.ch <- &conn{p1, p2}:
  76. return &conn{p2, p1}, nil
  77. }
  78. }
  79. type pipe struct {
  80. mu sync.Mutex
  81. // buf contains the data in the pipe. It is a ring buffer of fixed capacity,
  82. // with r and w pointing to the offset to read and write, respsectively.
  83. //
  84. // Data is read between [r, w) and written to [w, r), wrapping around the end
  85. // of the slice if necessary.
  86. //
  87. // The buffer is empty if r == len(buf), otherwise if r == w, it is full.
  88. //
  89. // w and r are always in the range [0, cap(buf)) and [0, len(buf)].
  90. buf []byte
  91. w, r int
  92. wwait sync.Cond
  93. rwait sync.Cond
  94. closed bool
  95. writeClosed bool
  96. }
  97. func newPipe(sz int) *pipe {
  98. p := &pipe{buf: make([]byte, 0, sz)}
  99. p.wwait.L = &p.mu
  100. p.rwait.L = &p.mu
  101. return p
  102. }
  103. func (p *pipe) empty() bool {
  104. return p.r == len(p.buf)
  105. }
  106. func (p *pipe) full() bool {
  107. return p.r < len(p.buf) && p.r == p.w
  108. }
  109. func (p *pipe) Read(b []byte) (n int, err error) {
  110. p.mu.Lock()
  111. defer p.mu.Unlock()
  112. // Block until p has data.
  113. for {
  114. if p.closed {
  115. return 0, io.ErrClosedPipe
  116. }
  117. if !p.empty() {
  118. break
  119. }
  120. if p.writeClosed {
  121. return 0, io.EOF
  122. }
  123. p.rwait.Wait()
  124. }
  125. wasFull := p.full()
  126. n = copy(b, p.buf[p.r:len(p.buf)])
  127. p.r += n
  128. if p.r == cap(p.buf) {
  129. p.r = 0
  130. p.buf = p.buf[:p.w]
  131. }
  132. // Signal a blocked writer, if any
  133. if wasFull {
  134. p.wwait.Signal()
  135. }
  136. return n, nil
  137. }
  138. func (p *pipe) Write(b []byte) (n int, err error) {
  139. p.mu.Lock()
  140. defer p.mu.Unlock()
  141. if p.closed {
  142. return 0, io.ErrClosedPipe
  143. }
  144. for len(b) > 0 {
  145. // Block until p is not full.
  146. for {
  147. if p.closed || p.writeClosed {
  148. return 0, io.ErrClosedPipe
  149. }
  150. if !p.full() {
  151. break
  152. }
  153. p.wwait.Wait()
  154. }
  155. wasEmpty := p.empty()
  156. end := cap(p.buf)
  157. if p.w < p.r {
  158. end = p.r
  159. }
  160. x := copy(p.buf[p.w:end], b)
  161. b = b[x:]
  162. n += x
  163. p.w += x
  164. if p.w > len(p.buf) {
  165. p.buf = p.buf[:p.w]
  166. }
  167. if p.w == cap(p.buf) {
  168. p.w = 0
  169. }
  170. // Signal a blocked reader, if any.
  171. if wasEmpty {
  172. p.rwait.Signal()
  173. }
  174. }
  175. return n, nil
  176. }
  177. func (p *pipe) Close() error {
  178. p.mu.Lock()
  179. defer p.mu.Unlock()
  180. p.closed = true
  181. // Signal all blocked readers and writers to return an error.
  182. p.rwait.Broadcast()
  183. p.wwait.Broadcast()
  184. return nil
  185. }
  186. func (p *pipe) closeWrite() error {
  187. p.mu.Lock()
  188. defer p.mu.Unlock()
  189. p.writeClosed = true
  190. // Signal all blocked readers and writers to return an error.
  191. p.rwait.Broadcast()
  192. p.wwait.Broadcast()
  193. return nil
  194. }
  195. type conn struct {
  196. io.Reader
  197. io.Writer
  198. }
  199. func (c *conn) Close() error {
  200. err1 := c.Reader.(*pipe).Close()
  201. err2 := c.Writer.(*pipe).closeWrite()
  202. if err1 != nil {
  203. return err1
  204. }
  205. return err2
  206. }
  207. func (*conn) LocalAddr() net.Addr { return addr{} }
  208. func (*conn) RemoteAddr() net.Addr { return addr{} }
  209. func (c *conn) SetDeadline(t time.Time) error { return fmt.Errorf("unsupported") }
  210. func (c *conn) SetReadDeadline(t time.Time) error { return fmt.Errorf("unsupported") }
  211. func (c *conn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("unsupported") }
  212. type addr struct{}
  213. func (addr) Network() string { return "bufconn" }
  214. func (addr) String() string { return "bufconn" }