Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.
 
 
 

331 rader
11 KiB

  1. /*
  2. *
  3. * Copyright 2018 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package alts implements the ALTS credential support by gRPC library, which
  19. // encapsulates all the state needed by a client to authenticate with a server
  20. // using ALTS and make various assertions, e.g., about the client's identity,
  21. // role, or whether it is authorized to make a particular call.
  22. // This package is experimental.
  23. package alts
  24. import (
  25. "context"
  26. "errors"
  27. "fmt"
  28. "net"
  29. "sync"
  30. "time"
  31. "google.golang.org/grpc/credentials"
  32. core "google.golang.org/grpc/credentials/alts/internal"
  33. "google.golang.org/grpc/credentials/alts/internal/handshaker"
  34. "google.golang.org/grpc/credentials/alts/internal/handshaker/service"
  35. altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
  36. "google.golang.org/grpc/grpclog"
  37. )
  38. const (
  39. // hypervisorHandshakerServiceAddress represents the default ALTS gRPC
  40. // handshaker service address in the hypervisor.
  41. hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
  42. // defaultTimeout specifies the server handshake timeout.
  43. defaultTimeout = 30.0 * time.Second
  44. // The following constants specify the minimum and maximum acceptable
  45. // protocol versions.
  46. protocolVersionMaxMajor = 2
  47. protocolVersionMaxMinor = 1
  48. protocolVersionMinMajor = 2
  49. protocolVersionMinMinor = 1
  50. )
  51. var (
  52. once sync.Once
  53. maxRPCVersion = &altspb.RpcProtocolVersions_Version{
  54. Major: protocolVersionMaxMajor,
  55. Minor: protocolVersionMaxMinor,
  56. }
  57. minRPCVersion = &altspb.RpcProtocolVersions_Version{
  58. Major: protocolVersionMinMajor,
  59. Minor: protocolVersionMinMinor,
  60. }
  61. // ErrUntrustedPlatform is returned from ClientHandshake and
  62. // ServerHandshake is running on a platform where the trustworthiness of
  63. // the handshaker service is not guaranteed.
  64. ErrUntrustedPlatform = errors.New("ALTS: untrusted platform. ALTS is only supported on GCP")
  65. )
  66. // AuthInfo exposes security information from the ALTS handshake to the
  67. // application. This interface is to be implemented by ALTS. Users should not
  68. // need a brand new implementation of this interface. For situations like
  69. // testing, any new implementation should embed this interface. This allows
  70. // ALTS to add new methods to this interface.
  71. type AuthInfo interface {
  72. // ApplicationProtocol returns application protocol negotiated for the
  73. // ALTS connection.
  74. ApplicationProtocol() string
  75. // RecordProtocol returns the record protocol negotiated for the ALTS
  76. // connection.
  77. RecordProtocol() string
  78. // SecurityLevel returns the security level of the created ALTS secure
  79. // channel.
  80. SecurityLevel() altspb.SecurityLevel
  81. // PeerServiceAccount returns the peer service account.
  82. PeerServiceAccount() string
  83. // LocalServiceAccount returns the local service account.
  84. LocalServiceAccount() string
  85. // PeerRPCVersions returns the RPC version supported by the peer.
  86. PeerRPCVersions() *altspb.RpcProtocolVersions
  87. }
  88. // ClientOptions contains the client-side options of an ALTS channel. These
  89. // options will be passed to the underlying ALTS handshaker.
  90. type ClientOptions struct {
  91. // TargetServiceAccounts contains a list of expected target service
  92. // accounts.
  93. TargetServiceAccounts []string
  94. // HandshakerServiceAddress represents the ALTS handshaker gRPC service
  95. // address to connect to.
  96. HandshakerServiceAddress string
  97. }
  98. // DefaultClientOptions creates a new ClientOptions object with the default
  99. // values.
  100. func DefaultClientOptions() *ClientOptions {
  101. return &ClientOptions{
  102. HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
  103. }
  104. }
  105. // ServerOptions contains the server-side options of an ALTS channel. These
  106. // options will be passed to the underlying ALTS handshaker.
  107. type ServerOptions struct {
  108. // HandshakerServiceAddress represents the ALTS handshaker gRPC service
  109. // address to connect to.
  110. HandshakerServiceAddress string
  111. }
  112. // DefaultServerOptions creates a new ServerOptions object with the default
  113. // values.
  114. func DefaultServerOptions() *ServerOptions {
  115. return &ServerOptions{
  116. HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
  117. }
  118. }
  119. // altsTC is the credentials required for authenticating a connection using ALTS.
  120. // It implements credentials.TransportCredentials interface.
  121. type altsTC struct {
  122. info *credentials.ProtocolInfo
  123. side core.Side
  124. accounts []string
  125. hsAddress string
  126. }
  127. // NewClientCreds constructs a client-side ALTS TransportCredentials object.
  128. func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
  129. return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
  130. }
  131. // NewServerCreds constructs a server-side ALTS TransportCredentials object.
  132. func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
  133. return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
  134. }
  135. func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
  136. once.Do(func() {
  137. vmOnGCP = isRunningOnGCP()
  138. })
  139. if hsAddress == "" {
  140. hsAddress = hypervisorHandshakerServiceAddress
  141. }
  142. return &altsTC{
  143. info: &credentials.ProtocolInfo{
  144. SecurityProtocol: "alts",
  145. SecurityVersion: "1.0",
  146. },
  147. side: side,
  148. accounts: accounts,
  149. hsAddress: hsAddress,
  150. }
  151. }
  152. // ClientHandshake implements the client side handshake protocol.
  153. func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
  154. if !vmOnGCP {
  155. return nil, nil, ErrUntrustedPlatform
  156. }
  157. // Connecting to ALTS handshaker service.
  158. hsConn, err := service.Dial(g.hsAddress)
  159. if err != nil {
  160. return nil, nil, err
  161. }
  162. // Do not close hsConn since it is shared with other handshakes.
  163. // Possible context leak:
  164. // The cancel function for the child context we create will only be
  165. // called a non-nil error is returned.
  166. var cancel context.CancelFunc
  167. ctx, cancel = context.WithCancel(ctx)
  168. defer func() {
  169. if err != nil {
  170. cancel()
  171. }
  172. }()
  173. opts := handshaker.DefaultClientHandshakerOptions()
  174. opts.TargetName = addr
  175. opts.TargetServiceAccounts = g.accounts
  176. opts.RPCVersions = &altspb.RpcProtocolVersions{
  177. MaxRpcVersion: maxRPCVersion,
  178. MinRpcVersion: minRPCVersion,
  179. }
  180. chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
  181. defer func() {
  182. if err != nil {
  183. chs.Close()
  184. }
  185. }()
  186. if err != nil {
  187. return nil, nil, err
  188. }
  189. secConn, authInfo, err := chs.ClientHandshake(ctx)
  190. if err != nil {
  191. return nil, nil, err
  192. }
  193. altsAuthInfo, ok := authInfo.(AuthInfo)
  194. if !ok {
  195. return nil, nil, errors.New("client-side auth info is not of type alts.AuthInfo")
  196. }
  197. match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
  198. if !match {
  199. return nil, nil, fmt.Errorf("server-side RPC versions are not compatible with this client, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
  200. }
  201. return secConn, authInfo, nil
  202. }
  203. // ServerHandshake implements the server side ALTS handshaker.
  204. func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
  205. if !vmOnGCP {
  206. return nil, nil, ErrUntrustedPlatform
  207. }
  208. // Connecting to ALTS handshaker service.
  209. hsConn, err := service.Dial(g.hsAddress)
  210. if err != nil {
  211. return nil, nil, err
  212. }
  213. // Do not close hsConn since it's shared with other handshakes.
  214. ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
  215. defer cancel()
  216. opts := handshaker.DefaultServerHandshakerOptions()
  217. opts.RPCVersions = &altspb.RpcProtocolVersions{
  218. MaxRpcVersion: maxRPCVersion,
  219. MinRpcVersion: minRPCVersion,
  220. }
  221. shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
  222. defer func() {
  223. if err != nil {
  224. shs.Close()
  225. }
  226. }()
  227. if err != nil {
  228. return nil, nil, err
  229. }
  230. secConn, authInfo, err := shs.ServerHandshake(ctx)
  231. if err != nil {
  232. return nil, nil, err
  233. }
  234. altsAuthInfo, ok := authInfo.(AuthInfo)
  235. if !ok {
  236. return nil, nil, errors.New("server-side auth info is not of type alts.AuthInfo")
  237. }
  238. match, _ := checkRPCVersions(opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
  239. if !match {
  240. return nil, nil, fmt.Errorf("client-side RPC versions is not compatible with this server, local versions: %v, peer versions: %v", opts.RPCVersions, altsAuthInfo.PeerRPCVersions())
  241. }
  242. return secConn, authInfo, nil
  243. }
  244. func (g *altsTC) Info() credentials.ProtocolInfo {
  245. return *g.info
  246. }
  247. func (g *altsTC) Clone() credentials.TransportCredentials {
  248. info := *g.info
  249. var accounts []string
  250. if g.accounts != nil {
  251. accounts = make([]string, len(g.accounts))
  252. copy(accounts, g.accounts)
  253. }
  254. return &altsTC{
  255. info: &info,
  256. side: g.side,
  257. hsAddress: g.hsAddress,
  258. accounts: accounts,
  259. }
  260. }
  261. func (g *altsTC) OverrideServerName(serverNameOverride string) error {
  262. g.info.ServerName = serverNameOverride
  263. return nil
  264. }
  265. // compareRPCVersion returns 0 if v1 == v2, 1 if v1 > v2 and -1 if v1 < v2.
  266. func compareRPCVersions(v1, v2 *altspb.RpcProtocolVersions_Version) int {
  267. switch {
  268. case v1.GetMajor() > v2.GetMajor(),
  269. v1.GetMajor() == v2.GetMajor() && v1.GetMinor() > v2.GetMinor():
  270. return 1
  271. case v1.GetMajor() < v2.GetMajor(),
  272. v1.GetMajor() == v2.GetMajor() && v1.GetMinor() < v2.GetMinor():
  273. return -1
  274. }
  275. return 0
  276. }
  277. // checkRPCVersions performs a version check between local and peer rpc protocol
  278. // versions. This function returns true if the check passes which means both
  279. // parties agreed on a common rpc protocol to use, and false otherwise. The
  280. // function also returns the highest common RPC protocol version both parties
  281. // agreed on.
  282. func checkRPCVersions(local, peer *altspb.RpcProtocolVersions) (bool, *altspb.RpcProtocolVersions_Version) {
  283. if local == nil || peer == nil {
  284. grpclog.Error("invalid checkRPCVersions argument, either local or peer is nil.")
  285. return false, nil
  286. }
  287. // maxCommonVersion is MIN(local.max, peer.max).
  288. maxCommonVersion := local.GetMaxRpcVersion()
  289. if compareRPCVersions(local.GetMaxRpcVersion(), peer.GetMaxRpcVersion()) > 0 {
  290. maxCommonVersion = peer.GetMaxRpcVersion()
  291. }
  292. // minCommonVersion is MAX(local.min, peer.min).
  293. minCommonVersion := peer.GetMinRpcVersion()
  294. if compareRPCVersions(local.GetMinRpcVersion(), peer.GetMinRpcVersion()) > 0 {
  295. minCommonVersion = local.GetMinRpcVersion()
  296. }
  297. if compareRPCVersions(maxCommonVersion, minCommonVersion) < 0 {
  298. return false, nil
  299. }
  300. return true, maxCommonVersion
  301. }