You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

482 lines
13 KiB

  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package transport
  19. import (
  20. "context"
  21. "errors"
  22. "fmt"
  23. "io"
  24. "net/http"
  25. "net/http/httptest"
  26. "net/url"
  27. "reflect"
  28. "sync"
  29. "testing"
  30. "time"
  31. "github.com/golang/protobuf/proto"
  32. dpb "github.com/golang/protobuf/ptypes/duration"
  33. epb "google.golang.org/genproto/googleapis/rpc/errdetails"
  34. "google.golang.org/grpc/codes"
  35. "google.golang.org/grpc/metadata"
  36. "google.golang.org/grpc/status"
  37. )
  38. func TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
  39. type testCase struct {
  40. name string
  41. req *http.Request
  42. wantErr string
  43. modrw func(http.ResponseWriter) http.ResponseWriter
  44. check func(*serverHandlerTransport, *testCase) error
  45. }
  46. tests := []testCase{
  47. {
  48. name: "http/1.1",
  49. req: &http.Request{
  50. ProtoMajor: 1,
  51. ProtoMinor: 1,
  52. },
  53. wantErr: "gRPC requires HTTP/2",
  54. },
  55. {
  56. name: "bad method",
  57. req: &http.Request{
  58. ProtoMajor: 2,
  59. Method: "GET",
  60. Header: http.Header{},
  61. RequestURI: "/",
  62. },
  63. wantErr: "invalid gRPC request method",
  64. },
  65. {
  66. name: "bad content type",
  67. req: &http.Request{
  68. ProtoMajor: 2,
  69. Method: "POST",
  70. Header: http.Header{
  71. "Content-Type": {"application/foo"},
  72. },
  73. RequestURI: "/service/foo.bar",
  74. },
  75. wantErr: "invalid gRPC request content-type",
  76. },
  77. {
  78. name: "not flusher",
  79. req: &http.Request{
  80. ProtoMajor: 2,
  81. Method: "POST",
  82. Header: http.Header{
  83. "Content-Type": {"application/grpc"},
  84. },
  85. RequestURI: "/service/foo.bar",
  86. },
  87. modrw: func(w http.ResponseWriter) http.ResponseWriter {
  88. // Return w without its Flush method
  89. type onlyCloseNotifier interface {
  90. http.ResponseWriter
  91. http.CloseNotifier
  92. }
  93. return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
  94. },
  95. wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
  96. },
  97. {
  98. name: "not closenotifier",
  99. req: &http.Request{
  100. ProtoMajor: 2,
  101. Method: "POST",
  102. Header: http.Header{
  103. "Content-Type": {"application/grpc"},
  104. },
  105. RequestURI: "/service/foo.bar",
  106. },
  107. modrw: func(w http.ResponseWriter) http.ResponseWriter {
  108. // Return w without its CloseNotify method
  109. type onlyFlusher interface {
  110. http.ResponseWriter
  111. http.Flusher
  112. }
  113. return struct{ onlyFlusher }{w.(onlyFlusher)}
  114. },
  115. wantErr: "gRPC requires a ResponseWriter supporting http.CloseNotifier",
  116. },
  117. {
  118. name: "valid",
  119. req: &http.Request{
  120. ProtoMajor: 2,
  121. Method: "POST",
  122. Header: http.Header{
  123. "Content-Type": {"application/grpc"},
  124. },
  125. URL: &url.URL{
  126. Path: "/service/foo.bar",
  127. },
  128. RequestURI: "/service/foo.bar",
  129. },
  130. check: func(t *serverHandlerTransport, tt *testCase) error {
  131. if t.req != tt.req {
  132. return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
  133. }
  134. if t.rw == nil {
  135. return errors.New("t.rw = nil; want non-nil")
  136. }
  137. return nil
  138. },
  139. },
  140. {
  141. name: "with timeout",
  142. req: &http.Request{
  143. ProtoMajor: 2,
  144. Method: "POST",
  145. Header: http.Header{
  146. "Content-Type": []string{"application/grpc"},
  147. "Grpc-Timeout": {"200m"},
  148. },
  149. URL: &url.URL{
  150. Path: "/service/foo.bar",
  151. },
  152. RequestURI: "/service/foo.bar",
  153. },
  154. check: func(t *serverHandlerTransport, tt *testCase) error {
  155. if !t.timeoutSet {
  156. return errors.New("timeout not set")
  157. }
  158. if want := 200 * time.Millisecond; t.timeout != want {
  159. return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
  160. }
  161. return nil
  162. },
  163. },
  164. {
  165. name: "with bad timeout",
  166. req: &http.Request{
  167. ProtoMajor: 2,
  168. Method: "POST",
  169. Header: http.Header{
  170. "Content-Type": []string{"application/grpc"},
  171. "Grpc-Timeout": {"tomorrow"},
  172. },
  173. URL: &url.URL{
  174. Path: "/service/foo.bar",
  175. },
  176. RequestURI: "/service/foo.bar",
  177. },
  178. wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
  179. },
  180. {
  181. name: "with metadata",
  182. req: &http.Request{
  183. ProtoMajor: 2,
  184. Method: "POST",
  185. Header: http.Header{
  186. "Content-Type": []string{"application/grpc"},
  187. "meta-foo": {"foo-val"},
  188. "meta-bar": {"bar-val1", "bar-val2"},
  189. "user-agent": {"x/y a/b"},
  190. },
  191. URL: &url.URL{
  192. Path: "/service/foo.bar",
  193. },
  194. RequestURI: "/service/foo.bar",
  195. },
  196. check: func(ht *serverHandlerTransport, tt *testCase) error {
  197. want := metadata.MD{
  198. "meta-bar": {"bar-val1", "bar-val2"},
  199. "user-agent": {"x/y a/b"},
  200. "meta-foo": {"foo-val"},
  201. "content-type": {"application/grpc"},
  202. }
  203. if !reflect.DeepEqual(ht.headerMD, want) {
  204. return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
  205. }
  206. return nil
  207. },
  208. },
  209. }
  210. for _, tt := range tests {
  211. rw := newTestHandlerResponseWriter()
  212. if tt.modrw != nil {
  213. rw = tt.modrw(rw)
  214. }
  215. got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
  216. if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
  217. t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
  218. continue
  219. }
  220. if gotErr != nil {
  221. continue
  222. }
  223. if tt.check != nil {
  224. if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
  225. t.Errorf("%s: %v", tt.name, err)
  226. }
  227. }
  228. }
  229. }
  230. type testHandlerResponseWriter struct {
  231. *httptest.ResponseRecorder
  232. closeNotify chan bool
  233. }
  234. func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
  235. func (w testHandlerResponseWriter) Flush() {}
  236. func newTestHandlerResponseWriter() http.ResponseWriter {
  237. return testHandlerResponseWriter{
  238. ResponseRecorder: httptest.NewRecorder(),
  239. closeNotify: make(chan bool, 1),
  240. }
  241. }
  242. type handleStreamTest struct {
  243. t *testing.T
  244. bodyw *io.PipeWriter
  245. rw testHandlerResponseWriter
  246. ht *serverHandlerTransport
  247. }
  248. func newHandleStreamTest(t *testing.T) *handleStreamTest {
  249. bodyr, bodyw := io.Pipe()
  250. req := &http.Request{
  251. ProtoMajor: 2,
  252. Method: "POST",
  253. Header: http.Header{
  254. "Content-Type": {"application/grpc"},
  255. },
  256. URL: &url.URL{
  257. Path: "/service/foo.bar",
  258. },
  259. RequestURI: "/service/foo.bar",
  260. Body: bodyr,
  261. }
  262. rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
  263. ht, err := NewServerHandlerTransport(rw, req, nil)
  264. if err != nil {
  265. t.Fatal(err)
  266. }
  267. return &handleStreamTest{
  268. t: t,
  269. bodyw: bodyw,
  270. ht: ht.(*serverHandlerTransport),
  271. rw: rw,
  272. }
  273. }
  274. func TestHandlerTransport_HandleStreams(t *testing.T) {
  275. st := newHandleStreamTest(t)
  276. handleStream := func(s *Stream) {
  277. if want := "/service/foo.bar"; s.method != want {
  278. t.Errorf("stream method = %q; want %q", s.method, want)
  279. }
  280. st.bodyw.Close() // no body
  281. st.ht.WriteStatus(s, status.New(codes.OK, ""))
  282. }
  283. st.ht.HandleStreams(
  284. func(s *Stream) { go handleStream(s) },
  285. func(ctx context.Context, method string) context.Context { return ctx },
  286. )
  287. wantHeader := http.Header{
  288. "Date": nil,
  289. "Content-Type": {"application/grpc"},
  290. "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
  291. "Grpc-Status": {"0"},
  292. }
  293. if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
  294. t.Errorf("Header+Trailer Map: %#v; want %#v", st.rw.HeaderMap, wantHeader)
  295. }
  296. }
  297. // Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
  298. func TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
  299. handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
  300. }
  301. // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
  302. func TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
  303. handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
  304. }
  305. func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
  306. st := newHandleStreamTest(t)
  307. handleStream := func(s *Stream) {
  308. st.ht.WriteStatus(s, status.New(statusCode, msg))
  309. }
  310. st.ht.HandleStreams(
  311. func(s *Stream) { go handleStream(s) },
  312. func(ctx context.Context, method string) context.Context { return ctx },
  313. )
  314. wantHeader := http.Header{
  315. "Date": nil,
  316. "Content-Type": {"application/grpc"},
  317. "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
  318. "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
  319. "Grpc-Message": {encodeGrpcMessage(msg)},
  320. }
  321. if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
  322. t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
  323. }
  324. }
  325. func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
  326. bodyr, bodyw := io.Pipe()
  327. req := &http.Request{
  328. ProtoMajor: 2,
  329. Method: "POST",
  330. Header: http.Header{
  331. "Content-Type": {"application/grpc"},
  332. "Grpc-Timeout": {"200m"},
  333. },
  334. URL: &url.URL{
  335. Path: "/service/foo.bar",
  336. },
  337. RequestURI: "/service/foo.bar",
  338. Body: bodyr,
  339. }
  340. rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
  341. ht, err := NewServerHandlerTransport(rw, req, nil)
  342. if err != nil {
  343. t.Fatal(err)
  344. }
  345. runStream := func(s *Stream) {
  346. defer bodyw.Close()
  347. select {
  348. case <-s.ctx.Done():
  349. case <-time.After(5 * time.Second):
  350. t.Errorf("timeout waiting for ctx.Done")
  351. return
  352. }
  353. err := s.ctx.Err()
  354. if err != context.DeadlineExceeded {
  355. t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
  356. return
  357. }
  358. ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
  359. }
  360. ht.HandleStreams(
  361. func(s *Stream) { go runStream(s) },
  362. func(ctx context.Context, method string) context.Context { return ctx },
  363. )
  364. wantHeader := http.Header{
  365. "Date": nil,
  366. "Content-Type": {"application/grpc"},
  367. "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
  368. "Grpc-Status": {"4"},
  369. "Grpc-Message": {encodeGrpcMessage("too slow")},
  370. }
  371. if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
  372. t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
  373. }
  374. }
  375. // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
  376. // concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
  377. func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
  378. testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
  379. if want := "/service/foo.bar"; s.method != want {
  380. t.Errorf("stream method = %q; want %q", s.method, want)
  381. }
  382. st.bodyw.Close() // no body
  383. var wg sync.WaitGroup
  384. wg.Add(5)
  385. for i := 0; i < 5; i++ {
  386. go func() {
  387. defer wg.Done()
  388. st.ht.WriteStatus(s, status.New(codes.OK, ""))
  389. }()
  390. }
  391. wg.Wait()
  392. })
  393. }
  394. // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
  395. // following "WriteStatus" does not panic writing to closed "writes" channel.
  396. func TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
  397. testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
  398. if want := "/service/foo.bar"; s.method != want {
  399. t.Errorf("stream method = %q; want %q", s.method, want)
  400. }
  401. st.bodyw.Close() // no body
  402. st.ht.WriteStatus(s, status.New(codes.OK, ""))
  403. st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
  404. })
  405. }
  406. func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
  407. st := newHandleStreamTest(t)
  408. st.ht.HandleStreams(
  409. func(s *Stream) { go handleStream(st, s) },
  410. func(ctx context.Context, method string) context.Context { return ctx },
  411. )
  412. }
  413. func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
  414. errDetails := []proto.Message{
  415. &epb.RetryInfo{
  416. RetryDelay: &dpb.Duration{Seconds: 60},
  417. },
  418. &epb.ResourceInfo{
  419. ResourceType: "foo bar",
  420. ResourceName: "service.foo.bar",
  421. Owner: "User",
  422. },
  423. }
  424. statusCode := codes.ResourceExhausted
  425. msg := "you are being throttled"
  426. st, err := status.New(statusCode, msg).WithDetails(errDetails...)
  427. if err != nil {
  428. t.Fatal(err)
  429. }
  430. stBytes, err := proto.Marshal(st.Proto())
  431. if err != nil {
  432. t.Fatal(err)
  433. }
  434. hst := newHandleStreamTest(t)
  435. handleStream := func(s *Stream) {
  436. hst.ht.WriteStatus(s, st)
  437. }
  438. hst.ht.HandleStreams(
  439. func(s *Stream) { go handleStream(s) },
  440. func(ctx context.Context, method string) context.Context { return ctx },
  441. )
  442. wantHeader := http.Header{
  443. "Date": nil,
  444. "Content-Type": {"application/grpc"},
  445. "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
  446. "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
  447. "Grpc-Message": {encodeGrpcMessage(msg)},
  448. "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
  449. }
  450. if !reflect.DeepEqual(hst.rw.HeaderMap, wantHeader) {
  451. t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", hst.rw.HeaderMap, wantHeader)
  452. }
  453. }