您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 

291 行
7.4 KiB

  1. /*
  2. *
  3. * Copyright 2018 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package alts
  19. import (
  20. "reflect"
  21. "testing"
  22. "github.com/golang/protobuf/proto"
  23. altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
  24. )
  25. func TestInfoServerName(t *testing.T) {
  26. // This is not testing any handshaker functionality, so it's fine to only
  27. // use NewServerCreds and not NewClientCreds.
  28. alts := NewServerCreds(DefaultServerOptions())
  29. if got, want := alts.Info().ServerName, ""; got != want {
  30. t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
  31. }
  32. }
  33. func TestOverrideServerName(t *testing.T) {
  34. wantServerName := "server.name"
  35. // This is not testing any handshaker functionality, so it's fine to only
  36. // use NewServerCreds and not NewClientCreds.
  37. c := NewServerCreds(DefaultServerOptions())
  38. c.OverrideServerName(wantServerName)
  39. if got, want := c.Info().ServerName, wantServerName; got != want {
  40. t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
  41. }
  42. }
  43. func TestCloneClient(t *testing.T) {
  44. wantServerName := "server.name"
  45. opt := DefaultClientOptions()
  46. opt.TargetServiceAccounts = []string{"not", "empty"}
  47. c := NewClientCreds(opt)
  48. c.OverrideServerName(wantServerName)
  49. cc := c.Clone()
  50. if got, want := cc.Info().ServerName, wantServerName; got != want {
  51. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  52. }
  53. cc.OverrideServerName("")
  54. if got, want := c.Info().ServerName, wantServerName; got != want {
  55. t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
  56. }
  57. if got, want := cc.Info().ServerName, ""; got != want {
  58. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  59. }
  60. ct := c.(*altsTC)
  61. cct := cc.(*altsTC)
  62. if ct.side != cct.side {
  63. t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
  64. }
  65. if ct.hsAddress != cct.hsAddress {
  66. t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
  67. }
  68. if !reflect.DeepEqual(ct.accounts, cct.accounts) {
  69. t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
  70. }
  71. }
  72. func TestCloneServer(t *testing.T) {
  73. wantServerName := "server.name"
  74. c := NewServerCreds(DefaultServerOptions())
  75. c.OverrideServerName(wantServerName)
  76. cc := c.Clone()
  77. if got, want := cc.Info().ServerName, wantServerName; got != want {
  78. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  79. }
  80. cc.OverrideServerName("")
  81. if got, want := c.Info().ServerName, wantServerName; got != want {
  82. t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", got, want)
  83. }
  84. if got, want := cc.Info().ServerName, ""; got != want {
  85. t.Fatalf("cc.Info().ServerName = %v, want %v", got, want)
  86. }
  87. ct := c.(*altsTC)
  88. cct := cc.(*altsTC)
  89. if ct.side != cct.side {
  90. t.Errorf("cc.side = %q, want %q", cct.side, ct.side)
  91. }
  92. if ct.hsAddress != cct.hsAddress {
  93. t.Errorf("cc.hsAddress = %q, want %q", cct.hsAddress, ct.hsAddress)
  94. }
  95. if !reflect.DeepEqual(ct.accounts, cct.accounts) {
  96. t.Errorf("cc.accounts = %q, want %q", cct.accounts, ct.accounts)
  97. }
  98. }
  99. func TestInfo(t *testing.T) {
  100. // This is not testing any handshaker functionality, so it's fine to only
  101. // use NewServerCreds and not NewClientCreds.
  102. c := NewServerCreds(DefaultServerOptions())
  103. info := c.Info()
  104. if got, want := info.ProtocolVersion, ""; got != want {
  105. t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
  106. }
  107. if got, want := info.SecurityProtocol, "alts"; got != want {
  108. t.Errorf("info.SecurityProtocol=%v, want %v", got, want)
  109. }
  110. if got, want := info.SecurityVersion, "1.0"; got != want {
  111. t.Errorf("info.SecurityVersion=%v, want %v", got, want)
  112. }
  113. if got, want := info.ServerName, ""; got != want {
  114. t.Errorf("info.ServerName=%v, want %v", got, want)
  115. }
  116. }
  117. func TestCompareRPCVersions(t *testing.T) {
  118. for _, tc := range []struct {
  119. v1 *altspb.RpcProtocolVersions_Version
  120. v2 *altspb.RpcProtocolVersions_Version
  121. output int
  122. }{
  123. {
  124. version(3, 2),
  125. version(2, 1),
  126. 1,
  127. },
  128. {
  129. version(3, 2),
  130. version(3, 1),
  131. 1,
  132. },
  133. {
  134. version(2, 1),
  135. version(3, 2),
  136. -1,
  137. },
  138. {
  139. version(3, 1),
  140. version(3, 2),
  141. -1,
  142. },
  143. {
  144. version(3, 2),
  145. version(3, 2),
  146. 0,
  147. },
  148. } {
  149. if got, want := compareRPCVersions(tc.v1, tc.v2), tc.output; got != want {
  150. t.Errorf("compareRPCVersions(%v, %v)=%v, want %v", tc.v1, tc.v2, got, want)
  151. }
  152. }
  153. }
  154. func TestCheckRPCVersions(t *testing.T) {
  155. for _, tc := range []struct {
  156. desc string
  157. local *altspb.RpcProtocolVersions
  158. peer *altspb.RpcProtocolVersions
  159. output bool
  160. maxCommonVersion *altspb.RpcProtocolVersions_Version
  161. }{
  162. {
  163. "local.max > peer.max and local.min > peer.min",
  164. versions(2, 1, 3, 2),
  165. versions(1, 2, 2, 1),
  166. true,
  167. version(2, 1),
  168. },
  169. {
  170. "local.max > peer.max and local.min < peer.min",
  171. versions(1, 2, 3, 2),
  172. versions(2, 1, 2, 1),
  173. true,
  174. version(2, 1),
  175. },
  176. {
  177. "local.max > peer.max and local.min = peer.min",
  178. versions(2, 1, 3, 2),
  179. versions(2, 1, 2, 1),
  180. true,
  181. version(2, 1),
  182. },
  183. {
  184. "local.max < peer.max and local.min > peer.min",
  185. versions(2, 1, 2, 1),
  186. versions(1, 2, 3, 2),
  187. true,
  188. version(2, 1),
  189. },
  190. {
  191. "local.max = peer.max and local.min > peer.min",
  192. versions(2, 1, 2, 1),
  193. versions(1, 2, 2, 1),
  194. true,
  195. version(2, 1),
  196. },
  197. {
  198. "local.max < peer.max and local.min < peer.min",
  199. versions(1, 2, 2, 1),
  200. versions(2, 1, 3, 2),
  201. true,
  202. version(2, 1),
  203. },
  204. {
  205. "local.max < peer.max and local.min = peer.min",
  206. versions(1, 2, 2, 1),
  207. versions(1, 2, 3, 2),
  208. true,
  209. version(2, 1),
  210. },
  211. {
  212. "local.max = peer.max and local.min < peer.min",
  213. versions(1, 2, 2, 1),
  214. versions(2, 1, 2, 1),
  215. true,
  216. version(2, 1),
  217. },
  218. {
  219. "all equal",
  220. versions(2, 1, 2, 1),
  221. versions(2, 1, 2, 1),
  222. true,
  223. version(2, 1),
  224. },
  225. {
  226. "max is smaller than min",
  227. versions(2, 1, 1, 2),
  228. versions(2, 1, 1, 2),
  229. false,
  230. nil,
  231. },
  232. {
  233. "no overlap, local > peer",
  234. versions(4, 3, 6, 5),
  235. versions(1, 0, 2, 1),
  236. false,
  237. nil,
  238. },
  239. {
  240. "no overlap, local < peer",
  241. versions(1, 0, 2, 1),
  242. versions(4, 3, 6, 5),
  243. false,
  244. nil,
  245. },
  246. {
  247. "no overlap, max < min",
  248. versions(6, 5, 4, 3),
  249. versions(2, 1, 1, 0),
  250. false,
  251. nil,
  252. },
  253. } {
  254. output, maxCommonVersion := checkRPCVersions(tc.local, tc.peer)
  255. if got, want := output, tc.output; got != want {
  256. t.Errorf("%v: checkRPCVersions(%v, %v)=(%v, _), want (%v, _)", tc.desc, tc.local, tc.peer, got, want)
  257. }
  258. if got, want := maxCommonVersion, tc.maxCommonVersion; !proto.Equal(got, want) {
  259. t.Errorf("%v: checkRPCVersions(%v, %v)=(_, %v), want (_, %v)", tc.desc, tc.local, tc.peer, got, want)
  260. }
  261. }
  262. }
  263. func version(major, minor uint32) *altspb.RpcProtocolVersions_Version {
  264. return &altspb.RpcProtocolVersions_Version{
  265. Major: major,
  266. Minor: minor,
  267. }
  268. }
  269. func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocolVersions {
  270. return &altspb.RpcProtocolVersions{
  271. MinRpcVersion: version(minMajor, minMinor),
  272. MaxRpcVersion: version(maxMajor, maxMinor),
  273. }
  274. }