You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1737 lines
49 KiB

  1. /*
  2. Copyright 2017 Google LLC
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package spanner
  14. import (
  15. "context"
  16. "errors"
  17. "fmt"
  18. "io"
  19. "sync/atomic"
  20. "testing"
  21. "time"
  22. "cloud.google.com/go/spanner/internal/backoff"
  23. "cloud.google.com/go/spanner/internal/testutil"
  24. "github.com/golang/protobuf/proto"
  25. proto3 "github.com/golang/protobuf/ptypes/struct"
  26. "google.golang.org/api/iterator"
  27. sppb "google.golang.org/genproto/googleapis/spanner/v1"
  28. "google.golang.org/grpc"
  29. "google.golang.org/grpc/codes"
  30. "google.golang.org/grpc/status"
  31. )
  32. var (
  33. // Mocked transaction timestamp.
  34. trxTs = time.Unix(1, 2)
  35. // Metadata for mocked KV table, its rows are returned by SingleUse transactions.
  36. kvMeta = func() *sppb.ResultSetMetadata {
  37. meta := testutil.KvMeta
  38. meta.Transaction = &sppb.Transaction{
  39. ReadTimestamp: timestampProto(trxTs),
  40. }
  41. return &meta
  42. }()
  43. // Metadata for mocked ListKV table, which uses List for its key and value.
  44. // Its rows are returned by snapshot readonly transactions, as indicated in the transaction metadata.
  45. kvListMeta = &sppb.ResultSetMetadata{
  46. RowType: &sppb.StructType{
  47. Fields: []*sppb.StructType_Field{
  48. {
  49. Name: "Key",
  50. Type: &sppb.Type{
  51. Code: sppb.TypeCode_ARRAY,
  52. ArrayElementType: &sppb.Type{
  53. Code: sppb.TypeCode_STRING,
  54. },
  55. },
  56. },
  57. {
  58. Name: "Value",
  59. Type: &sppb.Type{
  60. Code: sppb.TypeCode_ARRAY,
  61. ArrayElementType: &sppb.Type{
  62. Code: sppb.TypeCode_STRING,
  63. },
  64. },
  65. },
  66. },
  67. },
  68. Transaction: &sppb.Transaction{
  69. Id: transactionID{5, 6, 7, 8, 9},
  70. ReadTimestamp: timestampProto(trxTs),
  71. },
  72. }
  73. // Metadata for mocked schema of a query result set, which has two struct
  74. // columns named "Col1" and "Col2", the struct's schema is like the
  75. // following:
  76. //
  77. // STRUCT {
  78. // INT
  79. // LIST<STRING>
  80. // }
  81. //
  82. // Its rows are returned in readwrite transaction, as indicated in the transaction metadata.
  83. kvObjectMeta = &sppb.ResultSetMetadata{
  84. RowType: &sppb.StructType{
  85. Fields: []*sppb.StructType_Field{
  86. {
  87. Name: "Col1",
  88. Type: &sppb.Type{
  89. Code: sppb.TypeCode_STRUCT,
  90. StructType: &sppb.StructType{
  91. Fields: []*sppb.StructType_Field{
  92. {
  93. Name: "foo-f1",
  94. Type: &sppb.Type{
  95. Code: sppb.TypeCode_INT64,
  96. },
  97. },
  98. {
  99. Name: "foo-f2",
  100. Type: &sppb.Type{
  101. Code: sppb.TypeCode_ARRAY,
  102. ArrayElementType: &sppb.Type{
  103. Code: sppb.TypeCode_STRING,
  104. },
  105. },
  106. },
  107. },
  108. },
  109. },
  110. },
  111. {
  112. Name: "Col2",
  113. Type: &sppb.Type{
  114. Code: sppb.TypeCode_STRUCT,
  115. StructType: &sppb.StructType{
  116. Fields: []*sppb.StructType_Field{
  117. {
  118. Name: "bar-f1",
  119. Type: &sppb.Type{
  120. Code: sppb.TypeCode_INT64,
  121. },
  122. },
  123. {
  124. Name: "bar-f2",
  125. Type: &sppb.Type{
  126. Code: sppb.TypeCode_ARRAY,
  127. ArrayElementType: &sppb.Type{
  128. Code: sppb.TypeCode_STRING,
  129. },
  130. },
  131. },
  132. },
  133. },
  134. },
  135. },
  136. },
  137. },
  138. Transaction: &sppb.Transaction{
  139. Id: transactionID{1, 2, 3, 4, 5},
  140. },
  141. }
  142. )
  143. // String implements fmt.stringer.
  144. func (r *Row) String() string {
  145. return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
  146. }
  147. func describeRows(l []*Row) string {
  148. // generate a nice test failure description
  149. var s = "["
  150. for i, r := range l {
  151. if i != 0 {
  152. s += ",\n "
  153. }
  154. s += fmt.Sprint(r)
  155. }
  156. s += "]"
  157. return s
  158. }
  159. // Helper for generating proto3 Value_ListValue instances, making
  160. // test code shorter and readable.
  161. func genProtoListValue(v ...string) *proto3.Value_ListValue {
  162. r := &proto3.Value_ListValue{
  163. ListValue: &proto3.ListValue{
  164. Values: []*proto3.Value{},
  165. },
  166. }
  167. for _, e := range v {
  168. r.ListValue.Values = append(
  169. r.ListValue.Values,
  170. &proto3.Value{
  171. Kind: &proto3.Value_StringValue{StringValue: e},
  172. },
  173. )
  174. }
  175. return r
  176. }
  177. // Test Row generation logics of partialResultSetDecoder.
  178. func TestPartialResultSetDecoder(t *testing.T) {
  179. restore := setMaxBytesBetweenResumeTokens()
  180. defer restore()
  181. var tests = []struct {
  182. input []*sppb.PartialResultSet
  183. wantF []*Row
  184. wantTxID transactionID
  185. wantTs time.Time
  186. wantD bool
  187. }{
  188. {
  189. // Empty input.
  190. wantD: true,
  191. },
  192. // String merging examples.
  193. {
  194. // Single KV result.
  195. input: []*sppb.PartialResultSet{
  196. {
  197. Metadata: kvMeta,
  198. Values: []*proto3.Value{
  199. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  200. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  201. },
  202. },
  203. },
  204. wantF: []*Row{
  205. {
  206. fields: kvMeta.RowType.Fields,
  207. vals: []*proto3.Value{
  208. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  209. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  210. },
  211. },
  212. },
  213. wantTs: trxTs,
  214. wantD: true,
  215. },
  216. {
  217. // Incomplete partial result.
  218. input: []*sppb.PartialResultSet{
  219. {
  220. Metadata: kvMeta,
  221. Values: []*proto3.Value{
  222. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  223. },
  224. },
  225. },
  226. wantTs: trxTs,
  227. wantD: false,
  228. },
  229. {
  230. // Complete splitted result.
  231. input: []*sppb.PartialResultSet{
  232. {
  233. Metadata: kvMeta,
  234. Values: []*proto3.Value{
  235. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  236. },
  237. },
  238. {
  239. Values: []*proto3.Value{
  240. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  241. },
  242. },
  243. },
  244. wantF: []*Row{
  245. {
  246. fields: kvMeta.RowType.Fields,
  247. vals: []*proto3.Value{
  248. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  249. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  250. },
  251. },
  252. },
  253. wantTs: trxTs,
  254. wantD: true,
  255. },
  256. {
  257. // Multi-row example with splitted row in the middle.
  258. input: []*sppb.PartialResultSet{
  259. {
  260. Metadata: kvMeta,
  261. Values: []*proto3.Value{
  262. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  263. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  264. {Kind: &proto3.Value_StringValue{StringValue: "A"}},
  265. },
  266. },
  267. {
  268. Values: []*proto3.Value{
  269. {Kind: &proto3.Value_StringValue{StringValue: "1"}},
  270. {Kind: &proto3.Value_StringValue{StringValue: "B"}},
  271. {Kind: &proto3.Value_StringValue{StringValue: "2"}},
  272. },
  273. },
  274. },
  275. wantF: []*Row{
  276. {
  277. fields: kvMeta.RowType.Fields,
  278. vals: []*proto3.Value{
  279. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  280. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  281. },
  282. },
  283. {
  284. fields: kvMeta.RowType.Fields,
  285. vals: []*proto3.Value{
  286. {Kind: &proto3.Value_StringValue{StringValue: "A"}},
  287. {Kind: &proto3.Value_StringValue{StringValue: "1"}},
  288. },
  289. },
  290. {
  291. fields: kvMeta.RowType.Fields,
  292. vals: []*proto3.Value{
  293. {Kind: &proto3.Value_StringValue{StringValue: "B"}},
  294. {Kind: &proto3.Value_StringValue{StringValue: "2"}},
  295. },
  296. },
  297. },
  298. wantTs: trxTs,
  299. wantD: true,
  300. },
  301. {
  302. // Merging example in result_set.proto.
  303. input: []*sppb.PartialResultSet{
  304. {
  305. Metadata: kvMeta,
  306. Values: []*proto3.Value{
  307. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  308. {Kind: &proto3.Value_StringValue{StringValue: "W"}},
  309. },
  310. ChunkedValue: true,
  311. },
  312. {
  313. Values: []*proto3.Value{
  314. {Kind: &proto3.Value_StringValue{StringValue: "orl"}},
  315. },
  316. ChunkedValue: true,
  317. },
  318. {
  319. Values: []*proto3.Value{
  320. {Kind: &proto3.Value_StringValue{StringValue: "d"}},
  321. },
  322. },
  323. },
  324. wantF: []*Row{
  325. {
  326. fields: kvMeta.RowType.Fields,
  327. vals: []*proto3.Value{
  328. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  329. {Kind: &proto3.Value_StringValue{StringValue: "World"}},
  330. },
  331. },
  332. },
  333. wantTs: trxTs,
  334. wantD: true,
  335. },
  336. {
  337. // More complex example showing completing a merge and
  338. // starting a new merge in the same partialResultSet.
  339. input: []*sppb.PartialResultSet{
  340. {
  341. Metadata: kvMeta,
  342. Values: []*proto3.Value{
  343. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  344. {Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
  345. },
  346. ChunkedValue: true,
  347. },
  348. {
  349. Values: []*proto3.Value{
  350. {Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
  351. {Kind: &proto3.Value_StringValue{StringValue: "i"}}, // start split in key
  352. },
  353. ChunkedValue: true,
  354. },
  355. {
  356. Values: []*proto3.Value{
  357. {Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
  358. {Kind: &proto3.Value_StringValue{StringValue: "not"}},
  359. {Kind: &proto3.Value_StringValue{StringValue: "a"}},
  360. {Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
  361. },
  362. ChunkedValue: true,
  363. },
  364. {
  365. Values: []*proto3.Value{
  366. {Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
  367. },
  368. },
  369. },
  370. wantF: []*Row{
  371. {
  372. fields: kvMeta.RowType.Fields,
  373. vals: []*proto3.Value{
  374. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  375. {Kind: &proto3.Value_StringValue{StringValue: "World"}},
  376. },
  377. },
  378. {
  379. fields: kvMeta.RowType.Fields,
  380. vals: []*proto3.Value{
  381. {Kind: &proto3.Value_StringValue{StringValue: "is"}},
  382. {Kind: &proto3.Value_StringValue{StringValue: "not"}},
  383. },
  384. },
  385. {
  386. fields: kvMeta.RowType.Fields,
  387. vals: []*proto3.Value{
  388. {Kind: &proto3.Value_StringValue{StringValue: "a"}},
  389. {Kind: &proto3.Value_StringValue{StringValue: "question"}},
  390. },
  391. },
  392. },
  393. wantTs: trxTs,
  394. wantD: true,
  395. },
  396. // List merging examples.
  397. {
  398. // Non-splitting Lists.
  399. input: []*sppb.PartialResultSet{
  400. {
  401. Metadata: kvListMeta,
  402. Values: []*proto3.Value{
  403. {
  404. Kind: genProtoListValue("foo-1", "foo-2"),
  405. },
  406. },
  407. },
  408. {
  409. Values: []*proto3.Value{
  410. {
  411. Kind: genProtoListValue("bar-1", "bar-2"),
  412. },
  413. },
  414. },
  415. },
  416. wantF: []*Row{
  417. {
  418. fields: kvListMeta.RowType.Fields,
  419. vals: []*proto3.Value{
  420. {
  421. Kind: genProtoListValue("foo-1", "foo-2"),
  422. },
  423. {
  424. Kind: genProtoListValue("bar-1", "bar-2"),
  425. },
  426. },
  427. },
  428. },
  429. wantTxID: transactionID{5, 6, 7, 8, 9},
  430. wantTs: trxTs,
  431. wantD: true,
  432. },
  433. {
  434. // Simple List merge case: splitted string element.
  435. input: []*sppb.PartialResultSet{
  436. {
  437. Metadata: kvListMeta,
  438. Values: []*proto3.Value{
  439. {
  440. Kind: genProtoListValue("foo-1", "foo-"),
  441. },
  442. },
  443. ChunkedValue: true,
  444. },
  445. {
  446. Values: []*proto3.Value{
  447. {
  448. Kind: genProtoListValue("2"),
  449. },
  450. },
  451. },
  452. {
  453. Values: []*proto3.Value{
  454. {
  455. Kind: genProtoListValue("bar-1", "bar-2"),
  456. },
  457. },
  458. },
  459. },
  460. wantF: []*Row{
  461. {
  462. fields: kvListMeta.RowType.Fields,
  463. vals: []*proto3.Value{
  464. {
  465. Kind: genProtoListValue("foo-1", "foo-2"),
  466. },
  467. {
  468. Kind: genProtoListValue("bar-1", "bar-2"),
  469. },
  470. },
  471. },
  472. },
  473. wantTxID: transactionID{5, 6, 7, 8, 9},
  474. wantTs: trxTs,
  475. wantD: true,
  476. },
  477. {
  478. // Struct merging is also implemented by List merging. Note that
  479. // Cloud Spanner uses proto.ListValue to encode Structs as well.
  480. input: []*sppb.PartialResultSet{
  481. {
  482. Metadata: kvObjectMeta,
  483. Values: []*proto3.Value{
  484. {
  485. Kind: &proto3.Value_ListValue{
  486. ListValue: &proto3.ListValue{
  487. Values: []*proto3.Value{
  488. {Kind: &proto3.Value_NumberValue{NumberValue: 23}},
  489. {Kind: genProtoListValue("foo-1", "fo")},
  490. },
  491. },
  492. },
  493. },
  494. },
  495. ChunkedValue: true,
  496. },
  497. {
  498. Values: []*proto3.Value{
  499. {
  500. Kind: &proto3.Value_ListValue{
  501. ListValue: &proto3.ListValue{
  502. Values: []*proto3.Value{
  503. {Kind: genProtoListValue("o-2", "f")},
  504. },
  505. },
  506. },
  507. },
  508. },
  509. ChunkedValue: true,
  510. },
  511. {
  512. Values: []*proto3.Value{
  513. {
  514. Kind: &proto3.Value_ListValue{
  515. ListValue: &proto3.ListValue{
  516. Values: []*proto3.Value{
  517. {Kind: genProtoListValue("oo-3")},
  518. },
  519. },
  520. },
  521. },
  522. {
  523. Kind: &proto3.Value_ListValue{
  524. ListValue: &proto3.ListValue{
  525. Values: []*proto3.Value{
  526. {Kind: &proto3.Value_NumberValue{NumberValue: 45}},
  527. {Kind: genProtoListValue("bar-1")},
  528. },
  529. },
  530. },
  531. },
  532. },
  533. },
  534. },
  535. wantF: []*Row{
  536. {
  537. fields: kvObjectMeta.RowType.Fields,
  538. vals: []*proto3.Value{
  539. {
  540. Kind: &proto3.Value_ListValue{
  541. ListValue: &proto3.ListValue{
  542. Values: []*proto3.Value{
  543. {Kind: &proto3.Value_NumberValue{NumberValue: 23}},
  544. {Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
  545. },
  546. },
  547. },
  548. },
  549. {
  550. Kind: &proto3.Value_ListValue{
  551. ListValue: &proto3.ListValue{
  552. Values: []*proto3.Value{
  553. {Kind: &proto3.Value_NumberValue{NumberValue: 45}},
  554. {Kind: genProtoListValue("bar-1")},
  555. },
  556. },
  557. },
  558. },
  559. },
  560. },
  561. },
  562. wantTxID: transactionID{1, 2, 3, 4, 5},
  563. wantD: true,
  564. },
  565. }
  566. nextTest:
  567. for i, test := range tests {
  568. var rows []*Row
  569. p := &partialResultSetDecoder{}
  570. for j, v := range test.input {
  571. rs, err := p.add(v)
  572. if err != nil {
  573. t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
  574. continue nextTest
  575. }
  576. rows = append(rows, rs...)
  577. }
  578. if !testEqual(p.ts, test.wantTs) {
  579. t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
  580. }
  581. if !testEqual(rows, test.wantF) {
  582. t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
  583. }
  584. if got := p.done(); got != test.wantD {
  585. t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
  586. }
  587. }
  588. }
  589. const (
  590. maxBuffers = 16 // max number of PartialResultSets that will be buffered in tests.
  591. )
  592. // setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to a smaller
  593. // value more suitable for tests. It returns a function which should be called to restore
  594. // the maxBytesBetweenResumeTokens to its old value
  595. func setMaxBytesBetweenResumeTokens() func() {
  596. o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
  597. atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
  598. Metadata: kvMeta,
  599. Values: []*proto3.Value{
  600. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  601. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  602. },
  603. })))
  604. return func() {
  605. atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
  606. }
  607. }
  608. // keyStr generates key string for kvMeta schema.
  609. func keyStr(i int) string {
  610. return fmt.Sprintf("foo-%02d", i)
  611. }
  612. // valStr generates value string for kvMeta schema.
  613. func valStr(i int) string {
  614. return fmt.Sprintf("bar-%02d", i)
  615. }
  616. // Test state transitions of resumableStreamDecoder where state machine
  617. // ends up to a non-blocking state(resumableStreamDecoder.Next returns
  618. // on non-blocking state).
  619. func TestRsdNonblockingStates(t *testing.T) {
  620. restore := setMaxBytesBetweenResumeTokens()
  621. defer restore()
  622. tests := []struct {
  623. name string
  624. msgs []testutil.MockCtlMsg
  625. rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
  626. sql string
  627. // Expected values
  628. want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller
  629. queue []*sppb.PartialResultSet // PartialResultSets that should be buffered
  630. resumeToken []byte // Resume token that is maintained by resumableStreamDecoder
  631. stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
  632. wantErr error
  633. }{
  634. {
  635. // unConnected->queueingRetryable->finished
  636. name: "unConnected->queueingRetryable->finished",
  637. msgs: []testutil.MockCtlMsg{
  638. {},
  639. {},
  640. {Err: io.EOF, ResumeToken: false},
  641. },
  642. sql: "SELECT t.key key, t.value value FROM t_mock t",
  643. want: []*sppb.PartialResultSet{
  644. {
  645. Metadata: kvMeta,
  646. Values: []*proto3.Value{
  647. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  648. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  649. },
  650. },
  651. },
  652. queue: []*sppb.PartialResultSet{
  653. {
  654. Metadata: kvMeta,
  655. Values: []*proto3.Value{
  656. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  657. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  658. },
  659. },
  660. },
  661. stateHistory: []resumableStreamDecoderState{
  662. queueingRetryable, // do RPC
  663. queueingRetryable, // got foo-00
  664. queueingRetryable, // got foo-01
  665. finished, // got EOF
  666. },
  667. },
  668. {
  669. // unConnected->queueingRetryable->aborted
  670. name: "unConnected->queueingRetryable->aborted",
  671. msgs: []testutil.MockCtlMsg{
  672. {},
  673. {Err: nil, ResumeToken: true},
  674. {},
  675. {Err: errors.New("I quit"), ResumeToken: false},
  676. },
  677. sql: "SELECT t.key key, t.value value FROM t_mock t",
  678. want: []*sppb.PartialResultSet{
  679. {
  680. Metadata: kvMeta,
  681. Values: []*proto3.Value{
  682. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  683. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  684. },
  685. },
  686. {
  687. Metadata: kvMeta,
  688. Values: []*proto3.Value{
  689. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  690. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  691. },
  692. ResumeToken: testutil.EncodeResumeToken(1),
  693. },
  694. },
  695. stateHistory: []resumableStreamDecoderState{
  696. queueingRetryable, // do RPC
  697. queueingRetryable, // got foo-00
  698. queueingRetryable, // got foo-01
  699. queueingRetryable, // foo-01, resume token
  700. queueingRetryable, // got foo-02
  701. aborted, // got error
  702. },
  703. wantErr: status.Errorf(codes.Unknown, "I quit"),
  704. },
  705. {
  706. // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
  707. name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
  708. msgs: func() (m []testutil.MockCtlMsg) {
  709. for i := 0; i < maxBuffers+1; i++ {
  710. m = append(m, testutil.MockCtlMsg{})
  711. }
  712. return m
  713. }(),
  714. sql: "SELECT t.key key, t.value value FROM t_mock t",
  715. want: func() (s []*sppb.PartialResultSet) {
  716. for i := 0; i < maxBuffers+1; i++ {
  717. s = append(s, &sppb.PartialResultSet{
  718. Metadata: kvMeta,
  719. Values: []*proto3.Value{
  720. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  721. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  722. },
  723. })
  724. }
  725. return s
  726. }(),
  727. stateHistory: func() (s []resumableStreamDecoderState) {
  728. s = append(s, queueingRetryable) // RPC
  729. for i := 0; i < maxBuffers; i++ {
  730. s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
  731. }
  732. // the first item fills up the queue and triggers state transition;
  733. // the second item is received under queueingUnretryable state.
  734. s = append(s, queueingUnretryable)
  735. s = append(s, queueingUnretryable)
  736. return s
  737. }(),
  738. },
  739. {
  740. // unConnected->queueingRetryable->queueingUnretryable->aborted
  741. name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
  742. msgs: func() (m []testutil.MockCtlMsg) {
  743. for i := 0; i < maxBuffers; i++ {
  744. m = append(m, testutil.MockCtlMsg{})
  745. }
  746. m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false})
  747. return m
  748. }(),
  749. sql: "SELECT t.key key, t.value value FROM t_mock t",
  750. want: func() (s []*sppb.PartialResultSet) {
  751. for i := 0; i < maxBuffers; i++ {
  752. s = append(s, &sppb.PartialResultSet{
  753. Metadata: kvMeta,
  754. Values: []*proto3.Value{
  755. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  756. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  757. },
  758. })
  759. }
  760. return s
  761. }(),
  762. stateHistory: func() (s []resumableStreamDecoderState) {
  763. s = append(s, queueingRetryable) // RPC
  764. for i := 0; i < maxBuffers; i++ {
  765. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
  766. }
  767. s = append(s, queueingUnretryable) // the last row triggers state change
  768. s = append(s, aborted) // Error happens
  769. return s
  770. }(),
  771. wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
  772. },
  773. }
  774. nextTest:
  775. for _, test := range tests {
  776. ms := testutil.NewMockCloudSpanner(t, trxTs)
  777. ms.Serve()
  778. mc := sppb.NewSpannerClient(dialMock(t, ms))
  779. if test.rpc == nil {
  780. test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  781. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  782. Sql: test.sql,
  783. ResumeToken: resumeToken,
  784. })
  785. }
  786. }
  787. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  788. defer cancel()
  789. r := newResumableStreamDecoder(
  790. ctx,
  791. test.rpc,
  792. )
  793. st := []resumableStreamDecoderState{}
  794. var lastErr error
  795. // Once the expected number of state transitions are observed,
  796. // send a signal by setting stateDone = true.
  797. stateDone := false
  798. // Set stateWitness to listen to state changes.
  799. hl := len(test.stateHistory) // To avoid data race on test.
  800. r.stateWitness = func(rs resumableStreamDecoderState) {
  801. if !stateDone {
  802. // Record state transitions.
  803. st = append(st, rs)
  804. if len(st) == hl {
  805. lastErr = r.lastErr()
  806. stateDone = true
  807. }
  808. }
  809. }
  810. // Let mock server stream given messages to resumableStreamDecoder.
  811. for _, m := range test.msgs {
  812. ms.AddMsg(m.Err, m.ResumeToken)
  813. }
  814. var rs []*sppb.PartialResultSet
  815. for {
  816. select {
  817. case <-ctx.Done():
  818. t.Errorf("context cancelled or timeout during test")
  819. continue nextTest
  820. default:
  821. }
  822. if stateDone {
  823. // Check if resumableStreamDecoder carried out expected
  824. // state transitions.
  825. if !testEqual(st, test.stateHistory) {
  826. t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n",
  827. test.name, st, test.stateHistory)
  828. }
  829. // Check if resumableStreamDecoder returns expected array of
  830. // PartialResultSets.
  831. if !testEqual(rs, test.want) {
  832. t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want)
  833. }
  834. // Verify that resumableStreamDecoder's internal buffering is also correct.
  835. var q []*sppb.PartialResultSet
  836. for {
  837. item := r.q.pop()
  838. if item == nil {
  839. break
  840. }
  841. q = append(q, item)
  842. }
  843. if !testEqual(q, test.queue) {
  844. t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue)
  845. }
  846. // Verify resume token.
  847. if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
  848. t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken)
  849. }
  850. // Verify error message.
  851. if !testEqual(lastErr, test.wantErr) {
  852. t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr)
  853. }
  854. // Proceed to next test
  855. continue nextTest
  856. }
  857. // Receive next decoded item.
  858. if r.next() {
  859. rs = append(rs, r.get())
  860. }
  861. }
  862. }
  863. }
  864. // Test state transitions of resumableStreamDecoder where state machine
  865. // ends up to a blocking state(resumableStreamDecoder.Next blocks
  866. // on blocking state).
  867. func TestRsdBlockingStates(t *testing.T) {
  868. restore := setMaxBytesBetweenResumeTokens()
  869. defer restore()
  870. tests := []struct {
  871. name string
  872. msgs []testutil.MockCtlMsg
  873. rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
  874. sql string
  875. // Expected values
  876. want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller
  877. queue []*sppb.PartialResultSet // PartialResultSets that should be buffered
  878. resumeToken []byte // Resume token that is maintained by resumableStreamDecoder
  879. stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
  880. wantErr error
  881. }{
  882. {
  883. // unConnected -> unConnected
  884. name: "unConnected -> unConnected",
  885. rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  886. return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
  887. },
  888. sql: "SELECT * from t_whatever",
  889. stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
  890. wantErr: status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
  891. },
  892. {
  893. // unConnected -> queueingRetryable
  894. name: "unConnected -> queueingRetryable",
  895. sql: "SELECT t.key key, t.value value FROM t_mock t",
  896. stateHistory: []resumableStreamDecoderState{queueingRetryable},
  897. },
  898. {
  899. // unConnected->queueingRetryable->queueingRetryable
  900. name: "unConnected->queueingRetryable->queueingRetryable",
  901. msgs: []testutil.MockCtlMsg{
  902. {},
  903. {Err: nil, ResumeToken: true},
  904. {Err: nil, ResumeToken: true},
  905. {},
  906. },
  907. sql: "SELECT t.key key, t.value value FROM t_mock t",
  908. want: []*sppb.PartialResultSet{
  909. {
  910. Metadata: kvMeta,
  911. Values: []*proto3.Value{
  912. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  913. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  914. },
  915. },
  916. {
  917. Metadata: kvMeta,
  918. Values: []*proto3.Value{
  919. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  920. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  921. },
  922. ResumeToken: testutil.EncodeResumeToken(1),
  923. },
  924. {
  925. Metadata: kvMeta,
  926. Values: []*proto3.Value{
  927. {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
  928. {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
  929. },
  930. ResumeToken: testutil.EncodeResumeToken(2),
  931. },
  932. },
  933. queue: []*sppb.PartialResultSet{
  934. {
  935. Metadata: kvMeta,
  936. Values: []*proto3.Value{
  937. {Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
  938. {Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
  939. },
  940. },
  941. },
  942. resumeToken: testutil.EncodeResumeToken(2),
  943. stateHistory: []resumableStreamDecoderState{
  944. queueingRetryable, // do RPC
  945. queueingRetryable, // got foo-00
  946. queueingRetryable, // got foo-01
  947. queueingRetryable, // foo-01, resume token
  948. queueingRetryable, // got foo-02
  949. queueingRetryable, // foo-02, resume token
  950. queueingRetryable, // got foo-03
  951. },
  952. },
  953. {
  954. // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
  955. name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
  956. msgs: func() (m []testutil.MockCtlMsg) {
  957. for i := 0; i < maxBuffers+1; i++ {
  958. m = append(m, testutil.MockCtlMsg{})
  959. }
  960. m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true})
  961. m = append(m, testutil.MockCtlMsg{})
  962. return m
  963. }(),
  964. sql: "SELECT t.key key, t.value value FROM t_mock t",
  965. want: func() (s []*sppb.PartialResultSet) {
  966. for i := 0; i < maxBuffers+2; i++ {
  967. s = append(s, &sppb.PartialResultSet{
  968. Metadata: kvMeta,
  969. Values: []*proto3.Value{
  970. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  971. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  972. },
  973. })
  974. }
  975. s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1)
  976. return s
  977. }(),
  978. resumeToken: testutil.EncodeResumeToken(maxBuffers + 1),
  979. queue: []*sppb.PartialResultSet{
  980. {
  981. Metadata: kvMeta,
  982. Values: []*proto3.Value{
  983. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
  984. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
  985. },
  986. },
  987. },
  988. stateHistory: func() (s []resumableStreamDecoderState) {
  989. s = append(s, queueingRetryable) // RPC
  990. for i := 0; i < maxBuffers; i++ {
  991. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
  992. }
  993. for i := maxBuffers - 1; i < maxBuffers+1; i++ {
  994. // the first item fills up the queue and triggers state change;
  995. // the second item is received under queueingUnretryable state.
  996. s = append(s, queueingUnretryable)
  997. }
  998. s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
  999. s = append(s, queueingRetryable) // (maxBuffers+1)th row has resume token
  1000. s = append(s, queueingRetryable) // (maxBuffers+2)th row has no resume token
  1001. return s
  1002. }(),
  1003. },
  1004. {
  1005. // unConnected->queueingRetryable->queueingUnretryable->finished
  1006. name: "unConnected->queueingRetryable->queueingUnretryable->finished",
  1007. msgs: func() (m []testutil.MockCtlMsg) {
  1008. for i := 0; i < maxBuffers; i++ {
  1009. m = append(m, testutil.MockCtlMsg{})
  1010. }
  1011. m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false})
  1012. return m
  1013. }(),
  1014. sql: "SELECT t.key key, t.value value FROM t_mock t",
  1015. want: func() (s []*sppb.PartialResultSet) {
  1016. for i := 0; i < maxBuffers; i++ {
  1017. s = append(s, &sppb.PartialResultSet{
  1018. Metadata: kvMeta,
  1019. Values: []*proto3.Value{
  1020. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1021. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1022. },
  1023. })
  1024. }
  1025. return s
  1026. }(),
  1027. stateHistory: func() (s []resumableStreamDecoderState) {
  1028. s = append(s, queueingRetryable) // RPC
  1029. for i := 0; i < maxBuffers; i++ {
  1030. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
  1031. }
  1032. s = append(s, queueingUnretryable) // last row triggers state change
  1033. s = append(s, finished) // query finishes
  1034. return s
  1035. }(),
  1036. },
  1037. }
  1038. for _, test := range tests {
  1039. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1040. ms.Serve()
  1041. cc := dialMock(t, ms)
  1042. mc := sppb.NewSpannerClient(cc)
  1043. if test.rpc == nil {
  1044. // Avoid using test.sql directly in closure because for loop changes test.
  1045. sql := test.sql
  1046. test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1047. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1048. Sql: sql,
  1049. ResumeToken: resumeToken,
  1050. })
  1051. }
  1052. }
  1053. ctx, cancel := context.WithCancel(context.Background())
  1054. defer cancel()
  1055. r := newResumableStreamDecoder(
  1056. ctx,
  1057. test.rpc,
  1058. )
  1059. // Override backoff to make the test run faster.
  1060. r.backoff = backoff.ExponentialBackoff{
  1061. Min: 1 * time.Nanosecond,
  1062. Max: 1 * time.Nanosecond,
  1063. }
  1064. // st is the set of observed state transitions.
  1065. st := []resumableStreamDecoderState{}
  1066. // q is the content of the decoder's partial result queue when expected number of state transitions are done.
  1067. q := []*sppb.PartialResultSet{}
  1068. var lastErr error
  1069. // Once the expected number of state transitions are observed,
  1070. // send a signal to channel stateDone.
  1071. stateDone := make(chan int)
  1072. // Set stateWitness to listen to state changes.
  1073. hl := len(test.stateHistory) // To avoid data race on test.
  1074. r.stateWitness = func(rs resumableStreamDecoderState) {
  1075. select {
  1076. case <-stateDone:
  1077. // Noop after expected number of state transitions
  1078. default:
  1079. // Record state transitions.
  1080. st = append(st, rs)
  1081. if len(st) == hl {
  1082. lastErr = r.lastErr()
  1083. q = r.q.dump()
  1084. close(stateDone)
  1085. }
  1086. }
  1087. }
  1088. // Let mock server stream given messages to resumableStreamDecoder.
  1089. for _, m := range test.msgs {
  1090. ms.AddMsg(m.Err, m.ResumeToken)
  1091. }
  1092. var rs []*sppb.PartialResultSet
  1093. go func() {
  1094. for {
  1095. if !r.next() {
  1096. // Note that r.Next also exits on context cancel/timeout.
  1097. return
  1098. }
  1099. rs = append(rs, r.get())
  1100. }
  1101. }()
  1102. // Verify that resumableStreamDecoder reaches expected state.
  1103. select {
  1104. case <-stateDone: // Note that at this point, receiver is still blocking on r.next().
  1105. // Check if resumableStreamDecoder carried out expected
  1106. // state transitions.
  1107. if !testEqual(st, test.stateHistory) {
  1108. t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n",
  1109. test.name, st, test.stateHistory)
  1110. }
  1111. // Check if resumableStreamDecoder returns expected array of
  1112. // PartialResultSets.
  1113. if !testEqual(rs, test.want) {
  1114. t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want)
  1115. }
  1116. // Verify that resumableStreamDecoder's internal buffering is also correct.
  1117. if !testEqual(q, test.queue) {
  1118. t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue)
  1119. }
  1120. // Verify resume token.
  1121. if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
  1122. t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken)
  1123. }
  1124. // Verify error message.
  1125. if !testEqual(lastErr, test.wantErr) {
  1126. t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr)
  1127. }
  1128. case <-time.After(1 * time.Second):
  1129. t.Errorf("%v: Timeout in waiting for state change", test.name)
  1130. }
  1131. ms.Stop()
  1132. cc.Close()
  1133. }
  1134. }
  1135. // sReceiver signals every receiving attempt through a channel,
  1136. // used by TestResumeToken to determine if the receiving of a certain
  1137. // PartialResultSet will be attempted next.
  1138. type sReceiver struct {
  1139. c chan int
  1140. rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
  1141. }
  1142. // Recv() implements streamingReceiver.Recv for sReceiver.
  1143. func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
  1144. sr.c <- 1
  1145. return sr.rpcReceiver.Recv()
  1146. }
  1147. // waitn waits for nth receiving attempt from now on, until
  1148. // the signal for nth Recv() attempts is received or timeout.
  1149. // Note that because the way stream() works, the signal for the
  1150. // nth Recv() means that the previous n - 1 PartialResultSets
  1151. // has already been returned to caller or queued, if no error happened.
  1152. func (sr *sReceiver) waitn(n int) error {
  1153. for i := 0; i < n; i++ {
  1154. select {
  1155. case <-sr.c:
  1156. case <-time.After(10 * time.Second):
  1157. return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
  1158. }
  1159. }
  1160. return nil
  1161. }
  1162. // Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
  1163. func TestQueueBytes(t *testing.T) {
  1164. restore := setMaxBytesBetweenResumeTokens()
  1165. defer restore()
  1166. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1167. ms.Serve()
  1168. defer ms.Stop()
  1169. cc := dialMock(t, ms)
  1170. defer cc.Close()
  1171. mc := sppb.NewSpannerClient(cc)
  1172. sr := &sReceiver{
  1173. c: make(chan int, 1000), // will never block in this test
  1174. }
  1175. wantQueueBytes := 0
  1176. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1177. defer cancel()
  1178. r := newResumableStreamDecoder(
  1179. ctx,
  1180. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1181. r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1182. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1183. ResumeToken: resumeToken,
  1184. })
  1185. sr.rpcReceiver = r
  1186. return sr, err
  1187. },
  1188. )
  1189. go func() {
  1190. for r.next() {
  1191. }
  1192. }()
  1193. // Let server send maxBuffers / 2 rows.
  1194. for i := 0; i < maxBuffers/2; i++ {
  1195. wantQueueBytes += proto.Size(&sppb.PartialResultSet{
  1196. Metadata: kvMeta,
  1197. Values: []*proto3.Value{
  1198. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1199. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1200. },
  1201. })
  1202. ms.AddMsg(nil, false)
  1203. }
  1204. if err := sr.waitn(maxBuffers/2 + 1); err != nil {
  1205. t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err)
  1206. }
  1207. if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
  1208. t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes)
  1209. }
  1210. // Now send a resume token to drain the queue.
  1211. ms.AddMsg(nil, true)
  1212. // Wait for all rows to be processes.
  1213. if err := sr.waitn(1); err != nil {
  1214. t.Fatalf("failed to wait for rows to be processed: %v", err)
  1215. }
  1216. if r.bytesBetweenResumeTokens != 0 {
  1217. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1218. }
  1219. // Let server send maxBuffers - 1 rows.
  1220. wantQueueBytes = 0
  1221. for i := 0; i < maxBuffers-1; i++ {
  1222. wantQueueBytes += proto.Size(&sppb.PartialResultSet{
  1223. Metadata: kvMeta,
  1224. Values: []*proto3.Value{
  1225. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1226. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1227. },
  1228. })
  1229. ms.AddMsg(nil, false)
  1230. }
  1231. if err := sr.waitn(maxBuffers - 1); err != nil {
  1232. t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err)
  1233. }
  1234. if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
  1235. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1236. }
  1237. // Trigger a state transition: queueingRetryable -> queueingUnretryable.
  1238. ms.AddMsg(nil, false)
  1239. if err := sr.waitn(1); err != nil {
  1240. t.Fatalf("failed to wait for state transition: %v", err)
  1241. }
  1242. if r.bytesBetweenResumeTokens != 0 {
  1243. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1244. }
  1245. }
  1246. // Verify that client can deal with resume token correctly
  1247. func TestResumeToken(t *testing.T) {
  1248. restore := setMaxBytesBetweenResumeTokens()
  1249. defer restore()
  1250. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1251. ms.Serve()
  1252. defer ms.Stop()
  1253. cc := dialMock(t, ms)
  1254. defer cc.Close()
  1255. mc := sppb.NewSpannerClient(cc)
  1256. sr := &sReceiver{
  1257. c: make(chan int, 1000), // will never block in this test
  1258. }
  1259. rows := []*Row{}
  1260. done := make(chan error)
  1261. streaming := func() {
  1262. // Establish a stream to mock cloud spanner server.
  1263. iter := stream(context.Background(),
  1264. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1265. r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1266. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1267. ResumeToken: resumeToken,
  1268. })
  1269. sr.rpcReceiver = r
  1270. return sr, err
  1271. },
  1272. nil,
  1273. func(error) {})
  1274. defer iter.Stop()
  1275. var err error
  1276. for {
  1277. var row *Row
  1278. row, err = iter.Next()
  1279. if err == iterator.Done {
  1280. err = nil
  1281. break
  1282. }
  1283. if err != nil {
  1284. break
  1285. }
  1286. rows = append(rows, row)
  1287. }
  1288. done <- err
  1289. }
  1290. go streaming()
  1291. // Server streaming row 0 - 2, only row 1 has resume token.
  1292. // Client will receive row 0 - 2, so it will try receiving for
  1293. // 4 times (the last recv will block), and only row 0 - 1 will
  1294. // be yielded.
  1295. for i := 0; i < 3; i++ {
  1296. if i == 1 {
  1297. ms.AddMsg(nil, true)
  1298. } else {
  1299. ms.AddMsg(nil, false)
  1300. }
  1301. }
  1302. // Wait for 4 receive attempts, as explained above.
  1303. if err := sr.waitn(4); err != nil {
  1304. t.Fatalf("failed to wait for row 0 - 2: %v", err)
  1305. }
  1306. want := []*Row{
  1307. {
  1308. fields: kvMeta.RowType.Fields,
  1309. vals: []*proto3.Value{
  1310. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  1311. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  1312. },
  1313. },
  1314. {
  1315. fields: kvMeta.RowType.Fields,
  1316. vals: []*proto3.Value{
  1317. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  1318. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  1319. },
  1320. },
  1321. }
  1322. if !testEqual(rows, want) {
  1323. t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
  1324. }
  1325. // Inject resumable failure.
  1326. ms.AddMsg(
  1327. status.Errorf(codes.Unavailable, "mock server unavailable"),
  1328. false,
  1329. )
  1330. // Test if client detects the resumable failure and retries.
  1331. if err := sr.waitn(1); err != nil {
  1332. t.Fatalf("failed to wait for client to retry: %v", err)
  1333. }
  1334. // Client has resumed the query, now server resend row 2.
  1335. ms.AddMsg(nil, true)
  1336. if err := sr.waitn(1); err != nil {
  1337. t.Fatalf("failed to wait for resending row 2: %v", err)
  1338. }
  1339. // Now client should have received row 0 - 2.
  1340. want = append(want, &Row{
  1341. fields: kvMeta.RowType.Fields,
  1342. vals: []*proto3.Value{
  1343. {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
  1344. {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
  1345. },
  1346. })
  1347. if !testEqual(rows, want) {
  1348. t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want)
  1349. }
  1350. // Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them.
  1351. for i := 3; i < maxBuffers+2; i++ {
  1352. ms.AddMsg(nil, false)
  1353. }
  1354. if err := sr.waitn(maxBuffers - 1); err != nil {
  1355. t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err)
  1356. }
  1357. // Received rows should be unchanged.
  1358. if !testEqual(rows, want) {
  1359. t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want)
  1360. }
  1361. // Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder:
  1362. // queueingRetryable -> queueingUnretryable
  1363. ms.AddMsg(nil, false)
  1364. if err := sr.waitn(1); err != nil {
  1365. t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err)
  1366. }
  1367. // Client should yield row 3rd - (maxBuffers+2)th to application. Therefore, application should
  1368. // see row 0 - (maxBuffers+2)th so far.
  1369. for i := 3; i < maxBuffers+3; i++ {
  1370. want = append(want, &Row{
  1371. fields: kvMeta.RowType.Fields,
  1372. vals: []*proto3.Value{
  1373. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1374. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1375. },
  1376. })
  1377. }
  1378. if !testEqual(rows, want) {
  1379. t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want)
  1380. }
  1381. // Inject resumable error, but since resumableStreamDecoder is already at queueingUnretryable
  1382. // state, query will just fail.
  1383. ms.AddMsg(
  1384. status.Errorf(codes.Unavailable, "mock server wants some sleep"),
  1385. false,
  1386. )
  1387. var gotErr error
  1388. select {
  1389. case gotErr = <-done:
  1390. case <-time.After(10 * time.Second):
  1391. t.Fatalf("timeout in waiting for failed query to return.")
  1392. }
  1393. if wantErr := toSpannerError(status.Errorf(codes.Unavailable, "mock server wants some sleep")); !testEqual(gotErr, wantErr) {
  1394. t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr)
  1395. }
  1396. // Reconnect to mock Cloud Spanner.
  1397. rows = []*Row{}
  1398. go streaming()
  1399. // Let server send two rows without resume token.
  1400. for i := maxBuffers + 3; i < maxBuffers+5; i++ {
  1401. ms.AddMsg(nil, false)
  1402. }
  1403. if err := sr.waitn(3); err != nil {
  1404. t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err)
  1405. }
  1406. if len(rows) > 0 {
  1407. t.Errorf("client received some rows unexpectedly: %v, want nothing", rows)
  1408. }
  1409. // Let server end the query.
  1410. ms.AddMsg(io.EOF, false)
  1411. select {
  1412. case gotErr = <-done:
  1413. case <-time.After(10 * time.Second):
  1414. t.Fatalf("timeout in waiting for failed query to return")
  1415. }
  1416. if gotErr != nil {
  1417. t.Fatalf("stream() returns unexpected error: %v, but want no error", gotErr)
  1418. }
  1419. // Verify if a normal server side EOF flushes all queued rows.
  1420. want = []*Row{
  1421. {
  1422. fields: kvMeta.RowType.Fields,
  1423. vals: []*proto3.Value{
  1424. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}},
  1425. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}},
  1426. },
  1427. },
  1428. {
  1429. fields: kvMeta.RowType.Fields,
  1430. vals: []*proto3.Value{
  1431. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}},
  1432. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}},
  1433. },
  1434. },
  1435. }
  1436. if !testEqual(rows, want) {
  1437. t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
  1438. }
  1439. }
  1440. // Verify that streaming query get retried upon real gRPC server transport failures.
  1441. func TestGrpcReconnect(t *testing.T) {
  1442. restore := setMaxBytesBetweenResumeTokens()
  1443. defer restore()
  1444. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1445. ms.Serve()
  1446. defer ms.Stop()
  1447. cc := dialMock(t, ms)
  1448. defer cc.Close()
  1449. mc := sppb.NewSpannerClient(cc)
  1450. retry := make(chan int)
  1451. row := make(chan int)
  1452. var err error
  1453. go func() {
  1454. r := 0
  1455. // Establish a stream to mock cloud spanner server.
  1456. iter := stream(context.Background(),
  1457. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1458. if r > 0 {
  1459. // This RPC attempt is a retry, signal it.
  1460. retry <- r
  1461. }
  1462. r++
  1463. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1464. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1465. ResumeToken: resumeToken,
  1466. })
  1467. },
  1468. nil,
  1469. func(error) {})
  1470. defer iter.Stop()
  1471. for {
  1472. _, err = iter.Next()
  1473. if err == iterator.Done {
  1474. err = nil
  1475. break
  1476. }
  1477. if err != nil {
  1478. break
  1479. }
  1480. row <- 0
  1481. }
  1482. }()
  1483. // Add a message and wait for the receipt.
  1484. ms.AddMsg(nil, true)
  1485. select {
  1486. case <-row:
  1487. case <-time.After(10 * time.Second):
  1488. t.Fatalf("expect stream to be established within 10 seconds, but it didn't")
  1489. }
  1490. // Error injection: force server to close all connections.
  1491. ms.Stop()
  1492. // Test to see if client respond to the real RPC failure correctly by
  1493. // retrying RPC.
  1494. select {
  1495. case r, ok := <-retry:
  1496. if ok && r == 1 {
  1497. break
  1498. }
  1499. t.Errorf("retry count = %v, want 1", r)
  1500. case <-time.After(10 * time.Second):
  1501. t.Errorf("client library failed to respond after 10 seconds, aborting")
  1502. return
  1503. }
  1504. }
  1505. // Test cancel/timeout for client operations.
  1506. func TestCancelTimeout(t *testing.T) {
  1507. restore := setMaxBytesBetweenResumeTokens()
  1508. defer restore()
  1509. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1510. ms.Serve()
  1511. defer ms.Stop()
  1512. cc := dialMock(t, ms)
  1513. defer cc.Close()
  1514. mc := sppb.NewSpannerClient(cc)
  1515. done := make(chan int)
  1516. go func() {
  1517. for {
  1518. ms.AddMsg(nil, true)
  1519. }
  1520. }()
  1521. // Test cancelling query.
  1522. ctx, cancel := context.WithCancel(context.Background())
  1523. var err error
  1524. go func() {
  1525. // Establish a stream to mock cloud spanner server.
  1526. iter := stream(ctx,
  1527. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1528. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1529. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1530. ResumeToken: resumeToken,
  1531. })
  1532. },
  1533. nil,
  1534. func(error) {})
  1535. defer iter.Stop()
  1536. for {
  1537. _, err = iter.Next()
  1538. if err == iterator.Done {
  1539. break
  1540. }
  1541. if err != nil {
  1542. done <- 0
  1543. break
  1544. }
  1545. }
  1546. }()
  1547. cancel()
  1548. select {
  1549. case <-done:
  1550. if ErrCode(err) != codes.Canceled {
  1551. t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
  1552. }
  1553. case <-time.After(1 * time.Second):
  1554. t.Errorf("query doesn't exit timely after being cancelled")
  1555. }
  1556. // Test query timeout.
  1557. ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
  1558. defer cancel()
  1559. go func() {
  1560. // Establish a stream to mock cloud spanner server.
  1561. iter := stream(ctx,
  1562. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1563. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1564. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1565. ResumeToken: resumeToken,
  1566. })
  1567. },
  1568. nil,
  1569. func(error) {})
  1570. defer iter.Stop()
  1571. for {
  1572. _, err = iter.Next()
  1573. if err == iterator.Done {
  1574. err = nil
  1575. break
  1576. }
  1577. if err != nil {
  1578. break
  1579. }
  1580. }
  1581. done <- 0
  1582. }()
  1583. select {
  1584. case <-done:
  1585. if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
  1586. t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
  1587. }
  1588. case <-time.After(2 * time.Second):
  1589. t.Errorf("query doesn't timeout as expected")
  1590. }
  1591. }
  1592. func TestRowIteratorDo(t *testing.T) {
  1593. restore := setMaxBytesBetweenResumeTokens()
  1594. defer restore()
  1595. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1596. ms.Serve()
  1597. defer ms.Stop()
  1598. cc := dialMock(t, ms)
  1599. defer cc.Close()
  1600. mc := sppb.NewSpannerClient(cc)
  1601. for i := 0; i < 3; i++ {
  1602. ms.AddMsg(nil, false)
  1603. }
  1604. ms.AddMsg(io.EOF, true)
  1605. nRows := 0
  1606. iter := stream(context.Background(),
  1607. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1608. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1609. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1610. ResumeToken: resumeToken,
  1611. })
  1612. },
  1613. nil,
  1614. func(error) {})
  1615. err := iter.Do(func(r *Row) error { nRows++; return nil })
  1616. if err != nil {
  1617. t.Errorf("Using Do: %v", err)
  1618. }
  1619. if nRows != 3 {
  1620. t.Errorf("got %d rows, want 3", nRows)
  1621. }
  1622. }
  1623. func TestRowIteratorDoWithError(t *testing.T) {
  1624. restore := setMaxBytesBetweenResumeTokens()
  1625. defer restore()
  1626. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1627. ms.Serve()
  1628. defer ms.Stop()
  1629. cc := dialMock(t, ms)
  1630. defer cc.Close()
  1631. mc := sppb.NewSpannerClient(cc)
  1632. for i := 0; i < 3; i++ {
  1633. ms.AddMsg(nil, false)
  1634. }
  1635. ms.AddMsg(io.EOF, true)
  1636. iter := stream(context.Background(),
  1637. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1638. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1639. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1640. ResumeToken: resumeToken,
  1641. })
  1642. },
  1643. nil,
  1644. func(error) {})
  1645. injected := errors.New("Failed iterator")
  1646. err := iter.Do(func(r *Row) error { return injected })
  1647. if err != injected {
  1648. t.Errorf("got <%v>, want <%v>", err, injected)
  1649. }
  1650. }
  1651. func TestIteratorStopEarly(t *testing.T) {
  1652. ctx := context.Background()
  1653. restore := setMaxBytesBetweenResumeTokens()
  1654. defer restore()
  1655. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1656. ms.Serve()
  1657. defer ms.Stop()
  1658. cc := dialMock(t, ms)
  1659. defer cc.Close()
  1660. mc := sppb.NewSpannerClient(cc)
  1661. ms.AddMsg(nil, false)
  1662. ms.AddMsg(nil, false)
  1663. ms.AddMsg(io.EOF, true)
  1664. iter := stream(ctx,
  1665. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1666. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1667. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1668. ResumeToken: resumeToken,
  1669. })
  1670. },
  1671. nil,
  1672. func(error) {})
  1673. _, err := iter.Next()
  1674. if err != nil {
  1675. t.Fatalf("before Stop: %v", err)
  1676. }
  1677. iter.Stop()
  1678. // Stop sets r.err to the FailedPrecondition error "Next called after Stop".
  1679. // Override that here so this test can observe the Canceled error from the stream.
  1680. iter.err = nil
  1681. iter.Next()
  1682. if ErrCode(iter.streamd.lastErr()) != codes.Canceled {
  1683. t.Errorf("after Stop: got %v, wanted Canceled", err)
  1684. }
  1685. }
  1686. func TestIteratorWithError(t *testing.T) {
  1687. injected := errors.New("Failed iterator")
  1688. iter := RowIterator{err: injected}
  1689. defer iter.Stop()
  1690. if _, err := iter.Next(); err != injected {
  1691. t.Fatalf("Expected error: %v, got %v", injected, err)
  1692. }
  1693. }
  1694. func dialMock(t *testing.T, ms *testutil.MockCloudSpanner) *grpc.ClientConn {
  1695. cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock())
  1696. if err != nil {
  1697. t.Fatalf("Dial(%q) = %v", ms.Addr(), err)
  1698. }
  1699. return cc
  1700. }