|
- // Copyright 2017 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- package rpcreplay
-
- import (
- "bufio"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "log"
- "os"
- "sync"
-
- "golang.org/x/net/context"
-
- "google.golang.org/grpc"
- "google.golang.org/grpc/metadata"
- "google.golang.org/grpc/status"
-
- pb "cloud.google.com/go/rpcreplay/proto/rpcreplay"
- "github.com/golang/protobuf/proto"
- "github.com/golang/protobuf/ptypes"
- "github.com/golang/protobuf/ptypes/any"
- spb "google.golang.org/genproto/googleapis/rpc/status"
- )
-
- // A Recorder records RPCs for later playback.
- type Recorder struct {
- mu sync.Mutex
- w *bufio.Writer
- f *os.File
- next int
- err error
- // BeforeFunc defines a function that can inspect and modify requests and responses
- // written to the replay file. It does not modify messages sent to the service.
- // It is run once before a request is written to the replay file, and once before a response
- // is written to the replay file.
- // The function is called with the method name and the message that triggered the callback.
- // If the function returns an error, the error will be returned to the client.
- // This is only executed for unary RPCs; streaming RPCs are not supported.
- BeforeFunc func(string, proto.Message) error
- }
-
- // NewRecorder creates a recorder that writes to filename. The file will
- // also store the initial bytes for retrieval during replay.
- //
- // You must call Close on the Recorder to ensure that all data is written.
- func NewRecorder(filename string, initial []byte) (*Recorder, error) {
- f, err := os.Create(filename)
- if err != nil {
- return nil, err
- }
- rec, err := NewRecorderWriter(f, initial)
- if err != nil {
- _ = f.Close()
- return nil, err
- }
- rec.f = f
- return rec, nil
- }
-
- // NewRecorderWriter creates a recorder that writes to w. The initial
- // bytes will also be written to w for retrieval during replay.
- //
- // You must call Close on the Recorder to ensure that all data is written.
- func NewRecorderWriter(w io.Writer, initial []byte) (*Recorder, error) {
- bw := bufio.NewWriter(w)
- if err := writeHeader(bw, initial); err != nil {
- return nil, err
- }
- return &Recorder{w: bw, next: 1}, nil
- }
-
- // DialOptions returns the options that must be passed to grpc.Dial
- // to enable recording.
- func (r *Recorder) DialOptions() []grpc.DialOption {
- return []grpc.DialOption{
- grpc.WithUnaryInterceptor(r.interceptUnary),
- grpc.WithStreamInterceptor(r.interceptStream),
- }
- }
-
- // Close saves any unwritten information.
- func (r *Recorder) Close() error {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.err != nil {
- return r.err
- }
- err := r.w.Flush()
- if r.f != nil {
- if err2 := r.f.Close(); err == nil {
- err = err2
- }
- }
- return err
- }
-
- // Intercepts all unary (non-stream) RPCs.
- func (r *Recorder) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
- ereq := &entry{
- kind: pb.Entry_REQUEST,
- method: method,
- msg: message{msg: proto.Clone(req.(proto.Message))},
- }
-
- if r.BeforeFunc != nil {
- if err := r.BeforeFunc(method, ereq.msg.msg); err != nil {
- return err
- }
- }
- refIndex, err := r.writeEntry(ereq)
- if err != nil {
- return err
- }
- ierr := invoker(ctx, method, req, res, cc, opts...)
- eres := &entry{
- kind: pb.Entry_RESPONSE,
- refIndex: refIndex,
- }
- // If the error is not a gRPC status, then something more
- // serious is wrong. More significantly, we have no way
- // of serializing an arbitrary error. So just return it
- // without recording the response.
- if _, ok := status.FromError(ierr); !ok {
- r.mu.Lock()
- r.err = fmt.Errorf("saw non-status error in %s response: %v (%T)", method, ierr, ierr)
- r.mu.Unlock()
- return ierr
- }
- eres.msg.set(proto.Clone(res.(proto.Message)), ierr)
- if r.BeforeFunc != nil {
- if err := r.BeforeFunc(method, eres.msg.msg); err != nil {
- return err
- }
- }
- if _, err := r.writeEntry(eres); err != nil {
- return err
- }
- return ierr
- }
-
- func (r *Recorder) writeEntry(e *entry) (int, error) {
- r.mu.Lock()
- defer r.mu.Unlock()
- if r.err != nil {
- return 0, r.err
- }
- err := writeEntry(r.w, e)
- if err != nil {
- r.err = err
- return 0, err
- }
- n := r.next
- r.next++
- return n, nil
- }
-
- func (r *Recorder) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
- cstream, serr := streamer(ctx, desc, cc, method, opts...)
- e := &entry{
- kind: pb.Entry_CREATE_STREAM,
- method: method,
- }
- e.msg.set(nil, serr)
- refIndex, err := r.writeEntry(e)
- if err != nil {
- return nil, err
- }
- return &recClientStream{
- ctx: ctx,
- rec: r,
- cstream: cstream,
- refIndex: refIndex,
- }, serr
- }
-
- // A recClientStream implements the gprc.ClientStream interface.
- // It behaves exactly like the default ClientStream, but also
- // records all messages sent and received.
- type recClientStream struct {
- ctx context.Context
- rec *Recorder
- cstream grpc.ClientStream
- refIndex int
- }
-
- func (rcs *recClientStream) Context() context.Context { return rcs.ctx }
-
- func (rcs *recClientStream) SendMsg(m interface{}) error {
- serr := rcs.cstream.SendMsg(m)
- e := &entry{
- kind: pb.Entry_SEND,
- refIndex: rcs.refIndex,
- }
- e.msg.set(m, serr)
- if _, err := rcs.rec.writeEntry(e); err != nil {
- return err
- }
- return serr
- }
-
- func (rcs *recClientStream) RecvMsg(m interface{}) error {
- serr := rcs.cstream.RecvMsg(m)
- e := &entry{
- kind: pb.Entry_RECV,
- refIndex: rcs.refIndex,
- }
- e.msg.set(m, serr)
- if _, err := rcs.rec.writeEntry(e); err != nil {
- return err
- }
- return serr
- }
-
- func (rcs *recClientStream) Header() (metadata.MD, error) {
- // TODO(jba): record.
- return rcs.cstream.Header()
- }
-
- func (rcs *recClientStream) Trailer() metadata.MD {
- // TODO(jba): record.
- return rcs.cstream.Trailer()
- }
-
- func (rcs *recClientStream) CloseSend() error {
- // TODO(jba): record.
- return rcs.cstream.CloseSend()
- }
-
- // A Replayer replays a set of RPCs saved by a Recorder.
- type Replayer struct {
- initial []byte // initial state
- log func(format string, v ...interface{}) // for debugging
-
- mu sync.Mutex
- calls []*call
- streams []*stream
- // BeforeFunc defines a function that can inspect and modify requests before they
- // are matched for responses from the replay file.
- // The function is called with the method name and the message that triggered the callback.
- // If the function returns an error, the error will be returned to the client.
- // This is only executed for unary RPCs; streaming RPCs are not supported.
- BeforeFunc func(string, proto.Message) error
- }
-
- // A call represents a unary RPC, with a request and response (or error).
- type call struct {
- method string
- request proto.Message
- response message
- }
-
- // A stream represents a gRPC stream, with an initial create-stream call, followed by
- // zero or more sends and/or receives.
- type stream struct {
- method string
- createIndex int
- createErr error // error from create call
- sends []message
- recvs []message
- }
-
- // NewReplayer creates a Replayer that reads from filename.
- func NewReplayer(filename string) (*Replayer, error) {
- f, err := os.Open(filename)
- if err != nil {
- return nil, err
- }
- defer f.Close()
- return NewReplayerReader(f)
- }
-
- // NewReplayerReader creates a Replayer that reads from r.
- func NewReplayerReader(r io.Reader) (*Replayer, error) {
- rep := &Replayer{
- log: func(string, ...interface{}) {},
- }
- if err := rep.read(r); err != nil {
- return nil, err
- }
- return rep, nil
- }
-
- // read reads the stream of recorded entries.
- // It matches requests with responses, with each pair grouped
- // into a call struct.
- func (rep *Replayer) read(r io.Reader) error {
- r = bufio.NewReader(r)
- bytes, err := readHeader(r)
- if err != nil {
- return err
- }
- rep.initial = bytes
-
- callsByIndex := map[int]*call{}
- streamsByIndex := map[int]*stream{}
- for i := 1; ; i++ {
- e, err := readEntry(r)
- if err != nil {
- return err
- }
- if e == nil {
- break
- }
- switch e.kind {
- case pb.Entry_REQUEST:
- callsByIndex[i] = &call{
- method: e.method,
- request: e.msg.msg,
- }
-
- case pb.Entry_RESPONSE:
- call := callsByIndex[e.refIndex]
- if call == nil {
- return fmt.Errorf("replayer: no request for response #%d", i)
- }
- delete(callsByIndex, e.refIndex)
- call.response = e.msg
- rep.calls = append(rep.calls, call)
-
- case pb.Entry_CREATE_STREAM:
- s := &stream{method: e.method, createIndex: i}
- s.createErr = e.msg.err
- streamsByIndex[i] = s
- rep.streams = append(rep.streams, s)
-
- case pb.Entry_SEND:
- s := streamsByIndex[e.refIndex]
- if s == nil {
- return fmt.Errorf("replayer: no stream for send #%d", i)
- }
- s.sends = append(s.sends, e.msg)
-
- case pb.Entry_RECV:
- s := streamsByIndex[e.refIndex]
- if s == nil {
- return fmt.Errorf("replayer: no stream for recv #%d", i)
- }
- s.recvs = append(s.recvs, e.msg)
-
- default:
- return fmt.Errorf("replayer: unknown kind %s", e.kind)
- }
- }
- if len(callsByIndex) > 0 {
- return fmt.Errorf("replayer: %d unmatched requests", len(callsByIndex))
- }
- return nil
- }
-
- // DialOptions returns the options that must be passed to grpc.Dial
- // to enable replaying.
- func (r *Replayer) DialOptions() []grpc.DialOption {
- return []grpc.DialOption{
- // On replay, we make no RPCs, which means the connection may be closed
- // before the normally async Dial completes. Making the Dial synchronous
- // fixes that.
- grpc.WithBlock(),
- grpc.WithUnaryInterceptor(r.interceptUnary),
- grpc.WithStreamInterceptor(r.interceptStream),
- }
- }
-
- // Initial returns the initial state saved by the Recorder.
- func (r *Replayer) Initial() []byte { return r.initial }
-
- // SetLogFunc sets a function to be used for debug logging. The function
- // should be safe to be called from multiple goroutines.
- func (r *Replayer) SetLogFunc(f func(format string, v ...interface{})) {
- r.log = f
- }
-
- // Close closes the Replayer.
- func (r *Replayer) Close() error {
- return nil
- }
-
- func (r *Replayer) interceptUnary(_ context.Context, method string, req, res interface{}, _ *grpc.ClientConn, _ grpc.UnaryInvoker, _ ...grpc.CallOption) error {
- mreq := req.(proto.Message)
- if r.BeforeFunc != nil {
- if err := r.BeforeFunc(method, mreq); err != nil {
- return err
- }
- }
- r.log("request %s (%s)", method, req)
- call := r.extractCall(method, mreq)
- if call == nil {
- return fmt.Errorf("replayer: request not found: %s", mreq)
- }
- r.log("returning %v", call.response)
- if call.response.err != nil {
- return call.response.err
- }
- proto.Merge(res.(proto.Message), call.response.msg) // copy msg into res
- return nil
- }
-
- func (r *Replayer) interceptStream(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, method string, _ grpc.Streamer, _ ...grpc.CallOption) (grpc.ClientStream, error) {
- r.log("create-stream %s", method)
- str := r.extractStream(method)
- if str == nil {
- return nil, fmt.Errorf("replayer: stream not found for method %s", method)
- }
- if str.createErr != nil {
- return nil, str.createErr
- }
- return &repClientStream{ctx: ctx, str: str}, nil
- }
-
- type repClientStream struct {
- ctx context.Context
- str *stream
- }
-
- func (rcs *repClientStream) Context() context.Context { return rcs.ctx }
-
- func (rcs *repClientStream) SendMsg(m interface{}) error {
- if len(rcs.str.sends) == 0 {
- return fmt.Errorf("replayer: no more sends for stream %s, created at index %d",
- rcs.str.method, rcs.str.createIndex)
- }
- // TODO(jba): Do not assume that the sends happen in the same order on replay.
- msg := rcs.str.sends[0]
- rcs.str.sends = rcs.str.sends[1:]
- return msg.err
- }
-
- func (rcs *repClientStream) RecvMsg(m interface{}) error {
- if len(rcs.str.recvs) == 0 {
- return fmt.Errorf("replayer: no more receives for stream %s, created at index %d",
- rcs.str.method, rcs.str.createIndex)
- }
- msg := rcs.str.recvs[0]
- rcs.str.recvs = rcs.str.recvs[1:]
- if msg.err != nil {
- return msg.err
- }
- proto.Merge(m.(proto.Message), msg.msg) // copy msg into m
- return nil
- }
-
- func (rcs *repClientStream) Header() (metadata.MD, error) {
- log.Printf("replay: stream metadata not supported")
- return nil, nil
- }
-
- func (rcs *repClientStream) Trailer() metadata.MD {
- log.Printf("replay: stream metadata not supported")
- return nil
- }
-
- func (rcs *repClientStream) CloseSend() error {
- return nil
- }
-
- // extractCall finds the first call in the list with the same method
- // and request. It returns nil if it can't find such a call.
- func (r *Replayer) extractCall(method string, req proto.Message) *call {
- r.mu.Lock()
- defer r.mu.Unlock()
- for i, call := range r.calls {
- if call == nil {
- continue
- }
- if method == call.method && proto.Equal(req, call.request) {
- r.calls[i] = nil // nil out this call so we don't reuse it
- return call
- }
- }
- return nil
- }
-
- func (r *Replayer) extractStream(method string) *stream {
- r.mu.Lock()
- defer r.mu.Unlock()
- for i, stream := range r.streams {
- if stream == nil {
- continue
- }
- if method == stream.method {
- r.streams[i] = nil
- return stream
- }
- }
- return nil
- }
-
- // Fprint reads the entries from filename and writes them to w in human-readable form.
- // It is intended for debugging.
- func Fprint(w io.Writer, filename string) error {
- f, err := os.Open(filename)
- if err != nil {
- return err
- }
- defer f.Close()
- return FprintReader(w, f)
- }
-
- // FprintReader reads the entries from r and writes them to w in human-readable form.
- // It is intended for debugging.
- func FprintReader(w io.Writer, r io.Reader) error {
- initial, err := readHeader(r)
- if err != nil {
- return err
- }
- fmt.Fprintf(w, "initial state: %q\n", string(initial))
- for i := 1; ; i++ {
- e, err := readEntry(r)
- if err != nil {
- return err
- }
- if e == nil {
- return nil
- }
-
- s := "message"
- if e.msg.err != nil {
- s = "error"
- }
- fmt.Fprintf(w, "#%d: kind: %s, method: %s, ref index: %d, %s:\n",
- i, e.kind, e.method, e.refIndex, s)
- if e.msg.err == nil {
- if err := proto.MarshalText(w, e.msg.msg); err != nil {
- return err
- }
- } else {
- fmt.Fprintf(w, "%v\n", e.msg.err)
- }
- }
- }
-
- // An entry holds one gRPC action (request, response, etc.).
- type entry struct {
- kind pb.Entry_Kind
- method string
- msg message
- refIndex int // index of corresponding request or create-stream
- }
-
- func (e1 *entry) equal(e2 *entry) bool {
- if e1 == nil && e2 == nil {
- return true
- }
- if e1 == nil || e2 == nil {
- return false
- }
- return e1.kind == e2.kind &&
- e1.method == e2.method &&
- proto.Equal(e1.msg.msg, e2.msg.msg) &&
- errEqual(e1.msg.err, e2.msg.err) &&
- e1.refIndex == e2.refIndex
- }
-
- func errEqual(e1, e2 error) bool {
- if e1 == e2 {
- return true
- }
- s1, ok1 := status.FromError(e1)
- s2, ok2 := status.FromError(e2)
- if !ok1 || !ok2 {
- return false
- }
- return proto.Equal(s1.Proto(), s2.Proto())
- }
-
- // message holds either a single proto.Message or an error.
- type message struct {
- msg proto.Message
- err error
- }
-
- func (m *message) set(msg interface{}, err error) {
- m.err = err
- if err != io.EOF && msg != nil {
- m.msg = msg.(proto.Message)
- }
- }
-
- // File format:
- // header
- // sequence of Entry protos
- //
- // Header format:
- // magic string
- // a record containing the bytes of the initial state
-
- const magic = "RPCReplay"
-
- func writeHeader(w io.Writer, initial []byte) error {
- if _, err := io.WriteString(w, magic); err != nil {
- return err
- }
- return writeRecord(w, initial)
- }
-
- func readHeader(r io.Reader) ([]byte, error) {
- var buf [len(magic)]byte
- if _, err := io.ReadFull(r, buf[:]); err != nil {
- if err == io.EOF {
- err = errors.New("rpcreplay: empty replay file")
- }
- return nil, err
- }
- if string(buf[:]) != magic {
- return nil, errors.New("rpcreplay: not a replay file (does not begin with magic string)")
- }
- bytes, err := readRecord(r)
- if err == io.EOF {
- err = errors.New("rpcreplay: missing initial state")
- }
- return bytes, err
- }
-
- func writeEntry(w io.Writer, e *entry) error {
- var m proto.Message
- if e.msg.err != nil && e.msg.err != io.EOF {
- s, ok := status.FromError(e.msg.err)
- if !ok {
- return fmt.Errorf("rpcreplay: error %v is not a Status", e.msg.err)
- }
- m = s.Proto()
- } else {
- m = e.msg.msg
- }
- var a *any.Any
- var err error
- if m != nil {
- a, err = ptypes.MarshalAny(m)
- if err != nil {
- return err
- }
- }
- pe := &pb.Entry{
- Kind: e.kind,
- Method: e.method,
- Message: a,
- IsError: e.msg.err != nil,
- RefIndex: int32(e.refIndex),
- }
- bytes, err := proto.Marshal(pe)
- if err != nil {
- return err
- }
- return writeRecord(w, bytes)
- }
-
- func readEntry(r io.Reader) (*entry, error) {
- buf, err := readRecord(r)
- if err == io.EOF {
- return nil, nil
- }
- if err != nil {
- return nil, err
- }
- var pe pb.Entry
- if err := proto.Unmarshal(buf, &pe); err != nil {
- return nil, err
- }
- var msg message
- if pe.Message != nil {
- var any ptypes.DynamicAny
- if err := ptypes.UnmarshalAny(pe.Message, &any); err != nil {
- return nil, err
- }
- if pe.IsError {
- msg.err = status.ErrorProto(any.Message.(*spb.Status))
- } else {
- msg.msg = any.Message
- }
- } else if pe.IsError {
- msg.err = io.EOF
- } else if pe.Kind != pb.Entry_CREATE_STREAM {
- return nil, errors.New("rpcreplay: entry with nil message and false is_error")
- }
- return &entry{
- kind: pe.Kind,
- method: pe.Method,
- msg: msg,
- refIndex: int(pe.RefIndex),
- }, nil
- }
-
- // A record consists of an unsigned 32-bit little-endian length L followed by L
- // bytes.
-
- func writeRecord(w io.Writer, data []byte) error {
- if err := binary.Write(w, binary.LittleEndian, uint32(len(data))); err != nil {
- return err
- }
- _, err := w.Write(data)
- return err
- }
-
- func readRecord(r io.Reader) ([]byte, error) {
- var size uint32
- if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
- return nil, err
- }
- buf := make([]byte, size)
- if _, err := io.ReadFull(r, buf); err != nil {
- return nil, err
- }
- return buf, nil
- }
|