You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

337 rivejä
9.4 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package pubsub
  15. // This file provides a fake/mock in-memory pubsub server.
  16. import (
  17. "io"
  18. "sort"
  19. "strings"
  20. "sync"
  21. "time"
  22. "cloud.google.com/go/internal/testutil"
  23. "github.com/golang/protobuf/proto"
  24. "github.com/golang/protobuf/ptypes"
  25. durpb "github.com/golang/protobuf/ptypes/duration"
  26. emptypb "github.com/golang/protobuf/ptypes/empty"
  27. "golang.org/x/net/context"
  28. pb "google.golang.org/genproto/googleapis/pubsub/v1"
  29. "google.golang.org/grpc/codes"
  30. "google.golang.org/grpc/status"
  31. )
  32. type fakeServer struct {
  33. pb.PublisherServer
  34. pb.SubscriberServer
  35. Addr string
  36. mu sync.Mutex
  37. Acked map[string]bool // acked message IDs
  38. Deadlines map[string]int32 // deadlines by message ID
  39. pullResponses []*pullResponse
  40. wg sync.WaitGroup
  41. subs map[string]*pb.Subscription
  42. topics map[string]*pb.Topic
  43. }
  44. type pullResponse struct {
  45. msgs []*pb.ReceivedMessage
  46. err error
  47. }
  48. func newFakeServer() (*fakeServer, error) {
  49. srv, err := testutil.NewServer()
  50. if err != nil {
  51. return nil, err
  52. }
  53. fake := &fakeServer{
  54. Addr: srv.Addr,
  55. Acked: map[string]bool{},
  56. Deadlines: map[string]int32{},
  57. subs: map[string]*pb.Subscription{},
  58. topics: map[string]*pb.Topic{},
  59. }
  60. pb.RegisterPublisherServer(srv.Gsrv, fake)
  61. pb.RegisterSubscriberServer(srv.Gsrv, fake)
  62. srv.Start()
  63. return fake, nil
  64. }
  65. // Each call to addStreamingPullMessages results in one StreamingPullResponse.
  66. func (s *fakeServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
  67. s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
  68. }
  69. func (s *fakeServer) addStreamingPullError(err error) {
  70. s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
  71. }
  72. func (s *fakeServer) wait() {
  73. s.wg.Wait()
  74. }
  75. func (s *fakeServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
  76. s.wg.Add(1)
  77. defer s.wg.Done()
  78. errc := make(chan error, 1)
  79. s.wg.Add(1)
  80. go func() {
  81. defer s.wg.Done()
  82. for {
  83. req, err := stream.Recv()
  84. if err != nil {
  85. errc <- err
  86. return
  87. }
  88. s.mu.Lock()
  89. for _, id := range req.AckIds {
  90. s.Acked[id] = true
  91. }
  92. for i, id := range req.ModifyDeadlineAckIds {
  93. s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
  94. }
  95. s.mu.Unlock()
  96. }
  97. }()
  98. // Send responses.
  99. for {
  100. s.mu.Lock()
  101. if len(s.pullResponses) == 0 {
  102. s.mu.Unlock()
  103. // Nothing to send, so wait for the client to shut down the stream.
  104. err := <-errc // a real error, or at least EOF
  105. if err == io.EOF {
  106. return nil
  107. }
  108. return err
  109. }
  110. pr := s.pullResponses[0]
  111. s.pullResponses = s.pullResponses[1:]
  112. s.mu.Unlock()
  113. if pr.err != nil {
  114. // Add a slight delay to ensure the server receives any
  115. // messages en route from the client before shutting down the stream.
  116. // This reduces flakiness of tests involving retry.
  117. time.Sleep(200 * time.Millisecond)
  118. }
  119. if pr.err == io.EOF {
  120. return nil
  121. }
  122. if pr.err != nil {
  123. return pr.err
  124. }
  125. // Return any error from Recv.
  126. select {
  127. case err := <-errc:
  128. return err
  129. default:
  130. }
  131. res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
  132. if err := stream.Send(res); err != nil {
  133. return err
  134. }
  135. }
  136. }
  137. const (
  138. minMessageRetentionDuration = 10 * time.Minute
  139. maxMessageRetentionDuration = 168 * time.Hour
  140. )
  141. var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
  142. func checkMRD(pmrd *durpb.Duration) error {
  143. mrd, err := ptypes.Duration(pmrd)
  144. if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
  145. return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
  146. }
  147. return nil
  148. }
  149. func checkAckDeadline(ads int32) error {
  150. if ads < 10 || ads > 600 {
  151. // PubSub service returns Unknown.
  152. return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
  153. }
  154. return nil
  155. }
  156. func (s *fakeServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
  157. for _, id := range req.AckIds {
  158. s.Acked[id] = true
  159. }
  160. return &emptypb.Empty{}, nil
  161. }
  162. func (s *fakeServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
  163. for _, id := range req.AckIds {
  164. s.Deadlines[id] = req.AckDeadlineSeconds
  165. }
  166. return &emptypb.Empty{}, nil
  167. }
  168. func (s *fakeServer) CreateSubscription(ctx context.Context, sub *pb.Subscription) (*pb.Subscription, error) {
  169. if s.subs[sub.Name] != nil {
  170. return nil, status.Errorf(codes.AlreadyExists, "subscription %q", sub.Name)
  171. }
  172. sub2 := proto.Clone(sub).(*pb.Subscription)
  173. if err := checkAckDeadline(sub.AckDeadlineSeconds); err != nil {
  174. return nil, err
  175. }
  176. if sub.MessageRetentionDuration == nil {
  177. sub2.MessageRetentionDuration = defaultMessageRetentionDuration
  178. }
  179. if err := checkMRD(sub2.MessageRetentionDuration); err != nil {
  180. return nil, err
  181. }
  182. if sub.PushConfig == nil {
  183. sub2.PushConfig = &pb.PushConfig{}
  184. }
  185. s.subs[sub.Name] = sub2
  186. return sub2, nil
  187. }
  188. func (s *fakeServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
  189. if sub := s.subs[req.Subscription]; sub != nil {
  190. return sub, nil
  191. }
  192. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
  193. }
  194. func (s *fakeServer) UpdateSubscription(ctx context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
  195. sub := s.subs[req.Subscription.Name]
  196. if sub == nil {
  197. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription.Name)
  198. }
  199. for _, path := range req.UpdateMask.Paths {
  200. switch path {
  201. case "push_config":
  202. sub.PushConfig = req.Subscription.PushConfig
  203. case "ack_deadline_seconds":
  204. a := req.Subscription.AckDeadlineSeconds
  205. if err := checkAckDeadline(a); err != nil {
  206. return nil, err
  207. }
  208. sub.AckDeadlineSeconds = a
  209. case "retain_acked_messages":
  210. sub.RetainAckedMessages = req.Subscription.RetainAckedMessages
  211. case "message_retention_duration":
  212. if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
  213. return nil, err
  214. }
  215. sub.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
  216. // TODO(jba): labels
  217. default:
  218. return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
  219. }
  220. }
  221. return sub, nil
  222. }
  223. func (s *fakeServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
  224. if s.subs[req.Subscription] == nil {
  225. return nil, status.Errorf(codes.NotFound, "subscription %q", req.Subscription)
  226. }
  227. delete(s.subs, req.Subscription)
  228. return &emptypb.Empty{}, nil
  229. }
  230. func (s *fakeServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
  231. if s.topics[t.Name] != nil {
  232. return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
  233. }
  234. t2 := proto.Clone(t).(*pb.Topic)
  235. s.topics[t.Name] = t2
  236. return t2, nil
  237. }
  238. func (s *fakeServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
  239. if t := s.topics[req.Topic]; t != nil {
  240. return t, nil
  241. }
  242. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  243. }
  244. func (s *fakeServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
  245. if s.topics[req.Topic] == nil {
  246. return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
  247. }
  248. delete(s.topics, req.Topic)
  249. return &emptypb.Empty{}, nil
  250. }
  251. func (s *fakeServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
  252. var names []string
  253. for n := range s.topics {
  254. if strings.HasPrefix(n, req.Project) {
  255. names = append(names, n)
  256. }
  257. }
  258. sort.Strings(names)
  259. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  260. if err != nil {
  261. return nil, err
  262. }
  263. res := &pb.ListTopicsResponse{NextPageToken: nextToken}
  264. for i := from; i < to; i++ {
  265. res.Topics = append(res.Topics, s.topics[names[i]])
  266. }
  267. return res, nil
  268. }
  269. func (s *fakeServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
  270. var names []string
  271. for _, sub := range s.subs {
  272. if strings.HasPrefix(sub.Name, req.Project) {
  273. names = append(names, sub.Name)
  274. }
  275. }
  276. sort.Strings(names)
  277. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  278. if err != nil {
  279. return nil, err
  280. }
  281. res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
  282. for i := from; i < to; i++ {
  283. res.Subscriptions = append(res.Subscriptions, s.subs[names[i]])
  284. }
  285. return res, nil
  286. }
  287. func (s *fakeServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
  288. var names []string
  289. for _, sub := range s.subs {
  290. if sub.Topic == req.Topic {
  291. names = append(names, sub.Name)
  292. }
  293. }
  294. sort.Strings(names)
  295. from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
  296. if err != nil {
  297. return nil, err
  298. }
  299. return &pb.ListTopicSubscriptionsResponse{
  300. Subscriptions: names[from:to],
  301. NextPageToken: nextToken,
  302. }, nil
  303. }