You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

244 rivejä
5.8 KiB

  1. // Copyright 2012 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package main
  15. // This file contains the model construction by reflection.
  16. import (
  17. "bytes"
  18. "encoding/gob"
  19. "flag"
  20. "go/build"
  21. "io/ioutil"
  22. "log"
  23. "os"
  24. "os/exec"
  25. "path/filepath"
  26. "runtime"
  27. "text/template"
  28. "github.com/golang/mock/mockgen/model"
  29. )
  30. var (
  31. progOnly = flag.Bool("prog_only", false, "(reflect mode) Only generate the reflection program; write it to stdout and exit.")
  32. execOnly = flag.String("exec_only", "", "(reflect mode) If set, execute this reflection program.")
  33. buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.")
  34. )
  35. func writeProgram(importPath string, symbols []string) ([]byte, error) {
  36. var program bytes.Buffer
  37. data := reflectData{
  38. ImportPath: importPath,
  39. Symbols: symbols,
  40. }
  41. if err := reflectProgram.Execute(&program, &data); err != nil {
  42. return nil, err
  43. }
  44. return program.Bytes(), nil
  45. }
  46. // run the given program and parse the output as a model.Package.
  47. func run(program string) (*model.Package, error) {
  48. f, err := ioutil.TempFile("", "")
  49. if err != nil {
  50. return nil, err
  51. }
  52. filename := f.Name()
  53. defer os.Remove(filename)
  54. if err := f.Close(); err != nil {
  55. return nil, err
  56. }
  57. // Run the program.
  58. cmd := exec.Command(program, "-output", filename)
  59. cmd.Stdout = os.Stdout
  60. cmd.Stderr = os.Stderr
  61. if err := cmd.Run(); err != nil {
  62. return nil, err
  63. }
  64. f, err = os.Open(filename)
  65. if err != nil {
  66. return nil, err
  67. }
  68. // Process output.
  69. var pkg model.Package
  70. if err := gob.NewDecoder(f).Decode(&pkg); err != nil {
  71. return nil, err
  72. }
  73. if err := f.Close(); err != nil {
  74. return nil, err
  75. }
  76. return &pkg, nil
  77. }
  78. // runInDir writes the given program into the given dir, runs it there, and
  79. // parses the output as a model.Package.
  80. func runInDir(program []byte, dir string) (*model.Package, error) {
  81. // We use TempDir instead of TempFile so we can control the filename.
  82. tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_")
  83. if err != nil {
  84. return nil, err
  85. }
  86. defer func() {
  87. if err := os.RemoveAll(tmpDir); err != nil {
  88. log.Printf("failed to remove temp directory: %s", err)
  89. }
  90. }()
  91. const progSource = "prog.go"
  92. var progBinary = "prog.bin"
  93. if runtime.GOOS == "windows" {
  94. // Windows won't execute a program unless it has a ".exe" suffix.
  95. progBinary += ".exe"
  96. }
  97. if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil {
  98. return nil, err
  99. }
  100. cmdArgs := []string{}
  101. cmdArgs = append(cmdArgs, "build")
  102. if *buildFlags != "" {
  103. cmdArgs = append(cmdArgs, *buildFlags)
  104. }
  105. cmdArgs = append(cmdArgs, "-o", progBinary, progSource)
  106. // Build the program.
  107. cmd := exec.Command("go", cmdArgs...)
  108. cmd.Dir = tmpDir
  109. cmd.Stdout = os.Stdout
  110. cmd.Stderr = os.Stderr
  111. if err := cmd.Run(); err != nil {
  112. return nil, err
  113. }
  114. return run(filepath.Join(tmpDir, progBinary))
  115. }
  116. func reflect(importPath string, symbols []string) (*model.Package, error) {
  117. // TODO: sanity check arguments
  118. if *execOnly != "" {
  119. return run(*execOnly)
  120. }
  121. program, err := writeProgram(importPath, symbols)
  122. if err != nil {
  123. return nil, err
  124. }
  125. if *progOnly {
  126. os.Stdout.Write(program)
  127. os.Exit(0)
  128. }
  129. wd, _ := os.Getwd()
  130. // Try to run the program in the same directory as the input package.
  131. if p, err := build.Import(importPath, wd, build.FindOnly); err == nil {
  132. dir := p.Dir
  133. if p, err := runInDir(program, dir); err == nil {
  134. return p, nil
  135. }
  136. }
  137. // Since that didn't work, try to run it in the current working directory.
  138. if p, err := runInDir(program, wd); err == nil {
  139. return p, nil
  140. }
  141. // Since that didn't work, try to run it in a standard temp directory.
  142. return runInDir(program, "")
  143. }
  144. type reflectData struct {
  145. ImportPath string
  146. Symbols []string
  147. }
  148. // This program reflects on an interface value, and prints the
  149. // gob encoding of a model.Package to standard output.
  150. // JSON doesn't work because of the model.Type interface.
  151. var reflectProgram = template.Must(template.New("program").Parse(`
  152. package main
  153. import (
  154. "encoding/gob"
  155. "flag"
  156. "fmt"
  157. "os"
  158. "path"
  159. "reflect"
  160. "github.com/golang/mock/mockgen/model"
  161. pkg_ {{printf "%q" .ImportPath}}
  162. )
  163. var output = flag.String("output", "", "The output file name, or empty to use stdout.")
  164. func main() {
  165. flag.Parse()
  166. its := []struct{
  167. sym string
  168. typ reflect.Type
  169. }{
  170. {{range .Symbols}}
  171. { {{printf "%q" .}}, reflect.TypeOf((*pkg_.{{.}})(nil)).Elem()},
  172. {{end}}
  173. }
  174. pkg := &model.Package{
  175. // NOTE: This behaves contrary to documented behaviour if the
  176. // package name is not the final component of the import path.
  177. // The reflect package doesn't expose the package name, though.
  178. Name: path.Base({{printf "%q" .ImportPath}}),
  179. }
  180. for _, it := range its {
  181. intf, err := model.InterfaceFromInterfaceType(it.typ)
  182. if err != nil {
  183. fmt.Fprintf(os.Stderr, "Reflection: %v\n", err)
  184. os.Exit(1)
  185. }
  186. intf.Name = it.sym
  187. pkg.Interfaces = append(pkg.Interfaces, intf)
  188. }
  189. outfile := os.Stdout
  190. if len(*output) != 0 {
  191. var err error
  192. outfile, err = os.Create(*output)
  193. if err != nil {
  194. fmt.Fprintf(os.Stderr, "failed to open output file %q", *output)
  195. }
  196. defer func() {
  197. if err := outfile.Close(); err != nil {
  198. fmt.Fprintf(os.Stderr, "failed to close output file %q", *output)
  199. os.Exit(1)
  200. }
  201. }()
  202. }
  203. if err := gob.NewEncoder(outfile).Encode(pkg); err != nil {
  204. fmt.Fprintf(os.Stderr, "gob encode: %v\n", err)
  205. os.Exit(1)
  206. }
  207. }
  208. `))