You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

201 lines
4.6 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package pubsub
  15. // This file provides a mock in-memory pubsub server for streaming pull testing.
  16. import (
  17. "context"
  18. "io"
  19. "sync"
  20. "time"
  21. "cloud.google.com/go/internal/testutil"
  22. emptypb "github.com/golang/protobuf/ptypes/empty"
  23. pb "google.golang.org/genproto/googleapis/pubsub/v1"
  24. )
  25. type mockServer struct {
  26. srv *testutil.Server
  27. pb.SubscriberServer
  28. Addr string
  29. mu sync.Mutex
  30. Acked map[string]bool // acked message IDs
  31. Deadlines map[string]int32 // deadlines by message ID
  32. pullResponses []*pullResponse
  33. ackErrs []error
  34. modAckErrs []error
  35. wg sync.WaitGroup
  36. sub *pb.Subscription
  37. }
  38. type pullResponse struct {
  39. msgs []*pb.ReceivedMessage
  40. err error
  41. }
  42. func newMockServer(port int) (*mockServer, error) {
  43. srv, err := testutil.NewServerWithPort(port)
  44. if err != nil {
  45. return nil, err
  46. }
  47. mock := &mockServer{
  48. srv: srv,
  49. Addr: srv.Addr,
  50. Acked: map[string]bool{},
  51. Deadlines: map[string]int32{},
  52. sub: &pb.Subscription{
  53. AckDeadlineSeconds: 10,
  54. PushConfig: &pb.PushConfig{},
  55. },
  56. }
  57. pb.RegisterSubscriberServer(srv.Gsrv, mock)
  58. srv.Start()
  59. return mock, nil
  60. }
  61. // Each call to addStreamingPullMessages results in one StreamingPullResponse.
  62. func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
  63. s.mu.Lock()
  64. s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
  65. s.mu.Unlock()
  66. }
  67. func (s *mockServer) addStreamingPullError(err error) {
  68. s.mu.Lock()
  69. s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
  70. s.mu.Unlock()
  71. }
  72. func (s *mockServer) addAckResponse(err error) {
  73. s.mu.Lock()
  74. s.ackErrs = append(s.ackErrs, err)
  75. s.mu.Unlock()
  76. }
  77. func (s *mockServer) addModAckResponse(err error) {
  78. s.mu.Lock()
  79. s.modAckErrs = append(s.modAckErrs, err)
  80. s.mu.Unlock()
  81. }
  82. func (s *mockServer) wait() {
  83. s.wg.Wait()
  84. }
  85. func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
  86. s.wg.Add(1)
  87. defer s.wg.Done()
  88. errc := make(chan error, 1)
  89. s.wg.Add(1)
  90. go func() {
  91. defer s.wg.Done()
  92. for {
  93. req, err := stream.Recv()
  94. if err != nil {
  95. errc <- err
  96. return
  97. }
  98. s.mu.Lock()
  99. for _, id := range req.AckIds {
  100. s.Acked[id] = true
  101. }
  102. for i, id := range req.ModifyDeadlineAckIds {
  103. s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
  104. }
  105. s.mu.Unlock()
  106. }
  107. }()
  108. // Send responses.
  109. for {
  110. s.mu.Lock()
  111. if len(s.pullResponses) == 0 {
  112. s.mu.Unlock()
  113. // Nothing to send, so wait for the client to shut down the stream.
  114. err := <-errc // a real error, or at least EOF
  115. if err == io.EOF {
  116. return nil
  117. }
  118. return err
  119. }
  120. pr := s.pullResponses[0]
  121. s.pullResponses = s.pullResponses[1:]
  122. s.mu.Unlock()
  123. if pr.err != nil {
  124. // Add a slight delay to ensure the server receives any
  125. // messages en route from the client before shutting down the stream.
  126. // This reduces flakiness of tests involving retry.
  127. time.Sleep(200 * time.Millisecond)
  128. }
  129. if pr.err == io.EOF {
  130. return nil
  131. }
  132. if pr.err != nil {
  133. return pr.err
  134. }
  135. // Return any error from Recv.
  136. select {
  137. case err := <-errc:
  138. return err
  139. default:
  140. }
  141. res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
  142. if err := stream.Send(res); err != nil {
  143. return err
  144. }
  145. }
  146. }
  147. func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
  148. var err error
  149. s.mu.Lock()
  150. if len(s.ackErrs) > 0 {
  151. err = s.ackErrs[0]
  152. s.ackErrs = s.ackErrs[1:]
  153. }
  154. s.mu.Unlock()
  155. if err != nil {
  156. return nil, err
  157. }
  158. for _, id := range req.AckIds {
  159. s.Acked[id] = true
  160. }
  161. return &emptypb.Empty{}, nil
  162. }
  163. func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
  164. var err error
  165. s.mu.Lock()
  166. if len(s.modAckErrs) > 0 {
  167. err = s.modAckErrs[0]
  168. s.modAckErrs = s.modAckErrs[1:]
  169. }
  170. s.mu.Unlock()
  171. if err != nil {
  172. return nil, err
  173. }
  174. for _, id := range req.AckIds {
  175. s.Deadlines[id] = req.AckDeadlineSeconds
  176. }
  177. return &emptypb.Empty{}, nil
  178. }
  179. func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
  180. return s.sub, nil
  181. }