|
- // Copyright 2017 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package firestore
-
- // A simple mock server.
-
- import (
- "fmt"
- "strings"
-
- "cloud.google.com/go/internal/testutil"
- pb "google.golang.org/genproto/googleapis/firestore/v1beta1"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/status"
-
- "github.com/golang/protobuf/proto"
- "github.com/golang/protobuf/ptypes/empty"
- "golang.org/x/net/context"
- )
-
- type mockServer struct {
- pb.FirestoreServer
-
- Addr string
-
- reqItems []reqItem
- resps []interface{}
- }
-
- type reqItem struct {
- wantReq proto.Message
- adjust func(gotReq proto.Message)
- }
-
- func newMockServer() (*mockServer, error) {
- srv, err := testutil.NewServer()
- if err != nil {
- return nil, err
- }
- mock := &mockServer{Addr: srv.Addr}
- pb.RegisterFirestoreServer(srv.Gsrv, mock)
- srv.Start()
- return mock, nil
- }
-
- // addRPC adds a (request, response) pair to the server's list of expected
- // interactions. The server will compare the incoming request with wantReq
- // using proto.Equal. The response can be a message or an error.
- //
- // For the Listen RPC, resp should be a []interface{}, where each element
- // is either ListenResponse or an error.
- //
- // Passing nil for wantReq disables the request check.
- func (s *mockServer) addRPC(wantReq proto.Message, resp interface{}) {
- s.addRPCAdjust(wantReq, resp, nil)
- }
-
- // addRPCAdjust is like addRPC, but accepts a function that can be used
- // to tweak the requests before comparison, for example to adjust for
- // randomness.
- func (s *mockServer) addRPCAdjust(wantReq proto.Message, resp interface{}, adjust func(proto.Message)) {
- s.reqItems = append(s.reqItems, reqItem{wantReq, adjust})
- s.resps = append(s.resps, resp)
- }
-
- // popRPC compares the request with the next expected (request, response) pair.
- // It returns the response, or an error if the request doesn't match what
- // was expected or there are no expected rpcs.
- func (s *mockServer) popRPC(gotReq proto.Message) (interface{}, error) {
- if len(s.reqItems) == 0 {
- panic("out of RPCs")
- }
- ri := s.reqItems[0]
- s.reqItems = s.reqItems[1:]
- if ri.wantReq != nil {
- if ri.adjust != nil {
- ri.adjust(gotReq)
- }
- if !proto.Equal(gotReq, ri.wantReq) {
- return nil, fmt.Errorf("mockServer: bad request\ngot: %T\n%s\nwant: %T\n%s",
- gotReq, proto.MarshalTextString(gotReq),
- ri.wantReq, proto.MarshalTextString(ri.wantReq))
- }
- }
- resp := s.resps[0]
- s.resps = s.resps[1:]
- if err, ok := resp.(error); ok {
- return nil, err
- }
- return resp, nil
- }
-
- func (s *mockServer) reset() {
- s.reqItems = nil
- s.resps = nil
- }
-
- func (s *mockServer) GetDocument(_ context.Context, req *pb.GetDocumentRequest) (*pb.Document, error) {
- res, err := s.popRPC(req)
- if err != nil {
- return nil, err
- }
- return res.(*pb.Document), nil
- }
-
- func (s *mockServer) Commit(_ context.Context, req *pb.CommitRequest) (*pb.CommitResponse, error) {
- res, err := s.popRPC(req)
- if err != nil {
- return nil, err
- }
- return res.(*pb.CommitResponse), nil
- }
-
- func (s *mockServer) BatchGetDocuments(req *pb.BatchGetDocumentsRequest, bs pb.Firestore_BatchGetDocumentsServer) error {
- res, err := s.popRPC(req)
- if err != nil {
- return err
- }
- responses := res.([]interface{})
- for _, res := range responses {
- switch res := res.(type) {
- case *pb.BatchGetDocumentsResponse:
- if err := bs.Send(res); err != nil {
- return err
- }
- case error:
- return res
- default:
- panic(fmt.Sprintf("bad response type in BatchGetDocuments: %+v", res))
- }
- }
- return nil
- }
-
- func (s *mockServer) RunQuery(req *pb.RunQueryRequest, qs pb.Firestore_RunQueryServer) error {
- res, err := s.popRPC(req)
- if err != nil {
- return err
- }
- responses := res.([]interface{})
- for _, res := range responses {
- switch res := res.(type) {
- case *pb.RunQueryResponse:
- if err := qs.Send(res); err != nil {
- return err
- }
- case error:
- return res
- default:
- panic(fmt.Sprintf("bad response type in RunQuery: %+v", res))
- }
- }
- return nil
- }
-
- func (s *mockServer) BeginTransaction(_ context.Context, req *pb.BeginTransactionRequest) (*pb.BeginTransactionResponse, error) {
- res, err := s.popRPC(req)
- if err != nil {
- return nil, err
- }
- return res.(*pb.BeginTransactionResponse), nil
- }
-
- func (s *mockServer) Rollback(_ context.Context, req *pb.RollbackRequest) (*empty.Empty, error) {
- res, err := s.popRPC(req)
- if err != nil {
- return nil, err
- }
- return res.(*empty.Empty), nil
- }
-
- func (s *mockServer) Listen(stream pb.Firestore_ListenServer) error {
- req, err := stream.Recv()
- if err != nil {
- return err
- }
- responses, err := s.popRPC(req)
- if err != nil {
- if status.Code(err) == codes.Unknown && strings.Contains(err.Error(), "mockServer") {
- // The stream will retry on Unknown, but we don't want that to happen if
- // the error comes from us.
- panic(err)
- }
- return err
- }
- for _, res := range responses.([]interface{}) {
- if err, ok := res.(error); ok {
- return err
- }
- if err := stream.Send(res.(*pb.ListenResponse)); err != nil {
- return err
- }
- }
- return nil
- }
|