You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

1734 rivejä
49 KiB

  1. /*
  2. Copyright 2017 Google LLC
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package spanner
  14. import (
  15. "errors"
  16. "fmt"
  17. "io"
  18. "sync/atomic"
  19. "testing"
  20. "time"
  21. "golang.org/x/net/context"
  22. "google.golang.org/grpc/status"
  23. "github.com/golang/protobuf/proto"
  24. proto3 "github.com/golang/protobuf/ptypes/struct"
  25. "cloud.google.com/go/spanner/internal/testutil"
  26. "google.golang.org/api/iterator"
  27. sppb "google.golang.org/genproto/googleapis/spanner/v1"
  28. "google.golang.org/grpc"
  29. "google.golang.org/grpc/codes"
  30. )
  31. var (
  32. // Mocked transaction timestamp.
  33. trxTs = time.Unix(1, 2)
  34. // Metadata for mocked KV table, its rows are returned by SingleUse transactions.
  35. kvMeta = func() *sppb.ResultSetMetadata {
  36. meta := testutil.KvMeta
  37. meta.Transaction = &sppb.Transaction{
  38. ReadTimestamp: timestampProto(trxTs),
  39. }
  40. return &meta
  41. }()
  42. // Metadata for mocked ListKV table, which uses List for its key and value.
  43. // Its rows are returned by snapshot readonly transactions, as indicated in the transaction metadata.
  44. kvListMeta = &sppb.ResultSetMetadata{
  45. RowType: &sppb.StructType{
  46. Fields: []*sppb.StructType_Field{
  47. {
  48. Name: "Key",
  49. Type: &sppb.Type{
  50. Code: sppb.TypeCode_ARRAY,
  51. ArrayElementType: &sppb.Type{
  52. Code: sppb.TypeCode_STRING,
  53. },
  54. },
  55. },
  56. {
  57. Name: "Value",
  58. Type: &sppb.Type{
  59. Code: sppb.TypeCode_ARRAY,
  60. ArrayElementType: &sppb.Type{
  61. Code: sppb.TypeCode_STRING,
  62. },
  63. },
  64. },
  65. },
  66. },
  67. Transaction: &sppb.Transaction{
  68. Id: transactionID{5, 6, 7, 8, 9},
  69. ReadTimestamp: timestampProto(trxTs),
  70. },
  71. }
  72. // Metadata for mocked schema of a query result set, which has two struct
  73. // columns named "Col1" and "Col2", the struct's schema is like the
  74. // following:
  75. //
  76. // STRUCT {
  77. // INT
  78. // LIST<STRING>
  79. // }
  80. //
  81. // Its rows are returned in readwrite transaction, as indicated in the transaction metadata.
  82. kvObjectMeta = &sppb.ResultSetMetadata{
  83. RowType: &sppb.StructType{
  84. Fields: []*sppb.StructType_Field{
  85. {
  86. Name: "Col1",
  87. Type: &sppb.Type{
  88. Code: sppb.TypeCode_STRUCT,
  89. StructType: &sppb.StructType{
  90. Fields: []*sppb.StructType_Field{
  91. {
  92. Name: "foo-f1",
  93. Type: &sppb.Type{
  94. Code: sppb.TypeCode_INT64,
  95. },
  96. },
  97. {
  98. Name: "foo-f2",
  99. Type: &sppb.Type{
  100. Code: sppb.TypeCode_ARRAY,
  101. ArrayElementType: &sppb.Type{
  102. Code: sppb.TypeCode_STRING,
  103. },
  104. },
  105. },
  106. },
  107. },
  108. },
  109. },
  110. {
  111. Name: "Col2",
  112. Type: &sppb.Type{
  113. Code: sppb.TypeCode_STRUCT,
  114. StructType: &sppb.StructType{
  115. Fields: []*sppb.StructType_Field{
  116. {
  117. Name: "bar-f1",
  118. Type: &sppb.Type{
  119. Code: sppb.TypeCode_INT64,
  120. },
  121. },
  122. {
  123. Name: "bar-f2",
  124. Type: &sppb.Type{
  125. Code: sppb.TypeCode_ARRAY,
  126. ArrayElementType: &sppb.Type{
  127. Code: sppb.TypeCode_STRING,
  128. },
  129. },
  130. },
  131. },
  132. },
  133. },
  134. },
  135. },
  136. },
  137. Transaction: &sppb.Transaction{
  138. Id: transactionID{1, 2, 3, 4, 5},
  139. },
  140. }
  141. )
  142. // String implements fmt.stringer.
  143. func (r *Row) String() string {
  144. return fmt.Sprintf("{fields: %s, val: %s}", r.fields, r.vals)
  145. }
  146. func describeRows(l []*Row) string {
  147. // generate a nice test failure description
  148. var s = "["
  149. for i, r := range l {
  150. if i != 0 {
  151. s += ",\n "
  152. }
  153. s += fmt.Sprint(r)
  154. }
  155. s += "]"
  156. return s
  157. }
  158. // Helper for generating proto3 Value_ListValue instances, making
  159. // test code shorter and readable.
  160. func genProtoListValue(v ...string) *proto3.Value_ListValue {
  161. r := &proto3.Value_ListValue{
  162. ListValue: &proto3.ListValue{
  163. Values: []*proto3.Value{},
  164. },
  165. }
  166. for _, e := range v {
  167. r.ListValue.Values = append(
  168. r.ListValue.Values,
  169. &proto3.Value{
  170. Kind: &proto3.Value_StringValue{StringValue: e},
  171. },
  172. )
  173. }
  174. return r
  175. }
  176. // Test Row generation logics of partialResultSetDecoder.
  177. func TestPartialResultSetDecoder(t *testing.T) {
  178. restore := setMaxBytesBetweenResumeTokens()
  179. defer restore()
  180. var tests = []struct {
  181. input []*sppb.PartialResultSet
  182. wantF []*Row
  183. wantTxID transactionID
  184. wantTs time.Time
  185. wantD bool
  186. }{
  187. {
  188. // Empty input.
  189. wantD: true,
  190. },
  191. // String merging examples.
  192. {
  193. // Single KV result.
  194. input: []*sppb.PartialResultSet{
  195. {
  196. Metadata: kvMeta,
  197. Values: []*proto3.Value{
  198. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  199. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  200. },
  201. },
  202. },
  203. wantF: []*Row{
  204. {
  205. fields: kvMeta.RowType.Fields,
  206. vals: []*proto3.Value{
  207. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  208. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  209. },
  210. },
  211. },
  212. wantTs: trxTs,
  213. wantD: true,
  214. },
  215. {
  216. // Incomplete partial result.
  217. input: []*sppb.PartialResultSet{
  218. {
  219. Metadata: kvMeta,
  220. Values: []*proto3.Value{
  221. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  222. },
  223. },
  224. },
  225. wantTs: trxTs,
  226. wantD: false,
  227. },
  228. {
  229. // Complete splitted result.
  230. input: []*sppb.PartialResultSet{
  231. {
  232. Metadata: kvMeta,
  233. Values: []*proto3.Value{
  234. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  235. },
  236. },
  237. {
  238. Values: []*proto3.Value{
  239. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  240. },
  241. },
  242. },
  243. wantF: []*Row{
  244. {
  245. fields: kvMeta.RowType.Fields,
  246. vals: []*proto3.Value{
  247. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  248. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  249. },
  250. },
  251. },
  252. wantTs: trxTs,
  253. wantD: true,
  254. },
  255. {
  256. // Multi-row example with splitted row in the middle.
  257. input: []*sppb.PartialResultSet{
  258. {
  259. Metadata: kvMeta,
  260. Values: []*proto3.Value{
  261. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  262. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  263. {Kind: &proto3.Value_StringValue{StringValue: "A"}},
  264. },
  265. },
  266. {
  267. Values: []*proto3.Value{
  268. {Kind: &proto3.Value_StringValue{StringValue: "1"}},
  269. {Kind: &proto3.Value_StringValue{StringValue: "B"}},
  270. {Kind: &proto3.Value_StringValue{StringValue: "2"}},
  271. },
  272. },
  273. },
  274. wantF: []*Row{
  275. {
  276. fields: kvMeta.RowType.Fields,
  277. vals: []*proto3.Value{
  278. {Kind: &proto3.Value_StringValue{StringValue: "foo"}},
  279. {Kind: &proto3.Value_StringValue{StringValue: "bar"}},
  280. },
  281. },
  282. {
  283. fields: kvMeta.RowType.Fields,
  284. vals: []*proto3.Value{
  285. {Kind: &proto3.Value_StringValue{StringValue: "A"}},
  286. {Kind: &proto3.Value_StringValue{StringValue: "1"}},
  287. },
  288. },
  289. {
  290. fields: kvMeta.RowType.Fields,
  291. vals: []*proto3.Value{
  292. {Kind: &proto3.Value_StringValue{StringValue: "B"}},
  293. {Kind: &proto3.Value_StringValue{StringValue: "2"}},
  294. },
  295. },
  296. },
  297. wantTs: trxTs,
  298. wantD: true,
  299. },
  300. {
  301. // Merging example in result_set.proto.
  302. input: []*sppb.PartialResultSet{
  303. {
  304. Metadata: kvMeta,
  305. Values: []*proto3.Value{
  306. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  307. {Kind: &proto3.Value_StringValue{StringValue: "W"}},
  308. },
  309. ChunkedValue: true,
  310. },
  311. {
  312. Values: []*proto3.Value{
  313. {Kind: &proto3.Value_StringValue{StringValue: "orl"}},
  314. },
  315. ChunkedValue: true,
  316. },
  317. {
  318. Values: []*proto3.Value{
  319. {Kind: &proto3.Value_StringValue{StringValue: "d"}},
  320. },
  321. },
  322. },
  323. wantF: []*Row{
  324. {
  325. fields: kvMeta.RowType.Fields,
  326. vals: []*proto3.Value{
  327. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  328. {Kind: &proto3.Value_StringValue{StringValue: "World"}},
  329. },
  330. },
  331. },
  332. wantTs: trxTs,
  333. wantD: true,
  334. },
  335. {
  336. // More complex example showing completing a merge and
  337. // starting a new merge in the same partialResultSet.
  338. input: []*sppb.PartialResultSet{
  339. {
  340. Metadata: kvMeta,
  341. Values: []*proto3.Value{
  342. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  343. {Kind: &proto3.Value_StringValue{StringValue: "W"}}, // start split in value
  344. },
  345. ChunkedValue: true,
  346. },
  347. {
  348. Values: []*proto3.Value{
  349. {Kind: &proto3.Value_StringValue{StringValue: "orld"}}, // complete value
  350. {Kind: &proto3.Value_StringValue{StringValue: "i"}}, // start split in key
  351. },
  352. ChunkedValue: true,
  353. },
  354. {
  355. Values: []*proto3.Value{
  356. {Kind: &proto3.Value_StringValue{StringValue: "s"}}, // complete key
  357. {Kind: &proto3.Value_StringValue{StringValue: "not"}},
  358. {Kind: &proto3.Value_StringValue{StringValue: "a"}},
  359. {Kind: &proto3.Value_StringValue{StringValue: "qu"}}, // split in value
  360. },
  361. ChunkedValue: true,
  362. },
  363. {
  364. Values: []*proto3.Value{
  365. {Kind: &proto3.Value_StringValue{StringValue: "estion"}}, // complete value
  366. },
  367. },
  368. },
  369. wantF: []*Row{
  370. {
  371. fields: kvMeta.RowType.Fields,
  372. vals: []*proto3.Value{
  373. {Kind: &proto3.Value_StringValue{StringValue: "Hello"}},
  374. {Kind: &proto3.Value_StringValue{StringValue: "World"}},
  375. },
  376. },
  377. {
  378. fields: kvMeta.RowType.Fields,
  379. vals: []*proto3.Value{
  380. {Kind: &proto3.Value_StringValue{StringValue: "is"}},
  381. {Kind: &proto3.Value_StringValue{StringValue: "not"}},
  382. },
  383. },
  384. {
  385. fields: kvMeta.RowType.Fields,
  386. vals: []*proto3.Value{
  387. {Kind: &proto3.Value_StringValue{StringValue: "a"}},
  388. {Kind: &proto3.Value_StringValue{StringValue: "question"}},
  389. },
  390. },
  391. },
  392. wantTs: trxTs,
  393. wantD: true,
  394. },
  395. // List merging examples.
  396. {
  397. // Non-splitting Lists.
  398. input: []*sppb.PartialResultSet{
  399. {
  400. Metadata: kvListMeta,
  401. Values: []*proto3.Value{
  402. {
  403. Kind: genProtoListValue("foo-1", "foo-2"),
  404. },
  405. },
  406. },
  407. {
  408. Values: []*proto3.Value{
  409. {
  410. Kind: genProtoListValue("bar-1", "bar-2"),
  411. },
  412. },
  413. },
  414. },
  415. wantF: []*Row{
  416. {
  417. fields: kvListMeta.RowType.Fields,
  418. vals: []*proto3.Value{
  419. {
  420. Kind: genProtoListValue("foo-1", "foo-2"),
  421. },
  422. {
  423. Kind: genProtoListValue("bar-1", "bar-2"),
  424. },
  425. },
  426. },
  427. },
  428. wantTxID: transactionID{5, 6, 7, 8, 9},
  429. wantTs: trxTs,
  430. wantD: true,
  431. },
  432. {
  433. // Simple List merge case: splitted string element.
  434. input: []*sppb.PartialResultSet{
  435. {
  436. Metadata: kvListMeta,
  437. Values: []*proto3.Value{
  438. {
  439. Kind: genProtoListValue("foo-1", "foo-"),
  440. },
  441. },
  442. ChunkedValue: true,
  443. },
  444. {
  445. Values: []*proto3.Value{
  446. {
  447. Kind: genProtoListValue("2"),
  448. },
  449. },
  450. },
  451. {
  452. Values: []*proto3.Value{
  453. {
  454. Kind: genProtoListValue("bar-1", "bar-2"),
  455. },
  456. },
  457. },
  458. },
  459. wantF: []*Row{
  460. {
  461. fields: kvListMeta.RowType.Fields,
  462. vals: []*proto3.Value{
  463. {
  464. Kind: genProtoListValue("foo-1", "foo-2"),
  465. },
  466. {
  467. Kind: genProtoListValue("bar-1", "bar-2"),
  468. },
  469. },
  470. },
  471. },
  472. wantTxID: transactionID{5, 6, 7, 8, 9},
  473. wantTs: trxTs,
  474. wantD: true,
  475. },
  476. {
  477. // Struct merging is also implemented by List merging. Note that
  478. // Cloud Spanner uses proto.ListValue to encode Structs as well.
  479. input: []*sppb.PartialResultSet{
  480. {
  481. Metadata: kvObjectMeta,
  482. Values: []*proto3.Value{
  483. {
  484. Kind: &proto3.Value_ListValue{
  485. ListValue: &proto3.ListValue{
  486. Values: []*proto3.Value{
  487. {Kind: &proto3.Value_NumberValue{NumberValue: 23}},
  488. {Kind: genProtoListValue("foo-1", "fo")},
  489. },
  490. },
  491. },
  492. },
  493. },
  494. ChunkedValue: true,
  495. },
  496. {
  497. Values: []*proto3.Value{
  498. {
  499. Kind: &proto3.Value_ListValue{
  500. ListValue: &proto3.ListValue{
  501. Values: []*proto3.Value{
  502. {Kind: genProtoListValue("o-2", "f")},
  503. },
  504. },
  505. },
  506. },
  507. },
  508. ChunkedValue: true,
  509. },
  510. {
  511. Values: []*proto3.Value{
  512. {
  513. Kind: &proto3.Value_ListValue{
  514. ListValue: &proto3.ListValue{
  515. Values: []*proto3.Value{
  516. {Kind: genProtoListValue("oo-3")},
  517. },
  518. },
  519. },
  520. },
  521. {
  522. Kind: &proto3.Value_ListValue{
  523. ListValue: &proto3.ListValue{
  524. Values: []*proto3.Value{
  525. {Kind: &proto3.Value_NumberValue{NumberValue: 45}},
  526. {Kind: genProtoListValue("bar-1")},
  527. },
  528. },
  529. },
  530. },
  531. },
  532. },
  533. },
  534. wantF: []*Row{
  535. {
  536. fields: kvObjectMeta.RowType.Fields,
  537. vals: []*proto3.Value{
  538. {
  539. Kind: &proto3.Value_ListValue{
  540. ListValue: &proto3.ListValue{
  541. Values: []*proto3.Value{
  542. {Kind: &proto3.Value_NumberValue{NumberValue: 23}},
  543. {Kind: genProtoListValue("foo-1", "foo-2", "foo-3")},
  544. },
  545. },
  546. },
  547. },
  548. {
  549. Kind: &proto3.Value_ListValue{
  550. ListValue: &proto3.ListValue{
  551. Values: []*proto3.Value{
  552. {Kind: &proto3.Value_NumberValue{NumberValue: 45}},
  553. {Kind: genProtoListValue("bar-1")},
  554. },
  555. },
  556. },
  557. },
  558. },
  559. },
  560. },
  561. wantTxID: transactionID{1, 2, 3, 4, 5},
  562. wantD: true,
  563. },
  564. }
  565. nextTest:
  566. for i, test := range tests {
  567. var rows []*Row
  568. p := &partialResultSetDecoder{}
  569. for j, v := range test.input {
  570. rs, err := p.add(v)
  571. if err != nil {
  572. t.Errorf("test %d.%d: partialResultSetDecoder.add(%v) = %v; want nil", i, j, v, err)
  573. continue nextTest
  574. }
  575. rows = append(rows, rs...)
  576. }
  577. if !testEqual(p.ts, test.wantTs) {
  578. t.Errorf("got transaction(%v), want %v", p.ts, test.wantTs)
  579. }
  580. if !testEqual(rows, test.wantF) {
  581. t.Errorf("test %d: rows=\n%v\n; want\n%v\n; p.row:\n%v\n", i, describeRows(rows), describeRows(test.wantF), p.row)
  582. }
  583. if got := p.done(); got != test.wantD {
  584. t.Errorf("test %d: partialResultSetDecoder.done() = %v", i, got)
  585. }
  586. }
  587. }
  588. const (
  589. maxBuffers = 16 // max number of PartialResultSets that will be buffered in tests.
  590. )
  591. // setMaxBytesBetweenResumeTokens sets the global maxBytesBetweenResumeTokens to a smaller
  592. // value more suitable for tests. It returns a function which should be called to restore
  593. // the maxBytesBetweenResumeTokens to its old value
  594. func setMaxBytesBetweenResumeTokens() func() {
  595. o := atomic.LoadInt32(&maxBytesBetweenResumeTokens)
  596. atomic.StoreInt32(&maxBytesBetweenResumeTokens, int32(maxBuffers*proto.Size(&sppb.PartialResultSet{
  597. Metadata: kvMeta,
  598. Values: []*proto3.Value{
  599. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  600. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  601. },
  602. })))
  603. return func() {
  604. atomic.StoreInt32(&maxBytesBetweenResumeTokens, o)
  605. }
  606. }
  607. // keyStr generates key string for kvMeta schema.
  608. func keyStr(i int) string {
  609. return fmt.Sprintf("foo-%02d", i)
  610. }
  611. // valStr generates value string for kvMeta schema.
  612. func valStr(i int) string {
  613. return fmt.Sprintf("bar-%02d", i)
  614. }
  615. // Test state transitions of resumableStreamDecoder where state machine
  616. // ends up to a non-blocking state(resumableStreamDecoder.Next returns
  617. // on non-blocking state).
  618. func TestRsdNonblockingStates(t *testing.T) {
  619. restore := setMaxBytesBetweenResumeTokens()
  620. defer restore()
  621. tests := []struct {
  622. name string
  623. msgs []testutil.MockCtlMsg
  624. rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
  625. sql string
  626. // Expected values
  627. want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller
  628. queue []*sppb.PartialResultSet // PartialResultSets that should be buffered
  629. resumeToken []byte // Resume token that is maintained by resumableStreamDecoder
  630. stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
  631. wantErr error
  632. }{
  633. {
  634. // unConnected->queueingRetryable->finished
  635. name: "unConnected->queueingRetryable->finished",
  636. msgs: []testutil.MockCtlMsg{
  637. {},
  638. {},
  639. {Err: io.EOF, ResumeToken: false},
  640. },
  641. sql: "SELECT t.key key, t.value value FROM t_mock t",
  642. want: []*sppb.PartialResultSet{
  643. {
  644. Metadata: kvMeta,
  645. Values: []*proto3.Value{
  646. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  647. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  648. },
  649. },
  650. },
  651. queue: []*sppb.PartialResultSet{
  652. {
  653. Metadata: kvMeta,
  654. Values: []*proto3.Value{
  655. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  656. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  657. },
  658. },
  659. },
  660. stateHistory: []resumableStreamDecoderState{
  661. queueingRetryable, // do RPC
  662. queueingRetryable, // got foo-00
  663. queueingRetryable, // got foo-01
  664. finished, // got EOF
  665. },
  666. },
  667. {
  668. // unConnected->queueingRetryable->aborted
  669. name: "unConnected->queueingRetryable->aborted",
  670. msgs: []testutil.MockCtlMsg{
  671. {},
  672. {Err: nil, ResumeToken: true},
  673. {},
  674. {Err: errors.New("I quit"), ResumeToken: false},
  675. },
  676. sql: "SELECT t.key key, t.value value FROM t_mock t",
  677. want: []*sppb.PartialResultSet{
  678. {
  679. Metadata: kvMeta,
  680. Values: []*proto3.Value{
  681. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  682. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  683. },
  684. },
  685. {
  686. Metadata: kvMeta,
  687. Values: []*proto3.Value{
  688. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  689. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  690. },
  691. ResumeToken: testutil.EncodeResumeToken(1),
  692. },
  693. },
  694. stateHistory: []resumableStreamDecoderState{
  695. queueingRetryable, // do RPC
  696. queueingRetryable, // got foo-00
  697. queueingRetryable, // got foo-01
  698. queueingRetryable, // foo-01, resume token
  699. queueingRetryable, // got foo-02
  700. aborted, // got error
  701. },
  702. wantErr: status.Errorf(codes.Unknown, "I quit"),
  703. },
  704. {
  705. // unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable
  706. name: "unConnected->queueingRetryable->queueingUnretryable->queueingUnretryable",
  707. msgs: func() (m []testutil.MockCtlMsg) {
  708. for i := 0; i < maxBuffers+1; i++ {
  709. m = append(m, testutil.MockCtlMsg{})
  710. }
  711. return m
  712. }(),
  713. sql: "SELECT t.key key, t.value value FROM t_mock t",
  714. want: func() (s []*sppb.PartialResultSet) {
  715. for i := 0; i < maxBuffers+1; i++ {
  716. s = append(s, &sppb.PartialResultSet{
  717. Metadata: kvMeta,
  718. Values: []*proto3.Value{
  719. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  720. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  721. },
  722. })
  723. }
  724. return s
  725. }(),
  726. stateHistory: func() (s []resumableStreamDecoderState) {
  727. s = append(s, queueingRetryable) // RPC
  728. for i := 0; i < maxBuffers; i++ {
  729. s = append(s, queueingRetryable) // the internal queue of resumableStreamDecoder fills up
  730. }
  731. // the first item fills up the queue and triggers state transition;
  732. // the second item is received under queueingUnretryable state.
  733. s = append(s, queueingUnretryable)
  734. s = append(s, queueingUnretryable)
  735. return s
  736. }(),
  737. },
  738. {
  739. // unConnected->queueingRetryable->queueingUnretryable->aborted
  740. name: "unConnected->queueingRetryable->queueingUnretryable->aborted",
  741. msgs: func() (m []testutil.MockCtlMsg) {
  742. for i := 0; i < maxBuffers; i++ {
  743. m = append(m, testutil.MockCtlMsg{})
  744. }
  745. m = append(m, testutil.MockCtlMsg{Err: errors.New("Just Abort It"), ResumeToken: false})
  746. return m
  747. }(),
  748. sql: "SELECT t.key key, t.value value FROM t_mock t",
  749. want: func() (s []*sppb.PartialResultSet) {
  750. for i := 0; i < maxBuffers; i++ {
  751. s = append(s, &sppb.PartialResultSet{
  752. Metadata: kvMeta,
  753. Values: []*proto3.Value{
  754. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  755. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  756. },
  757. })
  758. }
  759. return s
  760. }(),
  761. stateHistory: func() (s []resumableStreamDecoderState) {
  762. s = append(s, queueingRetryable) // RPC
  763. for i := 0; i < maxBuffers; i++ {
  764. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
  765. }
  766. s = append(s, queueingUnretryable) // the last row triggers state change
  767. s = append(s, aborted) // Error happens
  768. return s
  769. }(),
  770. wantErr: status.Errorf(codes.Unknown, "Just Abort It"),
  771. },
  772. }
  773. nextTest:
  774. for _, test := range tests {
  775. ms := testutil.NewMockCloudSpanner(t, trxTs)
  776. ms.Serve()
  777. mc := sppb.NewSpannerClient(dialMock(t, ms))
  778. if test.rpc == nil {
  779. test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  780. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  781. Sql: test.sql,
  782. ResumeToken: resumeToken,
  783. })
  784. }
  785. }
  786. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  787. defer cancel()
  788. r := newResumableStreamDecoder(
  789. ctx,
  790. test.rpc,
  791. )
  792. st := []resumableStreamDecoderState{}
  793. var lastErr error
  794. // Once the expected number of state transitions are observed,
  795. // send a signal by setting stateDone = true.
  796. stateDone := false
  797. // Set stateWitness to listen to state changes.
  798. hl := len(test.stateHistory) // To avoid data race on test.
  799. r.stateWitness = func(rs resumableStreamDecoderState) {
  800. if !stateDone {
  801. // Record state transitions.
  802. st = append(st, rs)
  803. if len(st) == hl {
  804. lastErr = r.lastErr()
  805. stateDone = true
  806. }
  807. }
  808. }
  809. // Let mock server stream given messages to resumableStreamDecoder.
  810. for _, m := range test.msgs {
  811. ms.AddMsg(m.Err, m.ResumeToken)
  812. }
  813. var rs []*sppb.PartialResultSet
  814. for {
  815. select {
  816. case <-ctx.Done():
  817. t.Errorf("context cancelled or timeout during test")
  818. continue nextTest
  819. default:
  820. }
  821. if stateDone {
  822. // Check if resumableStreamDecoder carried out expected
  823. // state transitions.
  824. if !testEqual(st, test.stateHistory) {
  825. t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n",
  826. test.name, st, test.stateHistory)
  827. }
  828. // Check if resumableStreamDecoder returns expected array of
  829. // PartialResultSets.
  830. if !testEqual(rs, test.want) {
  831. t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want)
  832. }
  833. // Verify that resumableStreamDecoder's internal buffering is also correct.
  834. var q []*sppb.PartialResultSet
  835. for {
  836. item := r.q.pop()
  837. if item == nil {
  838. break
  839. }
  840. q = append(q, item)
  841. }
  842. if !testEqual(q, test.queue) {
  843. t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue)
  844. }
  845. // Verify resume token.
  846. if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
  847. t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken)
  848. }
  849. // Verify error message.
  850. if !testEqual(lastErr, test.wantErr) {
  851. t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr)
  852. }
  853. // Proceed to next test
  854. continue nextTest
  855. }
  856. // Receive next decoded item.
  857. if r.next() {
  858. rs = append(rs, r.get())
  859. }
  860. }
  861. }
  862. }
  863. // Test state transitions of resumableStreamDecoder where state machine
  864. // ends up to a blocking state(resumableStreamDecoder.Next blocks
  865. // on blocking state).
  866. func TestRsdBlockingStates(t *testing.T) {
  867. restore := setMaxBytesBetweenResumeTokens()
  868. defer restore()
  869. tests := []struct {
  870. name string
  871. msgs []testutil.MockCtlMsg
  872. rpc func(ct context.Context, resumeToken []byte) (streamingReceiver, error)
  873. sql string
  874. // Expected values
  875. want []*sppb.PartialResultSet // PartialResultSets that should be returned to caller
  876. queue []*sppb.PartialResultSet // PartialResultSets that should be buffered
  877. resumeToken []byte // Resume token that is maintained by resumableStreamDecoder
  878. stateHistory []resumableStreamDecoderState // State transition history of resumableStreamDecoder
  879. wantErr error
  880. }{
  881. {
  882. // unConnected -> unConnected
  883. name: "unConnected -> unConnected",
  884. rpc: func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  885. return nil, status.Errorf(codes.Unavailable, "trust me: server is unavailable")
  886. },
  887. sql: "SELECT * from t_whatever",
  888. stateHistory: []resumableStreamDecoderState{unConnected, unConnected, unConnected},
  889. wantErr: status.Errorf(codes.Unavailable, "trust me: server is unavailable"),
  890. },
  891. {
  892. // unConnected -> queueingRetryable
  893. name: "unConnected -> queueingRetryable",
  894. sql: "SELECT t.key key, t.value value FROM t_mock t",
  895. stateHistory: []resumableStreamDecoderState{queueingRetryable},
  896. },
  897. {
  898. // unConnected->queueingRetryable->queueingRetryable
  899. name: "unConnected->queueingRetryable->queueingRetryable",
  900. msgs: []testutil.MockCtlMsg{
  901. {},
  902. {Err: nil, ResumeToken: true},
  903. {Err: nil, ResumeToken: true},
  904. {},
  905. },
  906. sql: "SELECT t.key key, t.value value FROM t_mock t",
  907. want: []*sppb.PartialResultSet{
  908. {
  909. Metadata: kvMeta,
  910. Values: []*proto3.Value{
  911. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  912. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  913. },
  914. },
  915. {
  916. Metadata: kvMeta,
  917. Values: []*proto3.Value{
  918. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  919. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  920. },
  921. ResumeToken: testutil.EncodeResumeToken(1),
  922. },
  923. {
  924. Metadata: kvMeta,
  925. Values: []*proto3.Value{
  926. {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
  927. {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
  928. },
  929. ResumeToken: testutil.EncodeResumeToken(2),
  930. },
  931. },
  932. queue: []*sppb.PartialResultSet{
  933. {
  934. Metadata: kvMeta,
  935. Values: []*proto3.Value{
  936. {Kind: &proto3.Value_StringValue{StringValue: keyStr(3)}},
  937. {Kind: &proto3.Value_StringValue{StringValue: valStr(3)}},
  938. },
  939. },
  940. },
  941. resumeToken: testutil.EncodeResumeToken(2),
  942. stateHistory: []resumableStreamDecoderState{
  943. queueingRetryable, // do RPC
  944. queueingRetryable, // got foo-00
  945. queueingRetryable, // got foo-01
  946. queueingRetryable, // foo-01, resume token
  947. queueingRetryable, // got foo-02
  948. queueingRetryable, // foo-02, resume token
  949. queueingRetryable, // got foo-03
  950. },
  951. },
  952. {
  953. // unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable
  954. name: "unConnected->queueingRetryable->queueingUnretryable->queueingRetryable->queueingRetryable",
  955. msgs: func() (m []testutil.MockCtlMsg) {
  956. for i := 0; i < maxBuffers+1; i++ {
  957. m = append(m, testutil.MockCtlMsg{})
  958. }
  959. m = append(m, testutil.MockCtlMsg{Err: nil, ResumeToken: true})
  960. m = append(m, testutil.MockCtlMsg{})
  961. return m
  962. }(),
  963. sql: "SELECT t.key key, t.value value FROM t_mock t",
  964. want: func() (s []*sppb.PartialResultSet) {
  965. for i := 0; i < maxBuffers+2; i++ {
  966. s = append(s, &sppb.PartialResultSet{
  967. Metadata: kvMeta,
  968. Values: []*proto3.Value{
  969. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  970. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  971. },
  972. })
  973. }
  974. s[maxBuffers+1].ResumeToken = testutil.EncodeResumeToken(maxBuffers + 1)
  975. return s
  976. }(),
  977. resumeToken: testutil.EncodeResumeToken(maxBuffers + 1),
  978. queue: []*sppb.PartialResultSet{
  979. {
  980. Metadata: kvMeta,
  981. Values: []*proto3.Value{
  982. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 2)}},
  983. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 2)}},
  984. },
  985. },
  986. },
  987. stateHistory: func() (s []resumableStreamDecoderState) {
  988. s = append(s, queueingRetryable) // RPC
  989. for i := 0; i < maxBuffers; i++ {
  990. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder filles up
  991. }
  992. for i := maxBuffers - 1; i < maxBuffers+1; i++ {
  993. // the first item fills up the queue and triggers state change;
  994. // the second item is received under queueingUnretryable state.
  995. s = append(s, queueingUnretryable)
  996. }
  997. s = append(s, queueingUnretryable) // got (maxBuffers+1)th row under Unretryable state
  998. s = append(s, queueingRetryable) // (maxBuffers+1)th row has resume token
  999. s = append(s, queueingRetryable) // (maxBuffers+2)th row has no resume token
  1000. return s
  1001. }(),
  1002. },
  1003. {
  1004. // unConnected->queueingRetryable->queueingUnretryable->finished
  1005. name: "unConnected->queueingRetryable->queueingUnretryable->finished",
  1006. msgs: func() (m []testutil.MockCtlMsg) {
  1007. for i := 0; i < maxBuffers; i++ {
  1008. m = append(m, testutil.MockCtlMsg{})
  1009. }
  1010. m = append(m, testutil.MockCtlMsg{Err: io.EOF, ResumeToken: false})
  1011. return m
  1012. }(),
  1013. sql: "SELECT t.key key, t.value value FROM t_mock t",
  1014. want: func() (s []*sppb.PartialResultSet) {
  1015. for i := 0; i < maxBuffers; i++ {
  1016. s = append(s, &sppb.PartialResultSet{
  1017. Metadata: kvMeta,
  1018. Values: []*proto3.Value{
  1019. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1020. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1021. },
  1022. })
  1023. }
  1024. return s
  1025. }(),
  1026. stateHistory: func() (s []resumableStreamDecoderState) {
  1027. s = append(s, queueingRetryable) // RPC
  1028. for i := 0; i < maxBuffers; i++ {
  1029. s = append(s, queueingRetryable) // internal queue of resumableStreamDecoder fills up
  1030. }
  1031. s = append(s, queueingUnretryable) // last row triggers state change
  1032. s = append(s, finished) // query finishes
  1033. return s
  1034. }(),
  1035. },
  1036. }
  1037. for _, test := range tests {
  1038. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1039. ms.Serve()
  1040. cc := dialMock(t, ms)
  1041. mc := sppb.NewSpannerClient(cc)
  1042. if test.rpc == nil {
  1043. // Avoid using test.sql directly in closure because for loop changes test.
  1044. sql := test.sql
  1045. test.rpc = func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1046. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1047. Sql: sql,
  1048. ResumeToken: resumeToken,
  1049. })
  1050. }
  1051. }
  1052. ctx, cancel := context.WithCancel(context.Background())
  1053. defer cancel()
  1054. r := newResumableStreamDecoder(
  1055. ctx,
  1056. test.rpc,
  1057. )
  1058. // Override backoff to make the test run faster.
  1059. r.backoff = exponentialBackoff{1 * time.Nanosecond, 1 * time.Nanosecond}
  1060. // st is the set of observed state transitions.
  1061. st := []resumableStreamDecoderState{}
  1062. // q is the content of the decoder's partial result queue when expected number of state transitions are done.
  1063. q := []*sppb.PartialResultSet{}
  1064. var lastErr error
  1065. // Once the expected number of state transitions are observed,
  1066. // send a signal to channel stateDone.
  1067. stateDone := make(chan int)
  1068. // Set stateWitness to listen to state changes.
  1069. hl := len(test.stateHistory) // To avoid data race on test.
  1070. r.stateWitness = func(rs resumableStreamDecoderState) {
  1071. select {
  1072. case <-stateDone:
  1073. // Noop after expected number of state transitions
  1074. default:
  1075. // Record state transitions.
  1076. st = append(st, rs)
  1077. if len(st) == hl {
  1078. lastErr = r.lastErr()
  1079. q = r.q.dump()
  1080. close(stateDone)
  1081. }
  1082. }
  1083. }
  1084. // Let mock server stream given messages to resumableStreamDecoder.
  1085. for _, m := range test.msgs {
  1086. ms.AddMsg(m.Err, m.ResumeToken)
  1087. }
  1088. var rs []*sppb.PartialResultSet
  1089. go func() {
  1090. for {
  1091. if !r.next() {
  1092. // Note that r.Next also exits on context cancel/timeout.
  1093. return
  1094. }
  1095. rs = append(rs, r.get())
  1096. }
  1097. }()
  1098. // Verify that resumableStreamDecoder reaches expected state.
  1099. select {
  1100. case <-stateDone: // Note that at this point, receiver is still blocking on r.next().
  1101. // Check if resumableStreamDecoder carried out expected
  1102. // state transitions.
  1103. if !testEqual(st, test.stateHistory) {
  1104. t.Errorf("%v: observed state transitions: \n%v\n, want \n%v\n",
  1105. test.name, st, test.stateHistory)
  1106. }
  1107. // Check if resumableStreamDecoder returns expected array of
  1108. // PartialResultSets.
  1109. if !testEqual(rs, test.want) {
  1110. t.Errorf("%v: received PartialResultSets: \n%v\n, want \n%v\n", test.name, rs, test.want)
  1111. }
  1112. // Verify that resumableStreamDecoder's internal buffering is also correct.
  1113. if !testEqual(q, test.queue) {
  1114. t.Errorf("%v: PartialResultSets still queued: \n%v\n, want \n%v\n", test.name, q, test.queue)
  1115. }
  1116. // Verify resume token.
  1117. if test.resumeToken != nil && !testEqual(r.resumeToken, test.resumeToken) {
  1118. t.Errorf("%v: Resume token is %v, want %v\n", test.name, r.resumeToken, test.resumeToken)
  1119. }
  1120. // Verify error message.
  1121. if !testEqual(lastErr, test.wantErr) {
  1122. t.Errorf("%v: got error %v, want %v", test.name, lastErr, test.wantErr)
  1123. }
  1124. case <-time.After(1 * time.Second):
  1125. t.Errorf("%v: Timeout in waiting for state change", test.name)
  1126. }
  1127. ms.Stop()
  1128. cc.Close()
  1129. }
  1130. }
  1131. // sReceiver signals every receiving attempt through a channel,
  1132. // used by TestResumeToken to determine if the receiving of a certain
  1133. // PartialResultSet will be attempted next.
  1134. type sReceiver struct {
  1135. c chan int
  1136. rpcReceiver sppb.Spanner_ExecuteStreamingSqlClient
  1137. }
  1138. // Recv() implements streamingReceiver.Recv for sReceiver.
  1139. func (sr *sReceiver) Recv() (*sppb.PartialResultSet, error) {
  1140. sr.c <- 1
  1141. return sr.rpcReceiver.Recv()
  1142. }
  1143. // waitn waits for nth receiving attempt from now on, until
  1144. // the signal for nth Recv() attempts is received or timeout.
  1145. // Note that because the way stream() works, the signal for the
  1146. // nth Recv() means that the previous n - 1 PartialResultSets
  1147. // has already been returned to caller or queued, if no error happened.
  1148. func (sr *sReceiver) waitn(n int) error {
  1149. for i := 0; i < n; i++ {
  1150. select {
  1151. case <-sr.c:
  1152. case <-time.After(10 * time.Second):
  1153. return fmt.Errorf("timeout in waiting for %v-th Recv()", i+1)
  1154. }
  1155. }
  1156. return nil
  1157. }
  1158. // Test the handling of resumableStreamDecoder.bytesBetweenResumeTokens.
  1159. func TestQueueBytes(t *testing.T) {
  1160. restore := setMaxBytesBetweenResumeTokens()
  1161. defer restore()
  1162. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1163. ms.Serve()
  1164. defer ms.Stop()
  1165. cc := dialMock(t, ms)
  1166. defer cc.Close()
  1167. mc := sppb.NewSpannerClient(cc)
  1168. sr := &sReceiver{
  1169. c: make(chan int, 1000), // will never block in this test
  1170. }
  1171. wantQueueBytes := 0
  1172. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  1173. defer cancel()
  1174. r := newResumableStreamDecoder(
  1175. ctx,
  1176. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1177. r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1178. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1179. ResumeToken: resumeToken,
  1180. })
  1181. sr.rpcReceiver = r
  1182. return sr, err
  1183. },
  1184. )
  1185. go func() {
  1186. for r.next() {
  1187. }
  1188. }()
  1189. // Let server send maxBuffers / 2 rows.
  1190. for i := 0; i < maxBuffers/2; i++ {
  1191. wantQueueBytes += proto.Size(&sppb.PartialResultSet{
  1192. Metadata: kvMeta,
  1193. Values: []*proto3.Value{
  1194. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1195. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1196. },
  1197. })
  1198. ms.AddMsg(nil, false)
  1199. }
  1200. if err := sr.waitn(maxBuffers/2 + 1); err != nil {
  1201. t.Fatalf("failed to wait for the first %v recv() calls: %v", maxBuffers, err)
  1202. }
  1203. if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
  1204. t.Errorf("r.bytesBetweenResumeTokens = %v, want %v", r.bytesBetweenResumeTokens, wantQueueBytes)
  1205. }
  1206. // Now send a resume token to drain the queue.
  1207. ms.AddMsg(nil, true)
  1208. // Wait for all rows to be processes.
  1209. if err := sr.waitn(1); err != nil {
  1210. t.Fatalf("failed to wait for rows to be processed: %v", err)
  1211. }
  1212. if r.bytesBetweenResumeTokens != 0 {
  1213. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1214. }
  1215. // Let server send maxBuffers - 1 rows.
  1216. wantQueueBytes = 0
  1217. for i := 0; i < maxBuffers-1; i++ {
  1218. wantQueueBytes += proto.Size(&sppb.PartialResultSet{
  1219. Metadata: kvMeta,
  1220. Values: []*proto3.Value{
  1221. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1222. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1223. },
  1224. })
  1225. ms.AddMsg(nil, false)
  1226. }
  1227. if err := sr.waitn(maxBuffers - 1); err != nil {
  1228. t.Fatalf("failed to wait for %v rows to be processed: %v", maxBuffers-1, err)
  1229. }
  1230. if int32(wantQueueBytes) != r.bytesBetweenResumeTokens {
  1231. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1232. }
  1233. // Trigger a state transition: queueingRetryable -> queueingUnretryable.
  1234. ms.AddMsg(nil, false)
  1235. if err := sr.waitn(1); err != nil {
  1236. t.Fatalf("failed to wait for state transition: %v", err)
  1237. }
  1238. if r.bytesBetweenResumeTokens != 0 {
  1239. t.Errorf("r.bytesBetweenResumeTokens = %v, want 0", r.bytesBetweenResumeTokens)
  1240. }
  1241. }
  1242. // Verify that client can deal with resume token correctly
  1243. func TestResumeToken(t *testing.T) {
  1244. restore := setMaxBytesBetweenResumeTokens()
  1245. defer restore()
  1246. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1247. ms.Serve()
  1248. defer ms.Stop()
  1249. cc := dialMock(t, ms)
  1250. defer cc.Close()
  1251. mc := sppb.NewSpannerClient(cc)
  1252. sr := &sReceiver{
  1253. c: make(chan int, 1000), // will never block in this test
  1254. }
  1255. rows := []*Row{}
  1256. done := make(chan error)
  1257. streaming := func() {
  1258. // Establish a stream to mock cloud spanner server.
  1259. iter := stream(context.Background(),
  1260. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1261. r, err := mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1262. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1263. ResumeToken: resumeToken,
  1264. })
  1265. sr.rpcReceiver = r
  1266. return sr, err
  1267. },
  1268. nil,
  1269. func(error) {})
  1270. defer iter.Stop()
  1271. var err error
  1272. for {
  1273. var row *Row
  1274. row, err = iter.Next()
  1275. if err == iterator.Done {
  1276. err = nil
  1277. break
  1278. }
  1279. if err != nil {
  1280. break
  1281. }
  1282. rows = append(rows, row)
  1283. }
  1284. done <- err
  1285. }
  1286. go streaming()
  1287. // Server streaming row 0 - 2, only row 1 has resume token.
  1288. // Client will receive row 0 - 2, so it will try receiving for
  1289. // 4 times (the last recv will block), and only row 0 - 1 will
  1290. // be yielded.
  1291. for i := 0; i < 3; i++ {
  1292. if i == 1 {
  1293. ms.AddMsg(nil, true)
  1294. } else {
  1295. ms.AddMsg(nil, false)
  1296. }
  1297. }
  1298. // Wait for 4 receive attempts, as explained above.
  1299. if err := sr.waitn(4); err != nil {
  1300. t.Fatalf("failed to wait for row 0 - 2: %v", err)
  1301. }
  1302. want := []*Row{
  1303. {
  1304. fields: kvMeta.RowType.Fields,
  1305. vals: []*proto3.Value{
  1306. {Kind: &proto3.Value_StringValue{StringValue: keyStr(0)}},
  1307. {Kind: &proto3.Value_StringValue{StringValue: valStr(0)}},
  1308. },
  1309. },
  1310. {
  1311. fields: kvMeta.RowType.Fields,
  1312. vals: []*proto3.Value{
  1313. {Kind: &proto3.Value_StringValue{StringValue: keyStr(1)}},
  1314. {Kind: &proto3.Value_StringValue{StringValue: valStr(1)}},
  1315. },
  1316. },
  1317. }
  1318. if !testEqual(rows, want) {
  1319. t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
  1320. }
  1321. // Inject resumable failure.
  1322. ms.AddMsg(
  1323. status.Errorf(codes.Unavailable, "mock server unavailable"),
  1324. false,
  1325. )
  1326. // Test if client detects the resumable failure and retries.
  1327. if err := sr.waitn(1); err != nil {
  1328. t.Fatalf("failed to wait for client to retry: %v", err)
  1329. }
  1330. // Client has resumed the query, now server resend row 2.
  1331. ms.AddMsg(nil, true)
  1332. if err := sr.waitn(1); err != nil {
  1333. t.Fatalf("failed to wait for resending row 2: %v", err)
  1334. }
  1335. // Now client should have received row 0 - 2.
  1336. want = append(want, &Row{
  1337. fields: kvMeta.RowType.Fields,
  1338. vals: []*proto3.Value{
  1339. {Kind: &proto3.Value_StringValue{StringValue: keyStr(2)}},
  1340. {Kind: &proto3.Value_StringValue{StringValue: valStr(2)}},
  1341. },
  1342. })
  1343. if !testEqual(rows, want) {
  1344. t.Errorf("received rows: \n%v\n, want\n%v\n", rows, want)
  1345. }
  1346. // Sending 3rd - (maxBuffers+1)th rows without resume tokens, client should buffer them.
  1347. for i := 3; i < maxBuffers+2; i++ {
  1348. ms.AddMsg(nil, false)
  1349. }
  1350. if err := sr.waitn(maxBuffers - 1); err != nil {
  1351. t.Fatalf("failed to wait for row 3-%v: %v", maxBuffers+1, err)
  1352. }
  1353. // Received rows should be unchanged.
  1354. if !testEqual(rows, want) {
  1355. t.Errorf("receive rows: \n%v\n, want\n%v\n", rows, want)
  1356. }
  1357. // Send (maxBuffers+2)th row to trigger state change of resumableStreamDecoder:
  1358. // queueingRetryable -> queueingUnretryable
  1359. ms.AddMsg(nil, false)
  1360. if err := sr.waitn(1); err != nil {
  1361. t.Fatalf("failed to wait for row %v: %v", maxBuffers+2, err)
  1362. }
  1363. // Client should yield row 3rd - (maxBuffers+2)th to application. Therefore, application should
  1364. // see row 0 - (maxBuffers+2)th so far.
  1365. for i := 3; i < maxBuffers+3; i++ {
  1366. want = append(want, &Row{
  1367. fields: kvMeta.RowType.Fields,
  1368. vals: []*proto3.Value{
  1369. {Kind: &proto3.Value_StringValue{StringValue: keyStr(i)}},
  1370. {Kind: &proto3.Value_StringValue{StringValue: valStr(i)}},
  1371. },
  1372. })
  1373. }
  1374. if !testEqual(rows, want) {
  1375. t.Errorf("received rows: \n%v\n; want\n%v\n", rows, want)
  1376. }
  1377. // Inject resumable error, but since resumableStreamDecoder is already at queueingUnretryable
  1378. // state, query will just fail.
  1379. ms.AddMsg(
  1380. status.Errorf(codes.Unavailable, "mock server wants some sleep"),
  1381. false,
  1382. )
  1383. var gotErr error
  1384. select {
  1385. case gotErr = <-done:
  1386. case <-time.After(10 * time.Second):
  1387. t.Fatalf("timeout in waiting for failed query to return.")
  1388. }
  1389. if wantErr := toSpannerError(status.Errorf(codes.Unavailable, "mock server wants some sleep")); !testEqual(gotErr, wantErr) {
  1390. t.Fatalf("stream() returns error: %v, but want error: %v", gotErr, wantErr)
  1391. }
  1392. // Reconnect to mock Cloud Spanner.
  1393. rows = []*Row{}
  1394. go streaming()
  1395. // Let server send two rows without resume token.
  1396. for i := maxBuffers + 3; i < maxBuffers+5; i++ {
  1397. ms.AddMsg(nil, false)
  1398. }
  1399. if err := sr.waitn(3); err != nil {
  1400. t.Fatalf("failed to wait for row %v - %v: %v", maxBuffers+3, maxBuffers+5, err)
  1401. }
  1402. if len(rows) > 0 {
  1403. t.Errorf("client received some rows unexpectedly: %v, want nothing", rows)
  1404. }
  1405. // Let server end the query.
  1406. ms.AddMsg(io.EOF, false)
  1407. select {
  1408. case gotErr = <-done:
  1409. case <-time.After(10 * time.Second):
  1410. t.Fatalf("timeout in waiting for failed query to return")
  1411. }
  1412. if gotErr != nil {
  1413. t.Fatalf("stream() returns unexpected error: %v, but want no error", gotErr)
  1414. }
  1415. // Verify if a normal server side EOF flushes all queued rows.
  1416. want = []*Row{
  1417. {
  1418. fields: kvMeta.RowType.Fields,
  1419. vals: []*proto3.Value{
  1420. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 3)}},
  1421. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 3)}},
  1422. },
  1423. },
  1424. {
  1425. fields: kvMeta.RowType.Fields,
  1426. vals: []*proto3.Value{
  1427. {Kind: &proto3.Value_StringValue{StringValue: keyStr(maxBuffers + 4)}},
  1428. {Kind: &proto3.Value_StringValue{StringValue: valStr(maxBuffers + 4)}},
  1429. },
  1430. },
  1431. }
  1432. if !testEqual(rows, want) {
  1433. t.Errorf("received rows: \n%v\n; but want\n%v\n", rows, want)
  1434. }
  1435. }
  1436. // Verify that streaming query get retried upon real gRPC server transport failures.
  1437. func TestGrpcReconnect(t *testing.T) {
  1438. restore := setMaxBytesBetweenResumeTokens()
  1439. defer restore()
  1440. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1441. ms.Serve()
  1442. defer ms.Stop()
  1443. cc := dialMock(t, ms)
  1444. defer cc.Close()
  1445. mc := sppb.NewSpannerClient(cc)
  1446. retry := make(chan int)
  1447. row := make(chan int)
  1448. var err error
  1449. go func() {
  1450. r := 0
  1451. // Establish a stream to mock cloud spanner server.
  1452. iter := stream(context.Background(),
  1453. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1454. if r > 0 {
  1455. // This RPC attempt is a retry, signal it.
  1456. retry <- r
  1457. }
  1458. r++
  1459. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1460. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1461. ResumeToken: resumeToken,
  1462. })
  1463. },
  1464. nil,
  1465. func(error) {})
  1466. defer iter.Stop()
  1467. for {
  1468. _, err = iter.Next()
  1469. if err == iterator.Done {
  1470. err = nil
  1471. break
  1472. }
  1473. if err != nil {
  1474. break
  1475. }
  1476. row <- 0
  1477. }
  1478. }()
  1479. // Add a message and wait for the receipt.
  1480. ms.AddMsg(nil, true)
  1481. select {
  1482. case <-row:
  1483. case <-time.After(10 * time.Second):
  1484. t.Fatalf("expect stream to be established within 10 seconds, but it didn't")
  1485. }
  1486. // Error injection: force server to close all connections.
  1487. ms.Stop()
  1488. // Test to see if client respond to the real RPC failure correctly by
  1489. // retrying RPC.
  1490. select {
  1491. case r, ok := <-retry:
  1492. if ok && r == 1 {
  1493. break
  1494. }
  1495. t.Errorf("retry count = %v, want 1", r)
  1496. case <-time.After(10 * time.Second):
  1497. t.Errorf("client library failed to respond after 10 seconds, aborting")
  1498. return
  1499. }
  1500. }
  1501. // Test cancel/timeout for client operations.
  1502. func TestCancelTimeout(t *testing.T) {
  1503. restore := setMaxBytesBetweenResumeTokens()
  1504. defer restore()
  1505. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1506. ms.Serve()
  1507. defer ms.Stop()
  1508. cc := dialMock(t, ms)
  1509. defer cc.Close()
  1510. mc := sppb.NewSpannerClient(cc)
  1511. done := make(chan int)
  1512. go func() {
  1513. for {
  1514. ms.AddMsg(nil, true)
  1515. }
  1516. }()
  1517. // Test cancelling query.
  1518. ctx, cancel := context.WithCancel(context.Background())
  1519. var err error
  1520. go func() {
  1521. // Establish a stream to mock cloud spanner server.
  1522. iter := stream(ctx,
  1523. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1524. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1525. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1526. ResumeToken: resumeToken,
  1527. })
  1528. },
  1529. nil,
  1530. func(error) {})
  1531. defer iter.Stop()
  1532. for {
  1533. _, err = iter.Next()
  1534. if err == iterator.Done {
  1535. break
  1536. }
  1537. if err != nil {
  1538. done <- 0
  1539. break
  1540. }
  1541. }
  1542. }()
  1543. cancel()
  1544. select {
  1545. case <-done:
  1546. if ErrCode(err) != codes.Canceled {
  1547. t.Errorf("streaming query is canceled and returns error %v, want error code %v", err, codes.Canceled)
  1548. }
  1549. case <-time.After(1 * time.Second):
  1550. t.Errorf("query doesn't exit timely after being cancelled")
  1551. }
  1552. // Test query timeout.
  1553. ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
  1554. go func() {
  1555. // Establish a stream to mock cloud spanner server.
  1556. iter := stream(ctx,
  1557. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1558. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1559. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1560. ResumeToken: resumeToken,
  1561. })
  1562. },
  1563. nil,
  1564. func(error) {})
  1565. defer iter.Stop()
  1566. for {
  1567. _, err = iter.Next()
  1568. if err == iterator.Done {
  1569. err = nil
  1570. break
  1571. }
  1572. if err != nil {
  1573. break
  1574. }
  1575. }
  1576. done <- 0
  1577. }()
  1578. select {
  1579. case <-done:
  1580. if wantErr := codes.DeadlineExceeded; ErrCode(err) != wantErr {
  1581. t.Errorf("streaming query timeout returns error %v, want error code %v", err, wantErr)
  1582. }
  1583. case <-time.After(2 * time.Second):
  1584. t.Errorf("query doesn't timeout as expected")
  1585. }
  1586. }
  1587. func TestRowIteratorDo(t *testing.T) {
  1588. restore := setMaxBytesBetweenResumeTokens()
  1589. defer restore()
  1590. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1591. ms.Serve()
  1592. defer ms.Stop()
  1593. cc := dialMock(t, ms)
  1594. defer cc.Close()
  1595. mc := sppb.NewSpannerClient(cc)
  1596. for i := 0; i < 3; i++ {
  1597. ms.AddMsg(nil, false)
  1598. }
  1599. ms.AddMsg(io.EOF, true)
  1600. nRows := 0
  1601. iter := stream(context.Background(),
  1602. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1603. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1604. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1605. ResumeToken: resumeToken,
  1606. })
  1607. },
  1608. nil,
  1609. func(error) {})
  1610. err := iter.Do(func(r *Row) error { nRows++; return nil })
  1611. if err != nil {
  1612. t.Errorf("Using Do: %v", err)
  1613. }
  1614. if nRows != 3 {
  1615. t.Errorf("got %d rows, want 3", nRows)
  1616. }
  1617. }
  1618. func TestRowIteratorDoWithError(t *testing.T) {
  1619. restore := setMaxBytesBetweenResumeTokens()
  1620. defer restore()
  1621. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1622. ms.Serve()
  1623. defer ms.Stop()
  1624. cc := dialMock(t, ms)
  1625. defer cc.Close()
  1626. mc := sppb.NewSpannerClient(cc)
  1627. for i := 0; i < 3; i++ {
  1628. ms.AddMsg(nil, false)
  1629. }
  1630. ms.AddMsg(io.EOF, true)
  1631. iter := stream(context.Background(),
  1632. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1633. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1634. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1635. ResumeToken: resumeToken,
  1636. })
  1637. },
  1638. nil,
  1639. func(error) {})
  1640. injected := errors.New("Failed iterator")
  1641. err := iter.Do(func(r *Row) error { return injected })
  1642. if err != injected {
  1643. t.Errorf("got <%v>, want <%v>", err, injected)
  1644. }
  1645. }
  1646. func TestIteratorStopEarly(t *testing.T) {
  1647. ctx := context.Background()
  1648. restore := setMaxBytesBetweenResumeTokens()
  1649. defer restore()
  1650. ms := testutil.NewMockCloudSpanner(t, trxTs)
  1651. ms.Serve()
  1652. defer ms.Stop()
  1653. cc := dialMock(t, ms)
  1654. defer cc.Close()
  1655. mc := sppb.NewSpannerClient(cc)
  1656. ms.AddMsg(nil, false)
  1657. ms.AddMsg(nil, false)
  1658. ms.AddMsg(io.EOF, true)
  1659. iter := stream(ctx,
  1660. func(ct context.Context, resumeToken []byte) (streamingReceiver, error) {
  1661. return mc.ExecuteStreamingSql(ct, &sppb.ExecuteSqlRequest{
  1662. Sql: "SELECT t.key key, t.value value FROM t_mock t",
  1663. ResumeToken: resumeToken,
  1664. })
  1665. },
  1666. nil,
  1667. func(error) {})
  1668. _, err := iter.Next()
  1669. if err != nil {
  1670. t.Fatalf("before Stop: %v", err)
  1671. }
  1672. iter.Stop()
  1673. // Stop sets r.err to the FailedPrecondition error "Next called after Stop".
  1674. // Override that here so this test can observe the Canceled error from the stream.
  1675. iter.err = nil
  1676. iter.Next()
  1677. if ErrCode(iter.streamd.lastErr()) != codes.Canceled {
  1678. t.Errorf("after Stop: got %v, wanted Canceled", err)
  1679. }
  1680. }
  1681. func TestIteratorWithError(t *testing.T) {
  1682. injected := errors.New("Failed iterator")
  1683. iter := RowIterator{err: injected}
  1684. defer iter.Stop()
  1685. if _, err := iter.Next(); err != injected {
  1686. t.Fatalf("Expected error: %v, got %v", injected, err)
  1687. }
  1688. }
  1689. func dialMock(t *testing.T, ms *testutil.MockCloudSpanner) *grpc.ClientConn {
  1690. cc, err := grpc.Dial(ms.Addr(), grpc.WithInsecure(), grpc.WithBlock())
  1691. if err != nil {
  1692. t.Fatalf("Dial(%q) = %v", ms.Addr(), err)
  1693. }
  1694. return cc
  1695. }