You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

228 lines
6.1 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package firestore
  15. // A simple mock server.
  16. import (
  17. "context"
  18. "fmt"
  19. "reflect"
  20. "sort"
  21. "strings"
  22. "cloud.google.com/go/internal/testutil"
  23. "github.com/golang/protobuf/proto"
  24. "github.com/golang/protobuf/ptypes/empty"
  25. pb "google.golang.org/genproto/googleapis/firestore/v1"
  26. "google.golang.org/grpc/codes"
  27. "google.golang.org/grpc/status"
  28. )
  29. type mockServer struct {
  30. pb.FirestoreServer
  31. Addr string
  32. reqItems []reqItem
  33. resps []interface{}
  34. }
  35. type reqItem struct {
  36. wantReq proto.Message
  37. adjust func(gotReq proto.Message)
  38. }
  39. func newMockServer() (*mockServer, error) {
  40. srv, err := testutil.NewServer()
  41. if err != nil {
  42. return nil, err
  43. }
  44. mock := &mockServer{Addr: srv.Addr}
  45. pb.RegisterFirestoreServer(srv.Gsrv, mock)
  46. srv.Start()
  47. return mock, nil
  48. }
  49. // addRPC adds a (request, response) pair to the server's list of expected
  50. // interactions. The server will compare the incoming request with wantReq
  51. // using proto.Equal. The response can be a message or an error.
  52. //
  53. // For the Listen RPC, resp should be a []interface{}, where each element
  54. // is either ListenResponse or an error.
  55. //
  56. // Passing nil for wantReq disables the request check.
  57. func (s *mockServer) addRPC(wantReq proto.Message, resp interface{}) {
  58. s.addRPCAdjust(wantReq, resp, nil)
  59. }
  60. // addRPCAdjust is like addRPC, but accepts a function that can be used
  61. // to tweak the requests before comparison, for example to adjust for
  62. // randomness.
  63. func (s *mockServer) addRPCAdjust(wantReq proto.Message, resp interface{}, adjust func(proto.Message)) {
  64. s.reqItems = append(s.reqItems, reqItem{wantReq, adjust})
  65. s.resps = append(s.resps, resp)
  66. }
  67. // popRPC compares the request with the next expected (request, response) pair.
  68. // It returns the response, or an error if the request doesn't match what
  69. // was expected or there are no expected rpcs.
  70. func (s *mockServer) popRPC(gotReq proto.Message) (interface{}, error) {
  71. if len(s.reqItems) == 0 {
  72. panic(fmt.Sprintf("out of RPCs, saw %v", reflect.TypeOf(gotReq)))
  73. }
  74. ri := s.reqItems[0]
  75. s.reqItems = s.reqItems[1:]
  76. if ri.wantReq != nil {
  77. if ri.adjust != nil {
  78. ri.adjust(gotReq)
  79. }
  80. // Sort FieldTransforms by FieldPath, since slice order is undefined and proto.Equal
  81. // is strict about order.
  82. switch gotReqTyped := gotReq.(type) {
  83. case *pb.CommitRequest:
  84. for _, w := range gotReqTyped.Writes {
  85. switch opTyped := w.Operation.(type) {
  86. case *pb.Write_Transform:
  87. sort.Sort(ByFieldPath(opTyped.Transform.FieldTransforms))
  88. }
  89. }
  90. }
  91. if !proto.Equal(gotReq, ri.wantReq) {
  92. return nil, fmt.Errorf("mockServer: bad request\ngot:\n%T\n%s\nwant:\n%T\n%s",
  93. gotReq, proto.MarshalTextString(gotReq),
  94. ri.wantReq, proto.MarshalTextString(ri.wantReq))
  95. }
  96. }
  97. resp := s.resps[0]
  98. s.resps = s.resps[1:]
  99. if err, ok := resp.(error); ok {
  100. return nil, err
  101. }
  102. return resp, nil
  103. }
  104. func (a ByFieldPath) Len() int { return len(a) }
  105. func (a ByFieldPath) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
  106. func (a ByFieldPath) Less(i, j int) bool { return a[i].FieldPath < a[j].FieldPath }
  107. type ByFieldPath []*pb.DocumentTransform_FieldTransform
  108. func (s *mockServer) reset() {
  109. s.reqItems = nil
  110. s.resps = nil
  111. }
  112. func (s *mockServer) GetDocument(_ context.Context, req *pb.GetDocumentRequest) (*pb.Document, error) {
  113. res, err := s.popRPC(req)
  114. if err != nil {
  115. return nil, err
  116. }
  117. return res.(*pb.Document), nil
  118. }
  119. func (s *mockServer) Commit(_ context.Context, req *pb.CommitRequest) (*pb.CommitResponse, error) {
  120. res, err := s.popRPC(req)
  121. if err != nil {
  122. return nil, err
  123. }
  124. return res.(*pb.CommitResponse), nil
  125. }
  126. func (s *mockServer) BatchGetDocuments(req *pb.BatchGetDocumentsRequest, bs pb.Firestore_BatchGetDocumentsServer) error {
  127. res, err := s.popRPC(req)
  128. if err != nil {
  129. return err
  130. }
  131. responses := res.([]interface{})
  132. for _, res := range responses {
  133. switch res := res.(type) {
  134. case *pb.BatchGetDocumentsResponse:
  135. if err := bs.Send(res); err != nil {
  136. return err
  137. }
  138. case error:
  139. return res
  140. default:
  141. panic(fmt.Sprintf("bad response type in BatchGetDocuments: %+v", res))
  142. }
  143. }
  144. return nil
  145. }
  146. func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryServer) error {
  147. res, err := s.popRPC(req)
  148. if err != nil {
  149. return err
  150. }
  151. responses := res.([]interface{})
  152. for _, res := range responses {
  153. switch res := res.(type) {
  154. case *pb.RunQueryResponse:
  155. if err := qs.Send(res); err != nil {
  156. return err
  157. }
  158. case error:
  159. return res
  160. default:
  161. panic(fmt.Sprintf("bad response type in RunQuery: %+v", res))
  162. }
  163. }
  164. return nil
  165. }
  166. func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) {
  167. res, err := s.popRPC(req)
  168. if err != nil {
  169. return nil, err
  170. }
  171. return res.(*pb.BeginTransactionResponse), nil
  172. }
  173. func (s *mockServer) Rollback(_ context.Context, req *pb.RollbackRequest) (*empty.Empty, error) {
  174. res, err := s.popRPC(req)
  175. if err != nil {
  176. return nil, err
  177. }
  178. return res.(*empty.Empty), nil
  179. }
  180. func (s *mockServer) Listen(stream pb.Firestore_ListenServer) error {
  181. req, err := stream.Recv()
  182. if err != nil {
  183. return err
  184. }
  185. responses, err := s.popRPC(req)
  186. if err != nil {
  187. if status.Code(err) == codes.Unknown && strings.Contains(err.Error(), "mockServer") {
  188. // The stream will retry on Unknown, but we don't want that to happen if
  189. // the error comes from us.
  190. panic(err)
  191. }
  192. return err
  193. }
  194. for _, res := range responses.([]interface{}) {
  195. if err, ok := res.(error); ok {
  196. return err
  197. }
  198. if err := stream.Send(res.(*pb.ListenResponse)); err != nil {
  199. return err
  200. }
  201. }
  202. return nil
  203. }