You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

206 rivejä
6.0 KiB

  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mitm
  15. import (
  16. "crypto/tls"
  17. "crypto/x509"
  18. "net"
  19. "reflect"
  20. "testing"
  21. "time"
  22. )
  23. func TestMITM(t *testing.T) {
  24. ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour)
  25. if err != nil {
  26. t.Fatalf("NewAuthority(): got %v, want no error", err)
  27. }
  28. c, err := NewConfig(ca, priv)
  29. if err != nil {
  30. t.Fatalf("NewConfig(): got %v, want no error", err)
  31. }
  32. c.SetValidity(20 * time.Hour)
  33. c.SetOrganization("Test Organization")
  34. protos := []string{"http/1.1"}
  35. conf := c.TLS()
  36. if got := conf.NextProtos; !reflect.DeepEqual(got, protos) {
  37. t.Errorf("conf.NextProtos: got %v, want %v", got, protos)
  38. }
  39. if conf.InsecureSkipVerify {
  40. t.Error("conf.InsecureSkipVerify: got true, want false")
  41. }
  42. // Simulate a TLS connection without SNI.
  43. clientHello := &tls.ClientHelloInfo{
  44. ServerName: "",
  45. }
  46. if _, err := conf.GetCertificate(clientHello); err == nil {
  47. t.Fatal("conf.GetCertificate(): got nil, want error")
  48. }
  49. // Simulate a TLS connection with SNI.
  50. clientHello.ServerName = "example.com"
  51. tlsc, err := conf.GetCertificate(clientHello)
  52. if err != nil {
  53. t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
  54. }
  55. x509c := tlsc.Leaf
  56. if got, want := x509c.Subject.CommonName, "example.com"; got != want {
  57. t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
  58. }
  59. c.SkipTLSVerify(true)
  60. conf = c.TLSForHost("example.com")
  61. if got := conf.NextProtos; !reflect.DeepEqual(got, protos) {
  62. t.Errorf("conf.NextProtos: got %v, want %v", got, protos)
  63. }
  64. if !conf.InsecureSkipVerify {
  65. t.Error("conf.InsecureSkipVerify: got false, want true")
  66. }
  67. // Set SNI, takes precedence over host.
  68. clientHello.ServerName = "google.com"
  69. tlsc, err = conf.GetCertificate(clientHello)
  70. if err != nil {
  71. t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
  72. }
  73. x509c = tlsc.Leaf
  74. if got, want := x509c.Subject.CommonName, "google.com"; got != want {
  75. t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
  76. }
  77. // Reset SNI to fallback to hostname.
  78. clientHello.ServerName = ""
  79. tlsc, err = conf.GetCertificate(clientHello)
  80. if err != nil {
  81. t.Fatalf("conf.GetCertificate(): got %v, want no error", err)
  82. }
  83. x509c = tlsc.Leaf
  84. if got, want := x509c.Subject.CommonName, "example.com"; got != want {
  85. t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want)
  86. }
  87. }
  88. func TestCert(t *testing.T) {
  89. ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour)
  90. if err != nil {
  91. t.Fatalf("NewAuthority(): got %v, want no error", err)
  92. }
  93. c, err := NewConfig(ca, priv)
  94. if err != nil {
  95. t.Fatalf("NewConfig(): got %v, want no error", err)
  96. }
  97. tlsc, err := c.cert("example.com")
  98. if err != nil {
  99. t.Fatalf("c.cert(%q): got %v, want no error", "example.com:8080", err)
  100. }
  101. if tlsc.Certificate == nil {
  102. t.Error("tlsc.Certificate: got nil, want certificate bytes")
  103. }
  104. if tlsc.PrivateKey == nil {
  105. t.Error("tlsc.PrivateKey: got nil, want private key")
  106. }
  107. x509c := tlsc.Leaf
  108. if x509c == nil {
  109. t.Fatal("x509c: got nil, want *x509.Certificate")
  110. }
  111. if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) >= 0 {
  112. t.Errorf("x509c.SerialNumber: got %v, want <= MaxSerialNumber", got)
  113. }
  114. if got, want := x509c.Subject.CommonName, "example.com"; got != want {
  115. t.Errorf("X509c.Subject.CommonName: got %q, want %q", got, want)
  116. }
  117. if err := x509c.VerifyHostname("example.com"); err != nil {
  118. t.Errorf("x509c.VerifyHostname(%q): got %v, want no error", "example.com", err)
  119. }
  120. if got, want := x509c.Subject.Organization, []string{"Martian Proxy"}; !reflect.DeepEqual(got, want) {
  121. t.Errorf("x509c.Subject.Organization: got %v, want %v", got, want)
  122. }
  123. if got := x509c.SubjectKeyId; got == nil {
  124. t.Error("x509c.SubjectKeyId: got nothing, want key ID")
  125. }
  126. if !x509c.BasicConstraintsValid {
  127. t.Error("x509c.BasicConstraintsValid: got false, want true")
  128. }
  129. if got, want := x509c.KeyUsage, x509.KeyUsageKeyEncipherment; got&want == 0 {
  130. t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageKeyEncipherment")
  131. }
  132. if got, want := x509c.KeyUsage, x509.KeyUsageDigitalSignature; got&want == 0 {
  133. t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageDigitalSignature")
  134. }
  135. want := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
  136. if got := x509c.ExtKeyUsage; !reflect.DeepEqual(got, want) {
  137. t.Errorf("x509c.ExtKeyUsage: got %v, want %v", got, want)
  138. }
  139. if got, want := x509c.DNSNames, []string{"example.com"}; !reflect.DeepEqual(got, want) {
  140. t.Errorf("x509c.DNSNames: got %v, want %v", got, want)
  141. }
  142. before := time.Now().Add(-2 * time.Hour)
  143. if got := x509c.NotBefore; before.After(got) {
  144. t.Errorf("x509c.NotBefore: got %v, want after %v", got, before)
  145. }
  146. after := time.Now().Add(2 * time.Hour)
  147. if got := x509c.NotAfter; !after.After(got) {
  148. t.Errorf("x509c.NotAfter: got %v, want before %v", got, want)
  149. }
  150. // Retrieve cached certificate.
  151. tlsc2, err := c.cert("example.com")
  152. if err != nil {
  153. t.Fatalf("c.cert(%q): got %v, want no error", "example.com", err)
  154. }
  155. if tlsc != tlsc2 {
  156. t.Error("tlsc2: got new certificate, want cached certificate")
  157. }
  158. // TLS certificate for IP.
  159. tlsc, err = c.cert("10.0.0.1:8227")
  160. if err != nil {
  161. t.Fatalf("c.cert(%q): got %v, want no error", "10.0.0.1:8227", err)
  162. }
  163. x509c = tlsc.Leaf
  164. if got, want := len(x509c.IPAddresses), 1; got != want {
  165. t.Fatalf("len(x509c.IPAddresses): got %d, want %d", got, want)
  166. }
  167. if got, want := x509c.IPAddresses[0], net.ParseIP("10.0.0.1"); !got.Equal(want) {
  168. t.Fatalf("x509c.IPAddresses: got %v, want %v", got, want)
  169. }
  170. }