You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

260 lines
6.2 KiB

  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package agent
  5. import (
  6. "crypto"
  7. "crypto/rand"
  8. "fmt"
  9. pseudorand "math/rand"
  10. "reflect"
  11. "strings"
  12. "testing"
  13. "golang.org/x/crypto/ssh"
  14. )
  15. func TestServer(t *testing.T) {
  16. c1, c2, err := netPipe()
  17. if err != nil {
  18. t.Fatalf("netPipe: %v", err)
  19. }
  20. defer c1.Close()
  21. defer c2.Close()
  22. client := NewClient(c1)
  23. go ServeAgent(NewKeyring(), c2)
  24. testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
  25. }
  26. func TestLockServer(t *testing.T) {
  27. testLockAgent(NewKeyring(), t)
  28. }
  29. func TestSetupForwardAgent(t *testing.T) {
  30. a, b, err := netPipe()
  31. if err != nil {
  32. t.Fatalf("netPipe: %v", err)
  33. }
  34. defer a.Close()
  35. defer b.Close()
  36. _, socket, cleanup := startOpenSSHAgent(t)
  37. defer cleanup()
  38. serverConf := ssh.ServerConfig{
  39. NoClientAuth: true,
  40. }
  41. serverConf.AddHostKey(testSigners["rsa"])
  42. incoming := make(chan *ssh.ServerConn, 1)
  43. go func() {
  44. conn, _, _, err := ssh.NewServerConn(a, &serverConf)
  45. if err != nil {
  46. t.Fatalf("Server: %v", err)
  47. }
  48. incoming <- conn
  49. }()
  50. conf := ssh.ClientConfig{
  51. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  52. }
  53. conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
  54. if err != nil {
  55. t.Fatalf("NewClientConn: %v", err)
  56. }
  57. client := ssh.NewClient(conn, chans, reqs)
  58. if err := ForwardToRemote(client, socket); err != nil {
  59. t.Fatalf("SetupForwardAgent: %v", err)
  60. }
  61. server := <-incoming
  62. ch, reqs, err := server.OpenChannel(channelType, nil)
  63. if err != nil {
  64. t.Fatalf("OpenChannel(%q): %v", channelType, err)
  65. }
  66. go ssh.DiscardRequests(reqs)
  67. agentClient := NewClient(ch)
  68. testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
  69. conn.Close()
  70. }
  71. func TestV1ProtocolMessages(t *testing.T) {
  72. c1, c2, err := netPipe()
  73. if err != nil {
  74. t.Fatalf("netPipe: %v", err)
  75. }
  76. defer c1.Close()
  77. defer c2.Close()
  78. c := NewClient(c1)
  79. go ServeAgent(NewKeyring(), c2)
  80. testV1ProtocolMessages(t, c.(*client))
  81. }
  82. func testV1ProtocolMessages(t *testing.T, c *client) {
  83. reply, err := c.call([]byte{agentRequestV1Identities})
  84. if err != nil {
  85. t.Fatalf("v1 request all failed: %v", err)
  86. }
  87. if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
  88. t.Fatalf("invalid request all response: %#v", reply)
  89. }
  90. reply, err = c.call([]byte{agentRemoveAllV1Identities})
  91. if err != nil {
  92. t.Fatalf("v1 remove all failed: %v", err)
  93. }
  94. if _, ok := reply.(*successAgentMsg); !ok {
  95. t.Fatalf("invalid remove all response: %#v", reply)
  96. }
  97. }
  98. func verifyKey(sshAgent Agent) error {
  99. keys, err := sshAgent.List()
  100. if err != nil {
  101. return fmt.Errorf("listing keys: %v", err)
  102. }
  103. if len(keys) != 1 {
  104. return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
  105. }
  106. buf := make([]byte, 128)
  107. if _, err := rand.Read(buf); err != nil {
  108. return fmt.Errorf("rand: %v", err)
  109. }
  110. sig, err := sshAgent.Sign(keys[0], buf)
  111. if err != nil {
  112. return fmt.Errorf("sign: %v", err)
  113. }
  114. if err := keys[0].Verify(buf, sig); err != nil {
  115. return fmt.Errorf("verify: %v", err)
  116. }
  117. return nil
  118. }
  119. func addKeyToAgent(key crypto.PrivateKey) error {
  120. sshAgent := NewKeyring()
  121. if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
  122. return fmt.Errorf("add: %v", err)
  123. }
  124. return verifyKey(sshAgent)
  125. }
  126. func TestKeyTypes(t *testing.T) {
  127. for k, v := range testPrivateKeys {
  128. if err := addKeyToAgent(v); err != nil {
  129. t.Errorf("error adding key type %s, %v", k, err)
  130. }
  131. if err := addCertToAgentSock(v, nil); err != nil {
  132. t.Errorf("error adding key type %s, %v", k, err)
  133. }
  134. }
  135. }
  136. func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
  137. a, b, err := netPipe()
  138. if err != nil {
  139. return err
  140. }
  141. agentServer := NewKeyring()
  142. go ServeAgent(agentServer, a)
  143. agentClient := NewClient(b)
  144. if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
  145. return fmt.Errorf("add: %v", err)
  146. }
  147. return verifyKey(agentClient)
  148. }
  149. func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
  150. sshAgent := NewKeyring()
  151. if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
  152. return fmt.Errorf("add: %v", err)
  153. }
  154. return verifyKey(sshAgent)
  155. }
  156. func TestCertTypes(t *testing.T) {
  157. for keyType, key := range testPublicKeys {
  158. cert := &ssh.Certificate{
  159. ValidPrincipals: []string{"gopher1"},
  160. ValidAfter: 0,
  161. ValidBefore: ssh.CertTimeInfinity,
  162. Key: key,
  163. Serial: 1,
  164. CertType: ssh.UserCert,
  165. SignatureKey: testPublicKeys["rsa"],
  166. Permissions: ssh.Permissions{
  167. CriticalOptions: map[string]string{},
  168. Extensions: map[string]string{},
  169. },
  170. }
  171. if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
  172. t.Fatalf("signcert: %v", err)
  173. }
  174. if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
  175. t.Fatalf("%v", err)
  176. }
  177. if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
  178. t.Fatalf("%v", err)
  179. }
  180. }
  181. }
  182. func TestParseConstraints(t *testing.T) {
  183. // Test LifetimeSecs
  184. var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
  185. lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
  186. if err != nil {
  187. t.Fatalf("parseConstraints: %v", err)
  188. }
  189. if lifetimeSecs != msg.LifetimeSecs {
  190. t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
  191. }
  192. // Test ConfirmBeforeUse
  193. _, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
  194. if err != nil {
  195. t.Fatalf("%v", err)
  196. }
  197. if !confirmBeforeUse {
  198. t.Error("got comfirmBeforeUse == false")
  199. }
  200. // Test ConstraintExtensions
  201. var data []byte
  202. var expect []ConstraintExtension
  203. for i := 0; i < 10; i++ {
  204. var ext = ConstraintExtension{
  205. ExtensionName: fmt.Sprintf("name%d", i),
  206. ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
  207. }
  208. expect = append(expect, ext)
  209. data = append(data, agentConstrainExtension)
  210. data = append(data, ssh.Marshal(ext)...)
  211. }
  212. _, _, extensions, err := parseConstraints(data)
  213. if err != nil {
  214. t.Fatalf("%v", err)
  215. }
  216. if !reflect.DeepEqual(expect, extensions) {
  217. t.Errorf("got extension %v, want %v", extensions, expect)
  218. }
  219. // Test Unknown Constraint
  220. _, _, _, err = parseConstraints([]byte{128})
  221. if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
  222. t.Errorf("unexpected error: %v", err)
  223. }
  224. }