You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

275 lines
8.9 KiB

  1. /*
  2. *
  3. * Copyright 2018 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package conn
  19. import (
  20. "bytes"
  21. "encoding/binary"
  22. "fmt"
  23. "io"
  24. "math"
  25. "net"
  26. "reflect"
  27. "testing"
  28. core "google.golang.org/grpc/credentials/alts/internal"
  29. )
  30. var (
  31. nextProtocols = []string{"ALTSRP_GCM_AES128"}
  32. altsRecordFuncs = map[string]ALTSRecordFunc{
  33. // ALTS handshaker protocols.
  34. "ALTSRP_GCM_AES128": func(s core.Side, keyData []byte) (ALTSRecordCrypto, error) {
  35. return NewAES128GCM(s, keyData)
  36. },
  37. }
  38. )
  39. func init() {
  40. for protocol, f := range altsRecordFuncs {
  41. if err := RegisterProtocol(protocol, f); err != nil {
  42. panic(err)
  43. }
  44. }
  45. }
  46. // testConn mimics a net.Conn to the peer.
  47. type testConn struct {
  48. net.Conn
  49. in *bytes.Buffer
  50. out *bytes.Buffer
  51. }
  52. func (c *testConn) Read(b []byte) (n int, err error) {
  53. return c.in.Read(b)
  54. }
  55. func (c *testConn) Write(b []byte) (n int, err error) {
  56. return c.out.Write(b)
  57. }
  58. func (c *testConn) Close() error {
  59. return nil
  60. }
  61. func newTestALTSRecordConn(in, out *bytes.Buffer, side core.Side, np string) *conn {
  62. key := []byte{
  63. // 16 arbitrary bytes.
  64. 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xd2, 0x4c, 0xce, 0x4f, 0x49}
  65. tc := testConn{
  66. in: in,
  67. out: out,
  68. }
  69. c, err := NewConn(&tc, side, np, key, nil)
  70. if err != nil {
  71. panic(fmt.Sprintf("Unexpected error creating test ALTS record connection: %v", err))
  72. }
  73. return c.(*conn)
  74. }
  75. func newConnPair(np string) (client, server *conn) {
  76. clientBuf := new(bytes.Buffer)
  77. serverBuf := new(bytes.Buffer)
  78. clientConn := newTestALTSRecordConn(clientBuf, serverBuf, core.ClientSide, np)
  79. serverConn := newTestALTSRecordConn(serverBuf, clientBuf, core.ServerSide, np)
  80. return clientConn, serverConn
  81. }
  82. func testPingPong(t *testing.T, np string) {
  83. clientConn, serverConn := newConnPair(np)
  84. clientMsg := []byte("Client Message")
  85. if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
  86. t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
  87. }
  88. rcvClientMsg := make([]byte, len(clientMsg))
  89. if n, err := serverConn.Read(rcvClientMsg); n != len(rcvClientMsg) || err != nil {
  90. t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(rcvClientMsg))
  91. }
  92. if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
  93. t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
  94. }
  95. serverMsg := []byte("Server Message")
  96. if n, err := serverConn.Write(serverMsg); n != len(serverMsg) || err != nil {
  97. t.Fatalf("Server Write() = %v, %v; want %v, <nil>", n, err, len(serverMsg))
  98. }
  99. rcvServerMsg := make([]byte, len(serverMsg))
  100. if n, err := clientConn.Read(rcvServerMsg); n != len(rcvServerMsg) || err != nil {
  101. t.Fatalf("Client Read() = %v, %v; want %v, <nil>", n, err, len(rcvServerMsg))
  102. }
  103. if !reflect.DeepEqual(serverMsg, rcvServerMsg) {
  104. t.Fatalf("Server Write()/Client Read() = %v, want %v", rcvServerMsg, serverMsg)
  105. }
  106. }
  107. func TestPingPong(t *testing.T) {
  108. for _, np := range nextProtocols {
  109. testPingPong(t, np)
  110. }
  111. }
  112. func testSmallReadBuffer(t *testing.T, np string) {
  113. clientConn, serverConn := newConnPair(np)
  114. msg := []byte("Very Important Message")
  115. if n, err := clientConn.Write(msg); err != nil {
  116. t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
  117. }
  118. rcvMsg := make([]byte, len(msg))
  119. n := 2 // Arbitrary index to break rcvMsg in two.
  120. rcvMsg1 := rcvMsg[:n]
  121. rcvMsg2 := rcvMsg[n:]
  122. if n, err := serverConn.Read(rcvMsg1); n != len(rcvMsg1) || err != nil {
  123. t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg1))
  124. }
  125. if n, err := serverConn.Read(rcvMsg2); n != len(rcvMsg2) || err != nil {
  126. t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg2))
  127. }
  128. if !reflect.DeepEqual(msg, rcvMsg) {
  129. t.Fatalf("Write()/Read() = %v, want %v", rcvMsg, msg)
  130. }
  131. }
  132. func TestSmallReadBuffer(t *testing.T) {
  133. for _, np := range nextProtocols {
  134. testSmallReadBuffer(t, np)
  135. }
  136. }
  137. func testLargeMsg(t *testing.T, np string) {
  138. clientConn, serverConn := newConnPair(np)
  139. // msgLen is such that the length in the framing is larger than the
  140. // default size of one frame.
  141. msgLen := altsRecordDefaultLength - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
  142. msg := make([]byte, msgLen)
  143. if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
  144. t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
  145. }
  146. rcvMsg := make([]byte, len(msg))
  147. if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
  148. t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
  149. }
  150. if !reflect.DeepEqual(msg, rcvMsg) {
  151. t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
  152. }
  153. }
  154. func TestLargeMsg(t *testing.T) {
  155. for _, np := range nextProtocols {
  156. testLargeMsg(t, np)
  157. }
  158. }
  159. func testIncorrectMsgType(t *testing.T, np string) {
  160. // framedMsg is an empty ciphertext with correct framing but wrong
  161. // message type.
  162. framedMsg := make([]byte, MsgLenFieldSize+msgTypeFieldSize)
  163. binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], msgTypeFieldSize)
  164. wrongMsgType := uint32(0x22)
  165. binary.LittleEndian.PutUint32(framedMsg[MsgLenFieldSize:], wrongMsgType)
  166. in := bytes.NewBuffer(framedMsg)
  167. c := newTestALTSRecordConn(in, nil, core.ClientSide, np)
  168. b := make([]byte, 1)
  169. if n, err := c.Read(b); n != 0 || err == nil {
  170. t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received frame with incorrect message type %v", wrongMsgType))
  171. }
  172. }
  173. func TestIncorrectMsgType(t *testing.T) {
  174. for _, np := range nextProtocols {
  175. testIncorrectMsgType(t, np)
  176. }
  177. }
  178. func testFrameTooLarge(t *testing.T, np string) {
  179. buf := new(bytes.Buffer)
  180. clientConn := newTestALTSRecordConn(nil, buf, core.ClientSide, np)
  181. serverConn := newTestALTSRecordConn(buf, nil, core.ServerSide, np)
  182. // payloadLen is such that the length in the framing is larger than
  183. // allowed in one frame.
  184. payloadLen := altsRecordLengthLimit - msgTypeFieldSize - clientConn.crypto.EncryptionOverhead() + 1
  185. payload := make([]byte, payloadLen)
  186. c, err := clientConn.crypto.Encrypt(nil, payload)
  187. if err != nil {
  188. t.Fatalf(fmt.Sprintf("Error encrypting message: %v", err))
  189. }
  190. msgLen := msgTypeFieldSize + len(c)
  191. framedMsg := make([]byte, MsgLenFieldSize+msgLen)
  192. binary.LittleEndian.PutUint32(framedMsg[:MsgLenFieldSize], uint32(msgTypeFieldSize+len(c)))
  193. msg := framedMsg[MsgLenFieldSize:]
  194. binary.LittleEndian.PutUint32(msg[:msgTypeFieldSize], altsRecordMsgType)
  195. copy(msg[msgTypeFieldSize:], c)
  196. if _, err = buf.Write(framedMsg); err != nil {
  197. t.Fatal(fmt.Sprintf("Unexpected error writing to buffer: %v", err))
  198. }
  199. b := make([]byte, 1)
  200. if n, err := serverConn.Read(b); n != 0 || err == nil {
  201. t.Fatalf("Read() = <nil>, want %v", fmt.Errorf("received the frame length %d larger than the limit %d", altsRecordLengthLimit+1, altsRecordLengthLimit))
  202. }
  203. }
  204. func TestFrameTooLarge(t *testing.T) {
  205. for _, np := range nextProtocols {
  206. testFrameTooLarge(t, np)
  207. }
  208. }
  209. func testWriteLargeData(t *testing.T, np string) {
  210. // Test sending and receiving messages larger than the maximum write
  211. // buffer size.
  212. clientConn, serverConn := newConnPair(np)
  213. // Message size is intentionally chosen to not be multiple of
  214. // payloadLengthLimtit.
  215. msgSize := altsWriteBufferMaxSize + (100 * 1024)
  216. clientMsg := make([]byte, msgSize)
  217. for i := 0; i < msgSize; i++ {
  218. clientMsg[i] = 0xAA
  219. }
  220. if n, err := clientConn.Write(clientMsg); n != len(clientMsg) || err != nil {
  221. t.Fatalf("Client Write() = %v, %v; want %v, <nil>", n, err, len(clientMsg))
  222. }
  223. // We need to keep reading until the entire message is received. The
  224. // reason we set all bytes of the message to a value other than zero is
  225. // to avoid ambiguous zero-init value of rcvClientMsg buffer and the
  226. // actual received data.
  227. rcvClientMsg := make([]byte, 0, msgSize)
  228. numberOfExpectedFrames := int(math.Ceil(float64(msgSize) / float64(serverConn.payloadLengthLimit)))
  229. for i := 0; i < numberOfExpectedFrames; i++ {
  230. expectedRcvSize := serverConn.payloadLengthLimit
  231. if i == numberOfExpectedFrames-1 {
  232. // Last frame might be smaller.
  233. expectedRcvSize = msgSize % serverConn.payloadLengthLimit
  234. }
  235. tmpBuf := make([]byte, expectedRcvSize)
  236. if n, err := serverConn.Read(tmpBuf); n != len(tmpBuf) || err != nil {
  237. t.Fatalf("Server Read() = %v, %v; want %v, <nil>", n, err, len(tmpBuf))
  238. }
  239. rcvClientMsg = append(rcvClientMsg, tmpBuf...)
  240. }
  241. if !reflect.DeepEqual(clientMsg, rcvClientMsg) {
  242. t.Fatalf("Client Write()/Server Read() = %v, want %v", rcvClientMsg, clientMsg)
  243. }
  244. }
  245. func TestWriteLargeData(t *testing.T) {
  246. for _, np := range nextProtocols {
  247. testWriteLargeData(t, np)
  248. }
  249. }