You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

479 lines
13 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package firestore
  15. import (
  16. "context"
  17. "testing"
  18. "github.com/golang/protobuf/ptypes/empty"
  19. "google.golang.org/api/iterator"
  20. pb "google.golang.org/genproto/googleapis/firestore/v1"
  21. "google.golang.org/grpc"
  22. "google.golang.org/grpc/codes"
  23. "google.golang.org/grpc/status"
  24. )
  25. func TestRunTransaction(t *testing.T) {
  26. ctx := context.Background()
  27. const db = "projects/projectID/databases/(default)"
  28. tid := []byte{1}
  29. c, srv := newMock(t)
  30. beginReq := &pb.BeginTransactionRequest{Database: db}
  31. beginRes := &pb.BeginTransactionResponse{Transaction: tid}
  32. commitReq := &pb.CommitRequest{Database: db, Transaction: tid}
  33. // Empty transaction.
  34. srv.addRPC(beginReq, beginRes)
  35. srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
  36. err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
  37. if err != nil {
  38. t.Fatal(err)
  39. }
  40. // Transaction with read and write.
  41. srv.reset()
  42. srv.addRPC(beginReq, beginRes)
  43. aDoc := &pb.Document{
  44. Name: db + "/documents/C/a",
  45. CreateTime: aTimestamp,
  46. UpdateTime: aTimestamp2,
  47. Fields: map[string]*pb.Value{"count": intval(1)},
  48. }
  49. srv.addRPC(
  50. &pb.BatchGetDocumentsRequest{
  51. Database: c.path(),
  52. Documents: []string{db + "/documents/C/a"},
  53. ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
  54. }, []interface{}{
  55. &pb.BatchGetDocumentsResponse{
  56. Result: &pb.BatchGetDocumentsResponse_Found{aDoc},
  57. ReadTime: aTimestamp2,
  58. },
  59. })
  60. aDoc2 := &pb.Document{
  61. Name: aDoc.Name,
  62. Fields: map[string]*pb.Value{"count": intval(2)},
  63. }
  64. srv.addRPC(
  65. &pb.CommitRequest{
  66. Database: db,
  67. Transaction: tid,
  68. Writes: []*pb.Write{{
  69. Operation: &pb.Write_Update{aDoc2},
  70. UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
  71. CurrentDocument: &pb.Precondition{
  72. ConditionType: &pb.Precondition_Exists{true},
  73. },
  74. }},
  75. },
  76. &pb.CommitResponse{CommitTime: aTimestamp3},
  77. )
  78. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  79. docref := c.Collection("C").Doc("a")
  80. doc, err := tx.Get(docref)
  81. if err != nil {
  82. return err
  83. }
  84. count, err := doc.DataAt("count")
  85. if err != nil {
  86. return err
  87. }
  88. return tx.Update(docref, []Update{{Path: "count", Value: count.(int64) + 1}})
  89. })
  90. if err != nil {
  91. t.Fatal(err)
  92. }
  93. // Query
  94. srv.reset()
  95. srv.addRPC(beginReq, beginRes)
  96. srv.addRPC(
  97. &pb.RunQueryRequest{
  98. Parent: db + "/documents",
  99. QueryType: &pb.RunQueryRequest_StructuredQuery{
  100. &pb.StructuredQuery{
  101. From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: "C"}},
  102. },
  103. },
  104. ConsistencySelector: &pb.RunQueryRequest_Transaction{tid},
  105. },
  106. []interface{}{},
  107. )
  108. srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp3})
  109. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  110. it := tx.Documents(c.Collection("C"))
  111. defer it.Stop()
  112. _, err := it.Next()
  113. if err != iterator.Done {
  114. return err
  115. }
  116. return nil
  117. })
  118. if err != nil {
  119. t.Fatal(err)
  120. }
  121. // Retry entire transaction.
  122. srv.reset()
  123. srv.addRPC(beginReq, beginRes)
  124. srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
  125. srv.addRPC(
  126. &pb.BeginTransactionRequest{
  127. Database: db,
  128. Options: &pb.TransactionOptions{
  129. Mode: &pb.TransactionOptions_ReadWrite_{
  130. &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
  131. },
  132. },
  133. },
  134. beginRes,
  135. )
  136. srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
  137. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { return nil })
  138. if err != nil {
  139. t.Fatal(err)
  140. }
  141. }
  142. func TestTransactionErrors(t *testing.T) {
  143. ctx := context.Background()
  144. const db = "projects/projectID/databases/(default)"
  145. c, srv := newMock(t)
  146. var (
  147. tid = []byte{1}
  148. internalErr = status.Errorf(codes.Internal, "so sad")
  149. beginReq = &pb.BeginTransactionRequest{
  150. Database: db,
  151. }
  152. beginRes = &pb.BeginTransactionResponse{Transaction: tid}
  153. getReq = &pb.BatchGetDocumentsRequest{
  154. Database: c.path(),
  155. Documents: []string{db + "/documents/C/a"},
  156. ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
  157. }
  158. rollbackReq = &pb.RollbackRequest{Database: db, Transaction: tid}
  159. commitReq = &pb.CommitRequest{Database: db, Transaction: tid}
  160. )
  161. // BeginTransaction has a permanent error.
  162. srv.addRPC(beginReq, internalErr)
  163. err := c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
  164. if grpc.Code(err) != codes.Internal {
  165. t.Errorf("got <%v>, want Internal", err)
  166. }
  167. // Get has a permanent error.
  168. get := func(_ context.Context, tx *Transaction) error {
  169. _, err := tx.Get(c.Doc("C/a"))
  170. return err
  171. }
  172. srv.reset()
  173. srv.addRPC(beginReq, beginRes)
  174. srv.addRPC(getReq, internalErr)
  175. srv.addRPC(rollbackReq, &empty.Empty{})
  176. err = c.RunTransaction(ctx, get)
  177. if grpc.Code(err) != codes.Internal {
  178. t.Errorf("got <%v>, want Internal", err)
  179. }
  180. // Get has a permanent error, but the rollback fails. We still
  181. // return Get's error.
  182. srv.reset()
  183. srv.addRPC(beginReq, beginRes)
  184. srv.addRPC(getReq, internalErr)
  185. srv.addRPC(rollbackReq, status.Errorf(codes.FailedPrecondition, ""))
  186. err = c.RunTransaction(ctx, get)
  187. if grpc.Code(err) != codes.Internal {
  188. t.Errorf("got <%v>, want Internal", err)
  189. }
  190. // Commit has a permanent error.
  191. srv.reset()
  192. srv.addRPC(beginReq, beginRes)
  193. srv.addRPC(getReq, []interface{}{
  194. &pb.BatchGetDocumentsResponse{
  195. Result: &pb.BatchGetDocumentsResponse_Found{&pb.Document{
  196. Name: "projects/projectID/databases/(default)/documents/C/a",
  197. CreateTime: aTimestamp,
  198. UpdateTime: aTimestamp2,
  199. }},
  200. ReadTime: aTimestamp2,
  201. },
  202. })
  203. srv.addRPC(commitReq, internalErr)
  204. err = c.RunTransaction(ctx, get)
  205. if grpc.Code(err) != codes.Internal {
  206. t.Errorf("got <%v>, want Internal", err)
  207. }
  208. // Read after write.
  209. srv.reset()
  210. srv.addRPC(beginReq, beginRes)
  211. srv.addRPC(rollbackReq, &empty.Empty{})
  212. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  213. if err := tx.Delete(c.Doc("C/a")); err != nil {
  214. return err
  215. }
  216. if _, err := tx.Get(c.Doc("C/a")); err != nil {
  217. return err
  218. }
  219. return nil
  220. })
  221. if err != errReadAfterWrite {
  222. t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
  223. }
  224. // Read after write, with query.
  225. srv.reset()
  226. srv.addRPC(beginReq, beginRes)
  227. srv.addRPC(rollbackReq, &empty.Empty{})
  228. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  229. if err := tx.Delete(c.Doc("C/a")); err != nil {
  230. return err
  231. }
  232. it := tx.Documents(c.Collection("C").Select("x"))
  233. defer it.Stop()
  234. if _, err := it.Next(); err != iterator.Done {
  235. return err
  236. }
  237. return nil
  238. })
  239. if err != errReadAfterWrite {
  240. t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
  241. }
  242. // Read after write fails even if the user ignores the read's error.
  243. srv.reset()
  244. srv.addRPC(beginReq, beginRes)
  245. srv.addRPC(rollbackReq, &empty.Empty{})
  246. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  247. if err := tx.Delete(c.Doc("C/a")); err != nil {
  248. return err
  249. }
  250. if _, err := tx.Get(c.Doc("C/a")); err != nil {
  251. return err
  252. }
  253. return nil
  254. })
  255. if err != errReadAfterWrite {
  256. t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite)
  257. }
  258. // Write in read-only transaction.
  259. srv.reset()
  260. srv.addRPC(
  261. &pb.BeginTransactionRequest{
  262. Database: db,
  263. Options: &pb.TransactionOptions{
  264. Mode: &pb.TransactionOptions_ReadOnly_{&pb.TransactionOptions_ReadOnly{}},
  265. },
  266. },
  267. beginRes,
  268. )
  269. srv.addRPC(rollbackReq, &empty.Empty{})
  270. err = c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  271. return tx.Delete(c.Doc("C/a"))
  272. }, ReadOnly)
  273. if err != errWriteReadOnly {
  274. t.Errorf("got <%v>, want <%v>", err, errWriteReadOnly)
  275. }
  276. // Too many retries.
  277. srv.reset()
  278. srv.addRPC(beginReq, beginRes)
  279. srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
  280. srv.addRPC(
  281. &pb.BeginTransactionRequest{
  282. Database: db,
  283. Options: &pb.TransactionOptions{
  284. Mode: &pb.TransactionOptions_ReadWrite_{
  285. &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
  286. },
  287. },
  288. },
  289. beginRes,
  290. )
  291. srv.addRPC(commitReq, status.Errorf(codes.Aborted, ""))
  292. srv.addRPC(rollbackReq, &empty.Empty{})
  293. err = c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil },
  294. MaxAttempts(2))
  295. if grpc.Code(err) != codes.Aborted {
  296. t.Errorf("got <%v>, want Aborted", err)
  297. }
  298. // Nested transaction.
  299. srv.reset()
  300. srv.addRPC(beginReq, beginRes)
  301. srv.addRPC(rollbackReq, &empty.Empty{})
  302. err = c.RunTransaction(ctx, func(ctx context.Context, tx *Transaction) error {
  303. return c.RunTransaction(ctx, func(context.Context, *Transaction) error { return nil })
  304. })
  305. if got, want := err, errNestedTransaction; got != want {
  306. t.Errorf("got <%v>, want <%v>", got, want)
  307. }
  308. }
  309. func TestTransactionGetAll(t *testing.T) {
  310. c, srv := newMock(t)
  311. defer c.Close()
  312. const dbPath = "projects/projectID/databases/(default)"
  313. tid := []byte{1}
  314. beginReq := &pb.BeginTransactionRequest{Database: dbPath}
  315. beginRes := &pb.BeginTransactionResponse{Transaction: tid}
  316. srv.addRPC(beginReq, beginRes)
  317. req := &pb.BatchGetDocumentsRequest{
  318. Database: dbPath,
  319. Documents: []string{
  320. dbPath + "/documents/C/a",
  321. dbPath + "/documents/C/b",
  322. dbPath + "/documents/C/c",
  323. },
  324. ConsistencySelector: &pb.BatchGetDocumentsRequest_Transaction{tid},
  325. }
  326. err := c.RunTransaction(context.Background(), func(_ context.Context, tx *Transaction) error {
  327. testGetAll(t, c, srv, dbPath,
  328. func(drs []*DocumentRef) ([]*DocumentSnapshot, error) { return tx.GetAll(drs) },
  329. req)
  330. commitReq := &pb.CommitRequest{Database: dbPath, Transaction: tid}
  331. srv.addRPC(commitReq, &pb.CommitResponse{CommitTime: aTimestamp})
  332. return nil
  333. })
  334. if err != nil {
  335. t.Fatal(err)
  336. }
  337. }
  338. // Each retry attempt has the same amount of commit writes.
  339. func TestRunTransaction_Retries(t *testing.T) {
  340. ctx := context.Background()
  341. const db = "projects/projectID/databases/(default)"
  342. tid := []byte{1}
  343. c, srv := newMock(t)
  344. srv.addRPC(
  345. &pb.BeginTransactionRequest{Database: db},
  346. &pb.BeginTransactionResponse{Transaction: tid},
  347. )
  348. aDoc := &pb.Document{
  349. Name: db + "/documents/C/a",
  350. CreateTime: aTimestamp,
  351. UpdateTime: aTimestamp2,
  352. Fields: map[string]*pb.Value{"count": intval(1)},
  353. }
  354. aDoc2 := &pb.Document{
  355. Name: aDoc.Name,
  356. Fields: map[string]*pb.Value{"count": intval(7)},
  357. }
  358. srv.addRPC(
  359. &pb.CommitRequest{
  360. Database: db,
  361. Transaction: tid,
  362. Writes: []*pb.Write{{
  363. Operation: &pb.Write_Update{aDoc2},
  364. UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
  365. CurrentDocument: &pb.Precondition{
  366. ConditionType: &pb.Precondition_Exists{true},
  367. },
  368. }},
  369. },
  370. status.Errorf(codes.Aborted, "something failed! please retry me!"),
  371. )
  372. srv.addRPC(
  373. &pb.BeginTransactionRequest{
  374. Database: db,
  375. Options: &pb.TransactionOptions{
  376. Mode: &pb.TransactionOptions_ReadWrite_{
  377. &pb.TransactionOptions_ReadWrite{RetryTransaction: tid},
  378. },
  379. },
  380. },
  381. &pb.BeginTransactionResponse{Transaction: tid},
  382. )
  383. srv.addRPC(
  384. &pb.CommitRequest{
  385. Database: db,
  386. Transaction: tid,
  387. Writes: []*pb.Write{{
  388. Operation: &pb.Write_Update{aDoc2},
  389. UpdateMask: &pb.DocumentMask{FieldPaths: []string{"count"}},
  390. CurrentDocument: &pb.Precondition{
  391. ConditionType: &pb.Precondition_Exists{true},
  392. },
  393. }},
  394. },
  395. &pb.CommitResponse{CommitTime: aTimestamp3},
  396. )
  397. err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error {
  398. docref := c.Collection("C").Doc("a")
  399. return tx.Update(docref, []Update{{Path: "count", Value: 7}})
  400. })
  401. if err != nil {
  402. t.Fatal(err)
  403. }
  404. }
  405. // Non-transactional operations are allowed in transactions (although
  406. // discouraged).
  407. func TestRunTransaction_NonTransactionalOp(t *testing.T) {
  408. ctx := context.Background()
  409. const db = "projects/projectID/databases/(default)"
  410. tid := []byte{1}
  411. c, srv := newMock(t)
  412. beginReq := &pb.BeginTransactionRequest{Database: db}
  413. beginRes := &pb.BeginTransactionResponse{Transaction: tid}
  414. srv.reset()
  415. srv.addRPC(beginReq, beginRes)
  416. aDoc := &pb.Document{
  417. Name: db + "/documents/C/a",
  418. CreateTime: aTimestamp,
  419. UpdateTime: aTimestamp2,
  420. Fields: map[string]*pb.Value{"count": intval(1)},
  421. }
  422. srv.addRPC(
  423. &pb.BatchGetDocumentsRequest{
  424. Database: c.path(),
  425. Documents: []string{db + "/documents/C/a"},
  426. }, []interface{}{
  427. &pb.BatchGetDocumentsResponse{
  428. Result: &pb.BatchGetDocumentsResponse_Found{aDoc},
  429. ReadTime: aTimestamp2,
  430. },
  431. })
  432. srv.addRPC(
  433. &pb.CommitRequest{
  434. Database: db,
  435. Transaction: tid,
  436. },
  437. &pb.CommitResponse{CommitTime: aTimestamp3},
  438. )
  439. if err := c.RunTransaction(ctx, func(ctx2 context.Context, tx *Transaction) error {
  440. docref := c.Collection("C").Doc("a")
  441. if _, err := c.GetAll(ctx2, []*DocumentRef{docref}); err != nil {
  442. t.Fatal(err)
  443. }
  444. return nil
  445. }); err != nil {
  446. t.Fatal(err)
  447. }
  448. }