You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

719 lines
18 KiB

  1. // Copyright 2017 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package rpcreplay
  15. import (
  16. "bufio"
  17. "encoding/binary"
  18. "errors"
  19. "fmt"
  20. "io"
  21. "log"
  22. "os"
  23. "sync"
  24. "golang.org/x/net/context"
  25. "google.golang.org/grpc"
  26. "google.golang.org/grpc/metadata"
  27. "google.golang.org/grpc/status"
  28. pb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
  29. "github.com/golang/protobuf/proto"
  30. "github.com/golang/protobuf/ptypes"
  31. "github.com/golang/protobuf/ptypes/any"
  32. spb "google.golang.org/genproto/googleapis/rpc/status"
  33. )
  34. // A Recorder records RPCs for later playback.
  35. type Recorder struct {
  36. mu sync.Mutex
  37. w *bufio.Writer
  38. f *os.File
  39. next int
  40. err error
  41. // BeforeFunc defines a function that can inspect and modify requests and responses
  42. // written to the replay file. It does not modify messages sent to the service.
  43. // It is run once before a request is written to the replay file, and once before a response
  44. // is written to the replay file.
  45. // The function is called with the method name and the message that triggered the callback.
  46. // If the function returns an error, the error will be returned to the client.
  47. // This is only executed for unary RPCs; streaming RPCs are not supported.
  48. BeforeFunc func(string, proto.Message) error
  49. }
  50. // NewRecorder creates a recorder that writes to filename. The file will
  51. // also store the initial bytes for retrieval during replay.
  52. //
  53. // You must call Close on the Recorder to ensure that all data is written.
  54. func NewRecorder(filename string, initial []byte) (*Recorder, error) {
  55. f, err := os.Create(filename)
  56. if err != nil {
  57. return nil, err
  58. }
  59. rec, err := NewRecorderWriter(f, initial)
  60. if err != nil {
  61. _ = f.Close()
  62. return nil, err
  63. }
  64. rec.f = f
  65. return rec, nil
  66. }
  67. // NewRecorderWriter creates a recorder that writes to w. The initial
  68. // bytes will also be written to w for retrieval during replay.
  69. //
  70. // You must call Close on the Recorder to ensure that all data is written.
  71. func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) {
  72. bw := bufio.NewWriter(w)
  73. if err := writeHeader(bw, initial); err != nil {
  74. return nil, err
  75. }
  76. return &Recorder{w: bw, next: 1}, nil
  77. }
  78. // DialOptions returns the options that must be passed to grpc.Dial
  79. // to enable recording.
  80. func (r *Recorder) DialOptions() []grpc.DialOption {
  81. return []grpc.DialOption{
  82. grpc.WithUnaryInterceptor(r.interceptUnary),
  83. grpc.WithStreamInterceptor(r.interceptStream),
  84. }
  85. }
  86. // Close saves any unwritten information.
  87. func (r *Recorder) Close() error {
  88. r.mu.Lock()
  89. defer r.mu.Unlock()
  90. if r.err != nil {
  91. return r.err
  92. }
  93. err := r.w.Flush()
  94. if r.f != nil {
  95. if err2 := r.f.Close(); err == nil {
  96. err = err2
  97. }
  98. }
  99. return err
  100. }
  101. // Intercepts all unary (non-stream) RPCs.
  102. func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
  103. ereq := &entry{
  104. kind: pb.Entry_REQUEST,
  105. method: method,
  106. msg: message{msg: proto.Clone(req.(proto.Message))},
  107. }
  108. if r.BeforeFunc != nil {
  109. if err := r.BeforeFunc(method, ereq.msg.msg); err != nil {
  110. return err
  111. }
  112. }
  113. refIndex, err := r.writeEntry(ereq)
  114. if err != nil {
  115. return err
  116. }
  117. ierr := invoker(ctx, method, req, res, cc, opts...)
  118. eres := &entry{
  119. kind: pb.Entry_RESPONSE,
  120. refIndex: refIndex,
  121. }
  122. // If the error is not a gRPC status, then something more
  123. // serious is wrong. More significantly, we have no way
  124. // of serializing an arbitrary error. So just return it
  125. // without recording the response.
  126. if _, ok := status.FromError(ierr); !ok {
  127. r.mu.Lock()
  128. r.err = fmt.Errorf("saw non-status error in %s response: %v (%T)", method, ierr, ierr)
  129. r.mu.Unlock()
  130. return ierr
  131. }
  132. eres.msg.set(proto.Clone(res.(proto.Message)), ierr)
  133. if r.BeforeFunc != nil {
  134. if err := r.BeforeFunc(method, eres.msg.msg); err != nil {
  135. return err
  136. }
  137. }
  138. if _, err := r.writeEntry(eres); err != nil {
  139. return err
  140. }
  141. return ierr
  142. }
  143. func (r *Recorder) writeEntry(e *entry) (int, error) {
  144. r.mu.Lock()
  145. defer r.mu.Unlock()
  146. if r.err != nil {
  147. return 0, r.err
  148. }
  149. err := writeEntry(r.w, e)
  150. if err != nil {
  151. r.err = err
  152. return 0, err
  153. }
  154. n := r.next
  155. r.next++
  156. return n, nil
  157. }
  158. func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
  159. cstream, serr := streamer(ctx, desc, cc, method, opts...)
  160. e := &entry{
  161. kind: pb.Entry_CREATE_STREAM,
  162. method: method,
  163. }
  164. e.msg.set(nil, serr)
  165. refIndex, err := r.writeEntry(e)
  166. if err != nil {
  167. return nil, err
  168. }
  169. return &recClientStream{
  170. ctx: ctx,
  171. rec: r,
  172. cstream: cstream,
  173. refIndex: refIndex,
  174. }, serr
  175. }
  176. // A recClientStream implements the gprc.ClientStream interface.
  177. // It behaves exactly like the default ClientStream, but also
  178. // records all messages sent and received.
  179. type recClientStream struct {
  180. ctx context.Context
  181. rec *Recorder
  182. cstream grpc.ClientStream
  183. refIndex int
  184. }
  185. func (rcs *recClientStream) Context() context.Context { return rcs.ctx }
  186. func (rcs *recClientStream) SendMsg(m interface{}) error {
  187. serr := rcs.cstream.SendMsg(m)
  188. e := &entry{
  189. kind: pb.Entry_SEND,
  190. refIndex: rcs.refIndex,
  191. }
  192. e.msg.set(m, serr)
  193. if _, err := rcs.rec.writeEntry(e); err != nil {
  194. return err
  195. }
  196. return serr
  197. }
  198. func (rcs *recClientStream) RecvMsg(m interface{}) error {
  199. serr := rcs.cstream.RecvMsg(m)
  200. e := &entry{
  201. kind: pb.Entry_RECV,
  202. refIndex: rcs.refIndex,
  203. }
  204. e.msg.set(m, serr)
  205. if _, err := rcs.rec.writeEntry(e); err != nil {
  206. return err
  207. }
  208. return serr
  209. }
  210. func (rcs *recClientStream) Header() (metadata.MD, error) {
  211. // TODO(jba): record.
  212. return rcs.cstream.Header()
  213. }
  214. func (rcs *recClientStream) Trailer() metadata.MD {
  215. // TODO(jba): record.
  216. return rcs.cstream.Trailer()
  217. }
  218. func (rcs *recClientStream) CloseSend() error {
  219. // TODO(jba): record.
  220. return rcs.cstream.CloseSend()
  221. }
  222. // A Replayer replays a set of RPCs saved by a Recorder.
  223. type Replayer struct {
  224. initial []byte // initial state
  225. log func(format string, v ...interface{}) // for debugging
  226. mu sync.Mutex
  227. calls []*call
  228. streams []*stream
  229. // BeforeFunc defines a function that can inspect and modify requests before they
  230. // are matched for responses from the replay file.
  231. // The function is called with the method name and the message that triggered the callback.
  232. // If the function returns an error, the error will be returned to the client.
  233. // This is only executed for unary RPCs; streaming RPCs are not supported.
  234. BeforeFunc func(string, proto.Message) error
  235. }
  236. // A call represents a unary RPC, with a request and response (or error).
  237. type call struct {
  238. method string
  239. request proto.Message
  240. response message
  241. }
  242. // A stream represents a gRPC stream, with an initial create-stream call, followed by
  243. // zero or more sends and/or receives.
  244. type stream struct {
  245. method string
  246. createIndex int
  247. createErr error // error from create call
  248. sends []message
  249. recvs []message
  250. }
  251. // NewReplayer creates a Replayer that reads from filename.
  252. func NewReplayer(filename string) (*Replayer, error) {
  253. f, err := os.Open(filename)
  254. if err != nil {
  255. return nil, err
  256. }
  257. defer f.Close()
  258. return NewReplayerReader(f)
  259. }
  260. // NewReplayerReader creates a Replayer that reads from r.
  261. func NewReplayerReader(r io.Reader) (*Replayer, error) {
  262. rep := &Replayer{
  263. log: func(string, ...interface{}) {},
  264. }
  265. if err := rep.read(r); err != nil {
  266. return nil, err
  267. }
  268. return rep, nil
  269. }
  270. // read reads the stream of recorded entries.
  271. // It matches requests with responses, with each pair grouped
  272. // into a call struct.
  273. func (rep *Replayer) read(r io.Reader) error {
  274. r = bufio.NewReader(r)
  275. bytes, err := readHeader(r)
  276. if err != nil {
  277. return err
  278. }
  279. rep.initial = bytes
  280. callsByIndex := map[int]*call{}
  281. streamsByIndex := map[int]*stream{}
  282. for i := 1; ; i++ {
  283. e, err := readEntry(r)
  284. if err != nil {
  285. return err
  286. }
  287. if e == nil {
  288. break
  289. }
  290. switch e.kind {
  291. case pb.Entry_REQUEST:
  292. callsByIndex[i] = &call{
  293. method: e.method,
  294. request: e.msg.msg,
  295. }
  296. case pb.Entry_RESPONSE:
  297. call := callsByIndex[e.refIndex]
  298. if call == nil {
  299. return fmt.Errorf("replayer: no request for response #%d", i)
  300. }
  301. delete(callsByIndex, e.refIndex)
  302. call.response = e.msg
  303. rep.calls = append(rep.calls, call)
  304. case pb.Entry_CREATE_STREAM:
  305. s := &stream{method: e.method, createIndex: i}
  306. s.createErr = e.msg.err
  307. streamsByIndex[i] = s
  308. rep.streams = append(rep.streams, s)
  309. case pb.Entry_SEND:
  310. s := streamsByIndex[e.refIndex]
  311. if s == nil {
  312. return fmt.Errorf("replayer: no stream for send #%d", i)
  313. }
  314. s.sends = append(s.sends, e.msg)
  315. case pb.Entry_RECV:
  316. s := streamsByIndex[e.refIndex]
  317. if s == nil {
  318. return fmt.Errorf("replayer: no stream for recv #%d", i)
  319. }
  320. s.recvs = append(s.recvs, e.msg)
  321. default:
  322. return fmt.Errorf("replayer: unknown kind %s", e.kind)
  323. }
  324. }
  325. if len(callsByIndex) > 0 {
  326. return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex))
  327. }
  328. return nil
  329. }
  330. // DialOptions returns the options that must be passed to grpc.Dial
  331. // to enable replaying.
  332. func (r *Replayer) DialOptions() []grpc.DialOption {
  333. return []grpc.DialOption{
  334. // On replay, we make no RPCs, which means the connection may be closed
  335. // before the normally async Dial completes. Making the Dial synchronous
  336. // fixes that.
  337. grpc.WithBlock(),
  338. grpc.WithUnaryInterceptor(r.interceptUnary),
  339. grpc.WithStreamInterceptor(r.interceptStream),
  340. }
  341. }
  342. // Initial returns the initial state saved by the Recorder.
  343. func (r *Replayer) Initial() []byte { return r.initial }
  344. // SetLogFunc sets a function to be used for debug logging. The function
  345. // should be safe to be called from multiple goroutines.
  346. func (r *Replayer) SetLogFunc(f func(format string, v ...interface{})) {
  347. r.log = f
  348. }
  349. // Close closes the Replayer.
  350. func (r *Replayer) Close() error {
  351. return nil
  352. }
  353. func (r *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
  354. mreq := req.(proto.Message)
  355. if r.BeforeFunc != nil {
  356. if err := r.BeforeFunc(method, mreq); err != nil {
  357. return err
  358. }
  359. }
  360. r.log("request %s (%s)", method, req)
  361. call := r.extractCall(method, mreq)
  362. if call == nil {
  363. return fmt.Errorf("replayer: request not found: %s", mreq)
  364. }
  365. r.log("returning %v", call.response)
  366. if call.response.err != nil {
  367. return call.response.err
  368. }
  369. proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res
  370. return nil
  371. }
  372. func (r *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
  373. r.log("create-stream %s", method)
  374. str := r.extractStream(method)
  375. if str == nil {
  376. return nil, fmt.Errorf("replayer: stream not found for method %s", method)
  377. }
  378. if str.createErr != nil {
  379. return nil, str.createErr
  380. }
  381. return &repClientStream{ctx: ctx, str: str}, nil
  382. }
  383. type repClientStream struct {
  384. ctx context.Context
  385. str *stream
  386. }
  387. func (rcs *repClientStream) Context() context.Context { return rcs.ctx }
  388. func (rcs *repClientStream) SendMsg(m interface{}) error {
  389. if len(rcs.str.sends) == 0 {
  390. return fmt.Errorf("replayer: no more sends for stream %s, created at index %d",
  391. rcs.str.method, rcs.str.createIndex)
  392. }
  393. // TODO(jba): Do not assume that the sends happen in the same order on replay.
  394. msg := rcs.str.sends[0]
  395. rcs.str.sends = rcs.str.sends[1:]
  396. return msg.err
  397. }
  398. func (rcs *repClientStream) RecvMsg(m interface{}) error {
  399. if len(rcs.str.recvs) == 0 {
  400. return fmt.Errorf("replayer: no more receives for stream %s, created at index %d",
  401. rcs.str.method, rcs.str.createIndex)
  402. }
  403. msg := rcs.str.recvs[0]
  404. rcs.str.recvs = rcs.str.recvs[1:]
  405. if msg.err != nil {
  406. return msg.err
  407. }
  408. proto.Merge(m.(proto.Message), msg.msg) // copy msg into m
  409. return nil
  410. }
  411. func (rcs *repClientStream) Header() (metadata.MD, error) {
  412. log.Printf("replay: stream metadata not supported")
  413. return nil, nil
  414. }
  415. func (rcs *repClientStream) Trailer() metadata.MD {
  416. log.Printf("replay: stream metadata not supported")
  417. return nil
  418. }
  419. func (rcs *repClientStream) CloseSend() error {
  420. return nil
  421. }
  422. // extractCall finds the first call in the list with the same method
  423. // and request. It returns nil if it can't find such a call.
  424. func (r *Replayer) extractCall(method string, req proto.Message) *call {
  425. r.mu.Lock()
  426. defer r.mu.Unlock()
  427. for i, call := range r.calls {
  428. if call == nil {
  429. continue
  430. }
  431. if method == call.method && proto.Equal(req, call.request) {
  432. r.calls[i] = nil // nil out this call so we don't reuse it
  433. return call
  434. }
  435. }
  436. return nil
  437. }
  438. func (r *Replayer) extractStream(method string) *stream {
  439. r.mu.Lock()
  440. defer r.mu.Unlock()
  441. for i, stream := range r.streams {
  442. if stream == nil {
  443. continue
  444. }
  445. if method == stream.method {
  446. r.streams[i] = nil
  447. return stream
  448. }
  449. }
  450. return nil
  451. }
  452. // Fprint reads the entries from filename and writes them to w in human-readable form.
  453. // It is intended for debugging.
  454. func Fprint(w io.Writer, filename string) error {
  455. f, err := os.Open(filename)
  456. if err != nil {
  457. return err
  458. }
  459. defer f.Close()
  460. return FprintReader(w, f)
  461. }
  462. // FprintReader reads the entries from r and writes them to w in human-readable form.
  463. // It is intended for debugging.
  464. func FprintReader(w io.Writer, r io.Reader) error {
  465. initial, err := readHeader(r)
  466. if err != nil {
  467. return err
  468. }
  469. fmt.Fprintf(w, "initial state: %q\n", string(initial))
  470. for i := 1; ; i++ {
  471. e, err := readEntry(r)
  472. if err != nil {
  473. return err
  474. }
  475. if e == nil {
  476. return nil
  477. }
  478. s := "message"
  479. if e.msg.err != nil {
  480. s = "error"
  481. }
  482. fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d, %s:\n",
  483. i, e.kind, e.method, e.refIndex, s)
  484. if e.msg.err == nil {
  485. if err := proto.MarshalText(w, e.msg.msg); err != nil {
  486. return err
  487. }
  488. } else {
  489. fmt.Fprintf(w, "%v\n", e.msg.err)
  490. }
  491. }
  492. }
  493. // An entry holds one gRPC action (request, response, etc.).
  494. type entry struct {
  495. kind pb.Entry_Kind
  496. method string
  497. msg message
  498. refIndex int // index of corresponding request or create-stream
  499. }
  500. func (e1 *entry) equal(e2 *entry) bool {
  501. if e1 == nil && e2 == nil {
  502. return true
  503. }
  504. if e1 == nil || e2 == nil {
  505. return false
  506. }
  507. return e1.kind == e2.kind &&
  508. e1.method == e2.method &&
  509. proto.Equal(e1.msg.msg, e2.msg.msg) &&
  510. errEqual(e1.msg.err, e2.msg.err) &&
  511. e1.refIndex == e2.refIndex
  512. }
  513. func errEqual(e1, e2 error) bool {
  514. if e1 == e2 {
  515. return true
  516. }
  517. s1, ok1 := status.FromError(e1)
  518. s2, ok2 := status.FromError(e2)
  519. if !ok1 || !ok2 {
  520. return false
  521. }
  522. return proto.Equal(s1.Proto(), s2.Proto())
  523. }
  524. // message holds either a single proto.Message or an error.
  525. type message struct {
  526. msg proto.Message
  527. err error
  528. }
  529. func (m *message) set(msg interface{}, err error) {
  530. m.err = err
  531. if err != io.EOF && msg != nil {
  532. m.msg = msg.(proto.Message)
  533. }
  534. }
  535. // File format:
  536. // header
  537. // sequence of Entry protos
  538. //
  539. // Header format:
  540. // magic string
  541. // a record containing the bytes of the initial state
  542. const magic = "RPCReplay"
  543. func writeHeader(w io.Writer, initial []byte) error {
  544. if _, err := io.WriteString(w, magic); err != nil {
  545. return err
  546. }
  547. return writeRecord(w, initial)
  548. }
  549. func readHeader(r io.Reader) ([]byte, error) {
  550. var buf [len(magic)]byte
  551. if _, err := io.ReadFull(r, buf[:]); err != nil {
  552. if err == io.EOF {
  553. err = errors.New("rpcreplay: empty replay file")
  554. }
  555. return nil, err
  556. }
  557. if string(buf[:]) != magic {
  558. return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)")
  559. }
  560. bytes, err := readRecord(r)
  561. if err == io.EOF {
  562. err = errors.New("rpcreplay: missing initial state")
  563. }
  564. return bytes, err
  565. }
  566. func writeEntry(w io.Writer, e *entry) error {
  567. var m proto.Message
  568. if e.msg.err != nil && e.msg.err != io.EOF {
  569. s, ok := status.FromError(e.msg.err)
  570. if !ok {
  571. return fmt.Errorf("rpcreplay: error %v is not a Status", e.msg.err)
  572. }
  573. m = s.Proto()
  574. } else {
  575. m = e.msg.msg
  576. }
  577. var a *any.Any
  578. var err error
  579. if m != nil {
  580. a, err = ptypes.MarshalAny(m)
  581. if err != nil {
  582. return err
  583. }
  584. }
  585. pe := &pb.Entry{
  586. Kind: e.kind,
  587. Method: e.method,
  588. Message: a,
  589. IsError: e.msg.err != nil,
  590. RefIndex: int32(e.refIndex),
  591. }
  592. bytes, err := proto.Marshal(pe)
  593. if err != nil {
  594. return err
  595. }
  596. return writeRecord(w, bytes)
  597. }
  598. func readEntry(r io.Reader) (*entry, error) {
  599. buf, err := readRecord(r)
  600. if err == io.EOF {
  601. return nil, nil
  602. }
  603. if err != nil {
  604. return nil, err
  605. }
  606. var pe pb.Entry
  607. if err := proto.Unmarshal(buf, &pe); err != nil {
  608. return nil, err
  609. }
  610. var msg message
  611. if pe.Message != nil {
  612. var any ptypes.DynamicAny
  613. if err := ptypes.UnmarshalAny(pe.Message, &any); err != nil {
  614. return nil, err
  615. }
  616. if pe.IsError {
  617. msg.err = status.ErrorProto(any.Message.(*spb.Status))
  618. } else {
  619. msg.msg = any.Message
  620. }
  621. } else if pe.IsError {
  622. msg.err = io.EOF
  623. } else if pe.Kind != pb.Entry_CREATE_STREAM {
  624. return nil, errors.New("rpcreplay: entry with nil message and false is_error")
  625. }
  626. return &entry{
  627. kind: pe.Kind,
  628. method: pe.Method,
  629. msg: msg,
  630. refIndex: int(pe.RefIndex),
  631. }, nil
  632. }
  633. // A record consists of an unsigned 32-bit little-endian length L followed by L
  634. // bytes.
  635. func writeRecord(w io.Writer, data []byte) error {
  636. if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil {
  637. return err
  638. }
  639. _, err := w.Write(data)
  640. return err
  641. }
  642. func readRecord(r io.Reader) ([]byte, error) {
  643. var size uint32
  644. if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
  645. return nil, err
  646. }
  647. buf := make([]byte, size)
  648. if _, err := io.ReadFull(r, buf); err != nil {
  649. return nil, err
  650. }
  651. return buf, nil
  652. }