You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

369 lines
9.0 KiB

  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build darwin dragonfly freebsd linux netbsd openbsd plan9
  5. package test
  6. // functional test harness for unix.
  7. import (
  8. "bytes"
  9. "crypto/rand"
  10. "encoding/base64"
  11. "fmt"
  12. "io/ioutil"
  13. "log"
  14. "net"
  15. "os"
  16. "os/exec"
  17. "os/user"
  18. "path/filepath"
  19. "testing"
  20. "text/template"
  21. "golang.org/x/crypto/ssh"
  22. "golang.org/x/crypto/ssh/testdata"
  23. )
  24. const (
  25. defaultSshdConfig = `
  26. Protocol 2
  27. Banner {{.Dir}}/banner
  28. HostKey {{.Dir}}/id_rsa
  29. HostKey {{.Dir}}/id_dsa
  30. HostKey {{.Dir}}/id_ecdsa
  31. HostCertificate {{.Dir}}/id_rsa-cert.pub
  32. Pidfile {{.Dir}}/sshd.pid
  33. #UsePrivilegeSeparation no
  34. KeyRegenerationInterval 3600
  35. ServerKeyBits 768
  36. SyslogFacility AUTH
  37. LogLevel DEBUG2
  38. LoginGraceTime 120
  39. PermitRootLogin no
  40. StrictModes no
  41. RSAAuthentication yes
  42. PubkeyAuthentication yes
  43. AuthorizedKeysFile {{.Dir}}/authorized_keys
  44. TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
  45. IgnoreRhosts yes
  46. RhostsRSAAuthentication no
  47. HostbasedAuthentication no
  48. PubkeyAcceptedKeyTypes=*
  49. `
  50. multiAuthSshdConfigTail = `
  51. UsePAM yes
  52. PasswordAuthentication yes
  53. ChallengeResponseAuthentication yes
  54. AuthenticationMethods {{.AuthMethods}}
  55. `
  56. )
  57. var configTmpl = map[string]*template.Template{
  58. "default": template.Must(template.New("").Parse(defaultSshdConfig)),
  59. "MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
  60. type server struct {
  61. t *testing.T
  62. cleanup func() // executed during Shutdown
  63. configfile string
  64. cmd *exec.Cmd
  65. output bytes.Buffer // holds stderr from sshd process
  66. testUser string // test username for sshd
  67. testPasswd string // test password for sshd
  68. sshdTestPwSo string // dynamic library to inject a custom password into sshd
  69. // Client half of the network connection.
  70. clientConn net.Conn
  71. }
  72. func username() string {
  73. var username string
  74. if user, err := user.Current(); err == nil {
  75. username = user.Username
  76. } else {
  77. // user.Current() currently requires cgo. If an error is
  78. // returned attempt to get the username from the environment.
  79. log.Printf("user.Current: %v; falling back on $USER", err)
  80. username = os.Getenv("USER")
  81. }
  82. if username == "" {
  83. panic("Unable to get username")
  84. }
  85. return username
  86. }
  87. type storedHostKey struct {
  88. // keys map from an algorithm string to binary key data.
  89. keys map[string][]byte
  90. // checkCount counts the Check calls. Used for testing
  91. // rekeying.
  92. checkCount int
  93. }
  94. func (k *storedHostKey) Add(key ssh.PublicKey) {
  95. if k.keys == nil {
  96. k.keys = map[string][]byte{}
  97. }
  98. k.keys[key.Type()] = key.Marshal()
  99. }
  100. func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
  101. k.checkCount++
  102. algo := key.Type()
  103. if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
  104. return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
  105. }
  106. return nil
  107. }
  108. func hostKeyDB() *storedHostKey {
  109. keyChecker := &storedHostKey{}
  110. keyChecker.Add(testPublicKeys["ecdsa"])
  111. keyChecker.Add(testPublicKeys["rsa"])
  112. keyChecker.Add(testPublicKeys["dsa"])
  113. return keyChecker
  114. }
  115. func clientConfig() *ssh.ClientConfig {
  116. config := &ssh.ClientConfig{
  117. User: username(),
  118. Auth: []ssh.AuthMethod{
  119. ssh.PublicKeys(testSigners["user"]),
  120. },
  121. HostKeyCallback: hostKeyDB().Check,
  122. HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
  123. ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
  124. ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
  125. ssh.KeyAlgoED25519,
  126. },
  127. }
  128. return config
  129. }
  130. // unixConnection creates two halves of a connected net.UnixConn. It
  131. // is used for connecting the Go SSH client with sshd without opening
  132. // ports.
  133. func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
  134. dir, err := ioutil.TempDir("", "unixConnection")
  135. if err != nil {
  136. return nil, nil, err
  137. }
  138. defer os.Remove(dir)
  139. addr := filepath.Join(dir, "ssh")
  140. listener, err := net.Listen("unix", addr)
  141. if err != nil {
  142. return nil, nil, err
  143. }
  144. defer listener.Close()
  145. c1, err := net.Dial("unix", addr)
  146. if err != nil {
  147. return nil, nil, err
  148. }
  149. c2, err := listener.Accept()
  150. if err != nil {
  151. c1.Close()
  152. return nil, nil, err
  153. }
  154. return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
  155. }
  156. func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
  157. return s.TryDialWithAddr(config, "")
  158. }
  159. // addr is the user specified host:port. While we don't actually dial it,
  160. // we need to know this for host key matching
  161. func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
  162. sshd, err := exec.LookPath("sshd")
  163. if err != nil {
  164. s.t.Skipf("skipping test: %v", err)
  165. }
  166. c1, c2, err := unixConnection()
  167. if err != nil {
  168. s.t.Fatalf("unixConnection: %v", err)
  169. }
  170. s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
  171. f, err := c2.File()
  172. if err != nil {
  173. s.t.Fatalf("UnixConn.File: %v", err)
  174. }
  175. defer f.Close()
  176. s.cmd.Stdin = f
  177. s.cmd.Stdout = f
  178. s.cmd.Stderr = &s.output
  179. if s.sshdTestPwSo != "" {
  180. if s.testUser == "" {
  181. s.t.Fatal("user missing from sshd_test_pw.so config")
  182. }
  183. if s.testPasswd == "" {
  184. s.t.Fatal("password missing from sshd_test_pw.so config")
  185. }
  186. s.cmd.Env = append(os.Environ(),
  187. fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
  188. fmt.Sprintf("TEST_USER=%s", s.testUser),
  189. fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
  190. }
  191. if err := s.cmd.Start(); err != nil {
  192. s.t.Fail()
  193. s.Shutdown()
  194. s.t.Fatalf("s.cmd.Start: %v", err)
  195. }
  196. s.clientConn = c1
  197. conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
  198. if err != nil {
  199. return nil, err
  200. }
  201. return ssh.NewClient(conn, chans, reqs), nil
  202. }
  203. func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
  204. conn, err := s.TryDial(config)
  205. if err != nil {
  206. s.t.Fail()
  207. s.Shutdown()
  208. s.t.Fatalf("ssh.Client: %v", err)
  209. }
  210. return conn
  211. }
  212. func (s *server) Shutdown() {
  213. if s.cmd != nil && s.cmd.Process != nil {
  214. // Don't check for errors; if it fails it's most
  215. // likely "os: process already finished", and we don't
  216. // care about that. Use os.Interrupt, so child
  217. // processes are killed too.
  218. s.cmd.Process.Signal(os.Interrupt)
  219. s.cmd.Wait()
  220. }
  221. if s.t.Failed() {
  222. // log any output from sshd process
  223. s.t.Logf("sshd: %s", s.output.String())
  224. }
  225. s.cleanup()
  226. }
  227. func writeFile(path string, contents []byte) {
  228. f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
  229. if err != nil {
  230. panic(err)
  231. }
  232. defer f.Close()
  233. if _, err := f.Write(contents); err != nil {
  234. panic(err)
  235. }
  236. }
  237. // generate random password
  238. func randomPassword() (string, error) {
  239. b := make([]byte, 12)
  240. _, err := rand.Read(b)
  241. if err != nil {
  242. return "", err
  243. }
  244. return base64.RawURLEncoding.EncodeToString(b), nil
  245. }
  246. // setTestPassword is used for setting user and password data for sshd_test_pw.so
  247. // This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
  248. func (s *server) setTestPassword(user, passwd string) error {
  249. wd, _ := os.Getwd()
  250. wrapper := filepath.Join(wd, "sshd_test_pw.so")
  251. if _, err := os.Stat(wrapper); err != nil {
  252. s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
  253. return err
  254. }
  255. s.sshdTestPwSo = wrapper
  256. s.testUser = user
  257. s.testPasswd = passwd
  258. return nil
  259. }
  260. // newServer returns a new mock ssh server.
  261. func newServer(t *testing.T) *server {
  262. return newServerForConfig(t, "default", map[string]string{})
  263. }
  264. // newServerForConfig returns a new mock ssh server.
  265. func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
  266. if testing.Short() {
  267. t.Skip("skipping test due to -short")
  268. }
  269. u, err := user.Current()
  270. if err != nil {
  271. t.Fatalf("user.Current: %v", err)
  272. }
  273. if u.Name == "root" {
  274. t.Skip("skipping test because current user is root")
  275. }
  276. dir, err := ioutil.TempDir("", "sshtest")
  277. if err != nil {
  278. t.Fatal(err)
  279. }
  280. f, err := os.Create(filepath.Join(dir, "sshd_config"))
  281. if err != nil {
  282. t.Fatal(err)
  283. }
  284. if _, ok := configTmpl[config]; ok == false {
  285. t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
  286. }
  287. configVars["Dir"] = dir
  288. err = configTmpl[config].Execute(f, configVars)
  289. if err != nil {
  290. t.Fatal(err)
  291. }
  292. f.Close()
  293. writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
  294. for k, v := range testdata.PEMBytes {
  295. filename := "id_" + k
  296. writeFile(filepath.Join(dir, filename), v)
  297. writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
  298. }
  299. for k, v := range testdata.SSHCertificates {
  300. filename := "id_" + k + "-cert.pub"
  301. writeFile(filepath.Join(dir, filename), v)
  302. }
  303. var authkeys bytes.Buffer
  304. for k := range testdata.PEMBytes {
  305. authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
  306. }
  307. writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
  308. return &server{
  309. t: t,
  310. configfile: f.Name(),
  311. cleanup: func() {
  312. if err := os.RemoveAll(dir); err != nil {
  313. t.Error(err)
  314. }
  315. },
  316. }
  317. }
  318. func newTempSocket(t *testing.T) (string, func()) {
  319. dir, err := ioutil.TempDir("", "socket")
  320. if err != nil {
  321. t.Fatal(err)
  322. }
  323. deferFunc := func() { os.RemoveAll(dir) }
  324. addr := filepath.Join(dir, "sock")
  325. return addr, deferFunc
  326. }