You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

208 lines
5.4 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package firestore
  15. // A simple mock server.
  16. import (
  17. "fmt"
  18. "strings"
  19. "cloud.google.com/go/internal/testutil"
  20. pb "google.golang.org/genproto/googleapis/firestore/v1beta1"
  21. "google.golang.org/grpc/codes"
  22. "google.golang.org/grpc/status"
  23. "github.com/golang/protobuf/proto"
  24. "github.com/golang/protobuf/ptypes/empty"
  25. "golang.org/x/net/context"
  26. )
  27. type mockServer struct {
  28. pb.FirestoreServer
  29. Addr string
  30. reqItems []reqItem
  31. resps []interface{}
  32. }
  33. type reqItem struct {
  34. wantReq proto.Message
  35. adjust func(gotReq proto.Message)
  36. }
  37. func newMockServer() (*mockServer, error) {
  38. srv, err := testutil.NewServer()
  39. if err != nil {
  40. return nil, err
  41. }
  42. mock := &mockServer{Addr: srv.Addr}
  43. pb.RegisterFirestoreServer(srv.Gsrv, mock)
  44. srv.Start()
  45. return mock, nil
  46. }
  47. // addRPC adds a (request, response) pair to the server's list of expected
  48. // interactions. The server will compare the incoming request with wantReq
  49. // using proto.Equal. The response can be a message or an error.
  50. //
  51. // For the Listen RPC, resp should be a []interface{}, where each element
  52. // is either ListenResponse or an error.
  53. //
  54. // Passing nil for wantReq disables the request check.
  55. func (s *mockServer) addRPC(wantReq proto.Message, resp interface{}) {
  56. s.addRPCAdjust(wantReq, resp, nil)
  57. }
  58. // addRPCAdjust is like addRPC, but accepts a function that can be used
  59. // to tweak the requests before comparison, for example to adjust for
  60. // randomness.
  61. func (s *mockServer) addRPCAdjust(wantReq proto.Message, resp interface{}, adjust func(proto.Message)) {
  62. s.reqItems = append(s.reqItems, reqItem{wantReq, adjust})
  63. s.resps = append(s.resps, resp)
  64. }
  65. // popRPC compares the request with the next expected (request, response) pair.
  66. // It returns the response, or an error if the request doesn't match what
  67. // was expected or there are no expected rpcs.
  68. func (s *mockServer) popRPC(gotReq proto.Message) (interface{}, error) {
  69. if len(s.reqItems) == 0 {
  70. panic("out of RPCs")
  71. }
  72. ri := s.reqItems[0]
  73. s.reqItems = s.reqItems[1:]
  74. if ri.wantReq != nil {
  75. if ri.adjust != nil {
  76. ri.adjust(gotReq)
  77. }
  78. if !proto.Equal(gotReq, ri.wantReq) {
  79. return nil, fmt.Errorf("mockServer: bad request\ngot: %T\n%s\nwant: %T\n%s",
  80. gotReq, proto.MarshalTextString(gotReq),
  81. ri.wantReq, proto.MarshalTextString(ri.wantReq))
  82. }
  83. }
  84. resp := s.resps[0]
  85. s.resps = s.resps[1:]
  86. if err, ok := resp.(error); ok {
  87. return nil, err
  88. }
  89. return resp, nil
  90. }
  91. func (s *mockServer) reset() {
  92. s.reqItems = nil
  93. s.resps = nil
  94. }
  95. func (s *mockServer) GetDocument(_ context.Context, req *pb.GetDocumentRequest) (*pb.Document, error) {
  96. res, err := s.popRPC(req)
  97. if err != nil {
  98. return nil, err
  99. }
  100. return res.(*pb.Document), nil
  101. }
  102. func (s *mockServer) Commit(_ context.Context, req *pb.CommitRequest) (*pb.CommitResponse, error) {
  103. res, err := s.popRPC(req)
  104. if err != nil {
  105. return nil, err
  106. }
  107. return res.(*pb.CommitResponse), nil
  108. }
  109. func (s *mockServer) BatchGetDocuments(req *pb.BatchGetDocumentsRequest, bs pb.Firestore_BatchGetDocumentsServer) error {
  110. res, err := s.popRPC(req)
  111. if err != nil {
  112. return err
  113. }
  114. responses := res.([]interface{})
  115. for _, res := range responses {
  116. switch res := res.(type) {
  117. case *pb.BatchGetDocumentsResponse:
  118. if err := bs.Send(res); err != nil {
  119. return err
  120. }
  121. case error:
  122. return res
  123. default:
  124. panic(fmt.Sprintf("bad response type in BatchGetDocuments: %+v", res))
  125. }
  126. }
  127. return nil
  128. }
  129. func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryServer) error {
  130. res, err := s.popRPC(req)
  131. if err != nil {
  132. return err
  133. }
  134. responses := res.([]interface{})
  135. for _, res := range responses {
  136. switch res := res.(type) {
  137. case *pb.RunQueryResponse:
  138. if err := qs.Send(res); err != nil {
  139. return err
  140. }
  141. case error:
  142. return res
  143. default:
  144. panic(fmt.Sprintf("bad response type in RunQuery: %+v", res))
  145. }
  146. }
  147. return nil
  148. }
  149. func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) {
  150. res, err := s.popRPC(req)
  151. if err != nil {
  152. return nil, err
  153. }
  154. return res.(*pb.BeginTransactionResponse), nil
  155. }
  156. func (s *mockServer) Rollback(_ context.Context, req *pb.RollbackRequest) (*empty.Empty, error) {
  157. res, err := s.popRPC(req)
  158. if err != nil {
  159. return nil, err
  160. }
  161. return res.(*empty.Empty), nil
  162. }
  163. func (s *mockServer) Listen(stream pb.Firestore_ListenServer) error {
  164. req, err := stream.Recv()
  165. if err != nil {
  166. return err
  167. }
  168. responses, err := s.popRPC(req)
  169. if err != nil {
  170. if status.Code(err) == codes.Unknown && strings.Contains(err.Error(), "mockServer") {
  171. // The stream will retry on Unknown, but we don't want that to happen if
  172. // the error comes from us.
  173. panic(err)
  174. }
  175. return err
  176. }
  177. for _, res := range responses.([]interface{}) {
  178. if err, ok := res.(error); ok {
  179. return err
  180. }
  181. if err := stream.Send(res.(*pb.ListenResponse)); err != nil {
  182. return err
  183. }
  184. }
  185. return nil
  186. }