|
- // Copyright 2017 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package pubsub
-
- // This file provides a mock in-memory pubsub server for streaming pull testing.
-
- import (
- "context"
- "io"
- "sync"
- "time"
-
- "cloud.google.com/go/internal/testutil"
- emptypb "github.com/golang/protobuf/ptypes/empty"
- pb "google.golang.org/genproto/googleapis/pubsub/v1"
- )
-
- type mockServer struct {
- srv *testutil.Server
-
- pb.SubscriberServer
-
- Addr string
-
- mu sync.Mutex
- Acked map[string]bool // acked message IDs
- Deadlines map[string]int32 // deadlines by message ID
- pullResponses []*pullResponse
- ackErrs []error
- modAckErrs []error
- wg sync.WaitGroup
- sub *pb.Subscription
- }
-
- type pullResponse struct {
- msgs []*pb.ReceivedMessage
- err error
- }
-
- func newMockServer(port int) (*mockServer, error) {
- srv, err := testutil.NewServerWithPort(port)
- if err != nil {
- return nil, err
- }
- mock := &mockServer{
- srv: srv,
- Addr: srv.Addr,
- Acked: map[string]bool{},
- Deadlines: map[string]int32{},
- sub: &pb.Subscription{
- AckDeadlineSeconds: 10,
- PushConfig: &pb.PushConfig{},
- },
- }
- pb.RegisterSubscriberServer(srv.Gsrv, mock)
- srv.Start()
- return mock, nil
- }
-
- // Each call to addStreamingPullMessages results in one StreamingPullResponse.
- func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
- s.mu.Lock()
- s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
- s.mu.Unlock()
- }
-
- func (s *mockServer) addStreamingPullError(err error) {
- s.mu.Lock()
- s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
- s.mu.Unlock()
- }
-
- func (s *mockServer) addAckResponse(err error) {
- s.mu.Lock()
- s.ackErrs = append(s.ackErrs, err)
- s.mu.Unlock()
- }
-
- func (s *mockServer) addModAckResponse(err error) {
- s.mu.Lock()
- s.modAckErrs = append(s.modAckErrs, err)
- s.mu.Unlock()
- }
-
- func (s *mockServer) wait() {
- s.wg.Wait()
- }
-
- func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
- s.wg.Add(1)
- defer s.wg.Done()
- errc := make(chan error, 1)
- s.wg.Add(1)
- go func() {
- defer s.wg.Done()
- for {
- req, err := stream.Recv()
- if err != nil {
- errc <- err
- return
- }
- s.mu.Lock()
- for _, id := range req.AckIds {
- s.Acked[id] = true
- }
- for i, id := range req.ModifyDeadlineAckIds {
- s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
- }
- s.mu.Unlock()
- }
- }()
- // Send responses.
- for {
- s.mu.Lock()
- if len(s.pullResponses) == 0 {
- s.mu.Unlock()
- // Nothing to send, so wait for the client to shut down the stream.
- err := <-errc // a real error, or at least EOF
- if err == io.EOF {
- return nil
- }
- return err
- }
- pr := s.pullResponses[0]
- s.pullResponses = s.pullResponses[1:]
- s.mu.Unlock()
- if pr.err != nil {
- // Add a slight delay to ensure the server receives any
- // messages en route from the client before shutting down the stream.
- // This reduces flakiness of tests involving retry.
- time.Sleep(200 * time.Millisecond)
- }
- if pr.err == io.EOF {
- return nil
- }
- if pr.err != nil {
- return pr.err
- }
- // Return any error from Recv.
- select {
- case err := <-errc:
- return err
- default:
- }
- res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
- if err := stream.Send(res); err != nil {
- return err
- }
- }
- }
-
- func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
- var err error
- s.mu.Lock()
- if len(s.ackErrs) > 0 {
- err = s.ackErrs[0]
- s.ackErrs = s.ackErrs[1:]
- }
- s.mu.Unlock()
- if err != nil {
- return nil, err
- }
- for _, id := range req.AckIds {
- s.Acked[id] = true
- }
- return &emptypb.Empty{}, nil
- }
-
- func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
- var err error
- s.mu.Lock()
- if len(s.modAckErrs) > 0 {
- err = s.modAckErrs[0]
- s.modAckErrs = s.modAckErrs[1:]
- }
- s.mu.Unlock()
- if err != nil {
- return nil, err
- }
- for _, id := range req.AckIds {
- s.Deadlines[id] = req.AckDeadlineSeconds
- }
- return &emptypb.Empty{}, nil
- }
-
- func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
- return s.sub, nil
- }
|