You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

441 lines
11 KiB

  1. // Copyright 2016 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package autocert
  5. import (
  6. "crypto"
  7. "crypto/ecdsa"
  8. "crypto/elliptic"
  9. "crypto/rand"
  10. "crypto/rsa"
  11. "crypto/tls"
  12. "crypto/x509"
  13. "crypto/x509/pkix"
  14. "encoding/base64"
  15. "encoding/json"
  16. "fmt"
  17. "html/template"
  18. "io"
  19. "math/big"
  20. "net/http"
  21. "net/http/httptest"
  22. "reflect"
  23. "sync"
  24. "testing"
  25. "time"
  26. "golang.org/x/crypto/acme"
  27. "golang.org/x/net/context"
  28. )
  29. var discoTmpl = template.Must(template.New("disco").Parse(`{
  30. "new-reg": "{{.}}/new-reg",
  31. "new-authz": "{{.}}/new-authz",
  32. "new-cert": "{{.}}/new-cert"
  33. }`))
  34. var authzTmpl = template.Must(template.New("authz").Parse(`{
  35. "status": "pending",
  36. "challenges": [
  37. {
  38. "uri": "{{.}}/challenge/1",
  39. "type": "tls-sni-01",
  40. "token": "token-01"
  41. },
  42. {
  43. "uri": "{{.}}/challenge/2",
  44. "type": "tls-sni-02",
  45. "token": "token-02"
  46. }
  47. ]
  48. }`))
  49. type memCache struct {
  50. mu sync.Mutex
  51. keyData map[string][]byte
  52. }
  53. func (m *memCache) Get(ctx context.Context, key string) ([]byte, error) {
  54. m.mu.Lock()
  55. defer m.mu.Unlock()
  56. v, ok := m.keyData[key]
  57. if !ok {
  58. return nil, ErrCacheMiss
  59. }
  60. return v, nil
  61. }
  62. func (m *memCache) Put(ctx context.Context, key string, data []byte) error {
  63. m.mu.Lock()
  64. defer m.mu.Unlock()
  65. m.keyData[key] = data
  66. return nil
  67. }
  68. func (m *memCache) Delete(ctx context.Context, key string) error {
  69. m.mu.Lock()
  70. defer m.mu.Unlock()
  71. delete(m.keyData, key)
  72. return nil
  73. }
  74. func newMemCache() *memCache {
  75. return &memCache{
  76. keyData: make(map[string][]byte),
  77. }
  78. }
  79. func dummyCert(pub interface{}, san ...string) ([]byte, error) {
  80. return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...)
  81. }
  82. func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) {
  83. // use EC key to run faster on 386
  84. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  85. if err != nil {
  86. return nil, err
  87. }
  88. t := &x509.Certificate{
  89. SerialNumber: big.NewInt(1),
  90. NotBefore: start,
  91. NotAfter: end,
  92. BasicConstraintsValid: true,
  93. KeyUsage: x509.KeyUsageKeyEncipherment,
  94. DNSNames: san,
  95. }
  96. if pub == nil {
  97. pub = &key.PublicKey
  98. }
  99. return x509.CreateCertificate(rand.Reader, t, t, pub, key)
  100. }
  101. func decodePayload(v interface{}, r io.Reader) error {
  102. var req struct{ Payload string }
  103. if err := json.NewDecoder(r).Decode(&req); err != nil {
  104. return err
  105. }
  106. payload, err := base64.RawURLEncoding.DecodeString(req.Payload)
  107. if err != nil {
  108. return err
  109. }
  110. return json.Unmarshal(payload, v)
  111. }
  112. func TestGetCertificate(t *testing.T) {
  113. man := &Manager{Prompt: AcceptTOS}
  114. defer man.stopRenew()
  115. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  116. testGetCertificate(t, man, "example.org", hello)
  117. }
  118. func TestGetCertificate_trailingDot(t *testing.T) {
  119. man := &Manager{Prompt: AcceptTOS}
  120. defer man.stopRenew()
  121. hello := &tls.ClientHelloInfo{ServerName: "example.org."}
  122. testGetCertificate(t, man, "example.org", hello)
  123. }
  124. func TestGetCertificate_ForceRSA(t *testing.T) {
  125. man := &Manager{
  126. Prompt: AcceptTOS,
  127. Cache: newMemCache(),
  128. ForceRSA: true,
  129. }
  130. defer man.stopRenew()
  131. hello := &tls.ClientHelloInfo{ServerName: "example.org"}
  132. testGetCertificate(t, man, "example.org", hello)
  133. cert, err := man.cacheGet("example.org")
  134. if err != nil {
  135. t.Fatalf("man.cacheGet: %v", err)
  136. }
  137. if _, ok := cert.PrivateKey.(*rsa.PrivateKey); !ok {
  138. t.Errorf("cert.PrivateKey is %T; want *rsa.PrivateKey", cert.PrivateKey)
  139. }
  140. }
  141. // tests man.GetCertificate flow using the provided hello argument.
  142. // The domain argument is the expected domain name of a certificate request.
  143. func testGetCertificate(t *testing.T, man *Manager, domain string, hello *tls.ClientHelloInfo) {
  144. // echo token-02 | shasum -a 256
  145. // then divide result in 2 parts separated by dot
  146. tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid"
  147. verifyTokenCert := func() {
  148. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  149. _, err := man.GetCertificate(hello)
  150. if err != nil {
  151. t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err)
  152. return
  153. }
  154. }
  155. // ACME CA server stub
  156. var ca *httptest.Server
  157. ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  158. w.Header().Set("replay-nonce", "nonce")
  159. if r.Method == "HEAD" {
  160. // a nonce request
  161. return
  162. }
  163. switch r.URL.Path {
  164. // discovery
  165. case "/":
  166. if err := discoTmpl.Execute(w, ca.URL); err != nil {
  167. t.Fatalf("discoTmpl: %v", err)
  168. }
  169. // client key registration
  170. case "/new-reg":
  171. w.Write([]byte("{}"))
  172. // domain authorization
  173. case "/new-authz":
  174. w.Header().Set("location", ca.URL+"/authz/1")
  175. w.WriteHeader(http.StatusCreated)
  176. if err := authzTmpl.Execute(w, ca.URL); err != nil {
  177. t.Fatalf("authzTmpl: %v", err)
  178. }
  179. // accept tls-sni-02 challenge
  180. case "/challenge/2":
  181. verifyTokenCert()
  182. w.Write([]byte("{}"))
  183. // authorization status
  184. case "/authz/1":
  185. w.Write([]byte(`{"status": "valid"}`))
  186. // cert request
  187. case "/new-cert":
  188. var req struct {
  189. CSR string `json:"csr"`
  190. }
  191. decodePayload(&req, r.Body)
  192. b, _ := base64.RawURLEncoding.DecodeString(req.CSR)
  193. csr, err := x509.ParseCertificateRequest(b)
  194. if err != nil {
  195. t.Fatalf("new-cert: CSR: %v", err)
  196. }
  197. if csr.Subject.CommonName != domain {
  198. t.Errorf("CommonName in CSR = %q; want %q", csr.Subject.CommonName, domain)
  199. }
  200. der, err := dummyCert(csr.PublicKey, domain)
  201. if err != nil {
  202. t.Fatalf("new-cert: dummyCert: %v", err)
  203. }
  204. chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL)
  205. w.Header().Set("link", chainUp)
  206. w.WriteHeader(http.StatusCreated)
  207. w.Write(der)
  208. // CA chain cert
  209. case "/ca-cert":
  210. der, err := dummyCert(nil, "ca")
  211. if err != nil {
  212. t.Fatalf("ca-cert: dummyCert: %v", err)
  213. }
  214. w.Write(der)
  215. default:
  216. t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path)
  217. }
  218. }))
  219. defer ca.Close()
  220. // use EC key to run faster on 386
  221. key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  222. if err != nil {
  223. t.Fatal(err)
  224. }
  225. man.Client = &acme.Client{
  226. Key: key,
  227. DirectoryURL: ca.URL,
  228. }
  229. // simulate tls.Config.GetCertificate
  230. var tlscert *tls.Certificate
  231. done := make(chan struct{})
  232. go func() {
  233. tlscert, err = man.GetCertificate(hello)
  234. close(done)
  235. }()
  236. select {
  237. case <-time.After(time.Minute):
  238. t.Fatal("man.GetCertificate took too long to return")
  239. case <-done:
  240. }
  241. if err != nil {
  242. t.Fatalf("man.GetCertificate: %v", err)
  243. }
  244. // verify the tlscert is the same we responded with from the CA stub
  245. if len(tlscert.Certificate) == 0 {
  246. t.Fatal("len(tlscert.Certificate) is 0")
  247. }
  248. cert, err := x509.ParseCertificate(tlscert.Certificate[0])
  249. if err != nil {
  250. t.Fatalf("x509.ParseCertificate: %v", err)
  251. }
  252. if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain {
  253. t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain)
  254. }
  255. // make sure token cert was removed
  256. done = make(chan struct{})
  257. go func() {
  258. for {
  259. hello := &tls.ClientHelloInfo{ServerName: tokenCertName}
  260. if _, err := man.GetCertificate(hello); err != nil {
  261. break
  262. }
  263. time.Sleep(100 * time.Millisecond)
  264. }
  265. close(done)
  266. }()
  267. select {
  268. case <-time.After(5 * time.Second):
  269. t.Error("token cert was not removed")
  270. case <-done:
  271. }
  272. }
  273. func TestAccountKeyCache(t *testing.T) {
  274. m := Manager{Cache: newMemCache()}
  275. ctx := context.Background()
  276. k1, err := m.accountKey(ctx)
  277. if err != nil {
  278. t.Fatal(err)
  279. }
  280. k2, err := m.accountKey(ctx)
  281. if err != nil {
  282. t.Fatal(err)
  283. }
  284. if !reflect.DeepEqual(k1, k2) {
  285. t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2)
  286. }
  287. }
  288. func TestCache(t *testing.T) {
  289. privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  290. if err != nil {
  291. t.Fatal(err)
  292. }
  293. tmpl := &x509.Certificate{
  294. SerialNumber: big.NewInt(1),
  295. Subject: pkix.Name{CommonName: "example.org"},
  296. NotAfter: time.Now().Add(time.Hour),
  297. }
  298. pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey)
  299. if err != nil {
  300. t.Fatal(err)
  301. }
  302. tlscert := &tls.Certificate{
  303. Certificate: [][]byte{pub},
  304. PrivateKey: privKey,
  305. }
  306. man := &Manager{Cache: newMemCache()}
  307. defer man.stopRenew()
  308. if err := man.cachePut("example.org", tlscert); err != nil {
  309. t.Fatalf("man.cachePut: %v", err)
  310. }
  311. res, err := man.cacheGet("example.org")
  312. if err != nil {
  313. t.Fatalf("man.cacheGet: %v", err)
  314. }
  315. if res == nil {
  316. t.Fatal("res is nil")
  317. }
  318. }
  319. func TestHostWhitelist(t *testing.T) {
  320. policy := HostWhitelist("example.com", "example.org", "*.example.net")
  321. tt := []struct {
  322. host string
  323. allow bool
  324. }{
  325. {"example.com", true},
  326. {"example.org", true},
  327. {"one.example.com", false},
  328. {"two.example.org", false},
  329. {"three.example.net", false},
  330. {"dummy", false},
  331. }
  332. for i, test := range tt {
  333. err := policy(nil, test.host)
  334. if err != nil && test.allow {
  335. t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err)
  336. }
  337. if err == nil && !test.allow {
  338. t.Errorf("%d: policy(%q): nil; want an error", i, test.host)
  339. }
  340. }
  341. }
  342. func TestValidCert(t *testing.T) {
  343. key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  348. if err != nil {
  349. t.Fatal(err)
  350. }
  351. key3, err := rsa.GenerateKey(rand.Reader, 512)
  352. if err != nil {
  353. t.Fatal(err)
  354. }
  355. cert1, err := dummyCert(key1.Public(), "example.org")
  356. if err != nil {
  357. t.Fatal(err)
  358. }
  359. cert2, err := dummyCert(key2.Public(), "example.org")
  360. if err != nil {
  361. t.Fatal(err)
  362. }
  363. cert3, err := dummyCert(key3.Public(), "example.org")
  364. if err != nil {
  365. t.Fatal(err)
  366. }
  367. now := time.Now()
  368. early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org")
  369. if err != nil {
  370. t.Fatal(err)
  371. }
  372. expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org")
  373. if err != nil {
  374. t.Fatal(err)
  375. }
  376. tt := []struct {
  377. domain string
  378. key crypto.Signer
  379. cert [][]byte
  380. ok bool
  381. }{
  382. {"example.org", key1, [][]byte{cert1}, true},
  383. {"example.org", key3, [][]byte{cert3}, true},
  384. {"example.org", key1, [][]byte{cert1, cert2, cert3}, true},
  385. {"example.org", key1, [][]byte{cert1, {1}}, false},
  386. {"example.org", key1, [][]byte{{1}}, false},
  387. {"example.org", key1, [][]byte{cert2}, false},
  388. {"example.org", key2, [][]byte{cert1}, false},
  389. {"example.org", key1, [][]byte{cert3}, false},
  390. {"example.org", key3, [][]byte{cert1}, false},
  391. {"example.net", key1, [][]byte{cert1}, false},
  392. {"example.org", key1, [][]byte{early}, false},
  393. {"example.org", key1, [][]byte{expired}, false},
  394. }
  395. for i, test := range tt {
  396. leaf, err := validCert(test.domain, test.cert, test.key)
  397. if err != nil && test.ok {
  398. t.Errorf("%d: err = %v", i, err)
  399. }
  400. if err == nil && !test.ok {
  401. t.Errorf("%d: err is nil", i)
  402. }
  403. if err == nil && test.ok && leaf == nil {
  404. t.Errorf("%d: leaf is nil", i)
  405. }
  406. }
  407. }