25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

524 lines
14 KiB

  1. // Copyright 2012 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package main
  15. // This file contains the model construction by parsing source files.
  16. import (
  17. "errors"
  18. "flag"
  19. "fmt"
  20. "go/ast"
  21. "go/build"
  22. "go/parser"
  23. "go/token"
  24. "log"
  25. "path"
  26. "path/filepath"
  27. "strconv"
  28. "strings"
  29. "github.com/golang/mock/mockgen/model"
  30. "golang.org/x/tools/go/packages"
  31. )
  32. var (
  33. imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
  34. auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
  35. )
  36. // TODO: simplify error reporting
  37. func parseFile(source string) (*model.Package, error) {
  38. srcDir, err := filepath.Abs(filepath.Dir(source))
  39. if err != nil {
  40. return nil, fmt.Errorf("failed getting source directory: %v", err)
  41. }
  42. cfg := &packages.Config{Mode: packages.LoadSyntax, Tests: true}
  43. pkgs, err := packages.Load(cfg, "file="+source)
  44. if err != nil {
  45. return nil, err
  46. }
  47. if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 {
  48. return nil, errors.New("loading package failed")
  49. }
  50. packageImport := pkgs[0].PkgPath
  51. // It is illegal to import a _test package.
  52. packageImport = strings.TrimSuffix(packageImport, "_test")
  53. fs := token.NewFileSet()
  54. file, err := parser.ParseFile(fs, source, nil, 0)
  55. if err != nil {
  56. return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
  57. }
  58. p := &fileParser{
  59. fileSet: fs,
  60. imports: make(map[string]string),
  61. importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
  62. auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
  63. srcDir: srcDir,
  64. }
  65. // Handle -imports.
  66. dotImports := make(map[string]bool)
  67. if *imports != "" {
  68. for _, kv := range strings.Split(*imports, ",") {
  69. eq := strings.Index(kv, "=")
  70. k, v := kv[:eq], kv[eq+1:]
  71. if k == "." {
  72. // TODO: Catch dupes?
  73. dotImports[v] = true
  74. } else {
  75. // TODO: Catch dupes?
  76. p.imports[k] = v
  77. }
  78. }
  79. }
  80. // Handle -aux_files.
  81. if err := p.parseAuxFiles(*auxFiles); err != nil {
  82. return nil, err
  83. }
  84. p.addAuxInterfacesFromFile(packageImport, file) // this file
  85. pkg, err := p.parseFile(packageImport, file)
  86. if err != nil {
  87. return nil, err
  88. }
  89. for path := range dotImports {
  90. pkg.DotImports = append(pkg.DotImports, path)
  91. }
  92. return pkg, nil
  93. }
  94. type fileParser struct {
  95. fileSet *token.FileSet
  96. imports map[string]string // package name => import path
  97. importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
  98. auxFiles []*ast.File
  99. auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
  100. srcDir string
  101. }
  102. func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
  103. ps := p.fileSet.Position(pos)
  104. format = "%s:%d:%d: " + format
  105. args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
  106. return fmt.Errorf(format, args...)
  107. }
  108. func (p *fileParser) parseAuxFiles(auxFiles string) error {
  109. auxFiles = strings.TrimSpace(auxFiles)
  110. if auxFiles == "" {
  111. return nil
  112. }
  113. for _, kv := range strings.Split(auxFiles, ",") {
  114. parts := strings.SplitN(kv, "=", 2)
  115. if len(parts) != 2 {
  116. return fmt.Errorf("bad aux file spec: %v", kv)
  117. }
  118. pkg, fpath := parts[0], parts[1]
  119. file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
  120. if err != nil {
  121. return err
  122. }
  123. p.auxFiles = append(p.auxFiles, file)
  124. p.addAuxInterfacesFromFile(pkg, file)
  125. }
  126. return nil
  127. }
  128. func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
  129. if _, ok := p.auxInterfaces[pkg]; !ok {
  130. p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
  131. }
  132. for ni := range iterInterfaces(file) {
  133. p.auxInterfaces[pkg][ni.name.Name] = ni.it
  134. }
  135. }
  136. // parseFile loads all file imports and auxiliary files import into the
  137. // fileParser, parses all file interfaces and returns package model.
  138. func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
  139. allImports, dotImports := importsOfFile(file)
  140. // Don't stomp imports provided by -imports. Those should take precedence.
  141. for pkg, path := range allImports {
  142. if _, ok := p.imports[pkg]; !ok {
  143. p.imports[pkg] = path
  144. }
  145. }
  146. // Add imports from auxiliary files, which might be needed for embedded interfaces.
  147. // Don't stomp any other imports.
  148. for _, f := range p.auxFiles {
  149. auxImports, _ := importsOfFile(f)
  150. for pkg, path := range auxImports {
  151. if _, ok := p.imports[pkg]; !ok {
  152. p.imports[pkg] = path
  153. }
  154. }
  155. }
  156. var is []*model.Interface
  157. for ni := range iterInterfaces(file) {
  158. i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
  159. if err != nil {
  160. return nil, err
  161. }
  162. is = append(is, i)
  163. }
  164. return &model.Package{
  165. Name: file.Name.String(),
  166. Interfaces: is,
  167. DotImports: dotImports,
  168. }, nil
  169. }
  170. // parsePackage loads package specified by path, parses it and populates
  171. // corresponding imports and importedInterfaces into the fileParser.
  172. func (p *fileParser) parsePackage(path string) error {
  173. var pkgs map[string]*ast.Package
  174. if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil {
  175. return err
  176. } else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil {
  177. return err
  178. }
  179. for _, pkg := range pkgs {
  180. file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
  181. if _, ok := p.importedInterfaces[path]; !ok {
  182. p.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
  183. }
  184. for ni := range iterInterfaces(file) {
  185. p.importedInterfaces[path][ni.name.Name] = ni.it
  186. }
  187. imports, _ := importsOfFile(file)
  188. for pkgName, pkgPath := range imports {
  189. if _, ok := p.imports[pkgName]; !ok {
  190. p.imports[pkgName] = pkgPath
  191. }
  192. }
  193. }
  194. return nil
  195. }
  196. func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
  197. intf := &model.Interface{Name: name}
  198. for _, field := range it.Methods.List {
  199. switch v := field.Type.(type) {
  200. case *ast.FuncType:
  201. if nn := len(field.Names); nn != 1 {
  202. return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
  203. }
  204. m := &model.Method{
  205. Name: field.Names[0].String(),
  206. }
  207. var err error
  208. m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
  209. if err != nil {
  210. return nil, err
  211. }
  212. intf.Methods = append(intf.Methods, m)
  213. case *ast.Ident:
  214. // Embedded interface in this package.
  215. ei := p.auxInterfaces[pkg][v.String()]
  216. if ei == nil {
  217. if ei = p.importedInterfaces[pkg][v.String()]; ei == nil {
  218. return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
  219. }
  220. }
  221. eintf, err := p.parseInterface(v.String(), pkg, ei)
  222. if err != nil {
  223. return nil, err
  224. }
  225. // Copy the methods.
  226. // TODO: apply shadowing rules.
  227. for _, m := range eintf.Methods {
  228. intf.Methods = append(intf.Methods, m)
  229. }
  230. case *ast.SelectorExpr:
  231. // Embedded interface in another package.
  232. fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
  233. epkg, ok := p.imports[fpkg]
  234. if !ok {
  235. return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
  236. }
  237. ei := p.auxInterfaces[fpkg][sel]
  238. if ei == nil {
  239. fpkg = epkg
  240. if _, ok = p.importedInterfaces[epkg]; !ok {
  241. if err := p.parsePackage(epkg); err != nil {
  242. return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err)
  243. }
  244. }
  245. if ei = p.importedInterfaces[epkg][sel]; ei == nil {
  246. return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
  247. }
  248. }
  249. eintf, err := p.parseInterface(sel, fpkg, ei)
  250. if err != nil {
  251. return nil, err
  252. }
  253. // Copy the methods.
  254. // TODO: apply shadowing rules.
  255. for _, m := range eintf.Methods {
  256. intf.Methods = append(intf.Methods, m)
  257. }
  258. default:
  259. return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
  260. }
  261. }
  262. return intf, nil
  263. }
  264. func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
  265. if f.Params != nil {
  266. regParams := f.Params.List
  267. if isVariadic(f) {
  268. n := len(regParams)
  269. varParams := regParams[n-1:]
  270. regParams = regParams[:n-1]
  271. vp, err := p.parseFieldList(pkg, varParams)
  272. if err != nil {
  273. return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
  274. }
  275. variadic = vp[0]
  276. }
  277. in, err = p.parseFieldList(pkg, regParams)
  278. if err != nil {
  279. return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
  280. }
  281. }
  282. if f.Results != nil {
  283. out, err = p.parseFieldList(pkg, f.Results.List)
  284. if err != nil {
  285. return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
  286. }
  287. }
  288. return
  289. }
  290. func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
  291. nf := 0
  292. for _, f := range fields {
  293. nn := len(f.Names)
  294. if nn == 0 {
  295. nn = 1 // anonymous parameter
  296. }
  297. nf += nn
  298. }
  299. if nf == 0 {
  300. return nil, nil
  301. }
  302. ps := make([]*model.Parameter, nf)
  303. i := 0 // destination index
  304. for _, f := range fields {
  305. t, err := p.parseType(pkg, f.Type)
  306. if err != nil {
  307. return nil, err
  308. }
  309. if len(f.Names) == 0 {
  310. // anonymous arg
  311. ps[i] = &model.Parameter{Type: t}
  312. i++
  313. continue
  314. }
  315. for _, name := range f.Names {
  316. ps[i] = &model.Parameter{Name: name.Name, Type: t}
  317. i++
  318. }
  319. }
  320. return ps, nil
  321. }
  322. func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
  323. switch v := typ.(type) {
  324. case *ast.ArrayType:
  325. ln := -1
  326. if v.Len != nil {
  327. x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
  328. if err != nil {
  329. return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
  330. }
  331. ln = x
  332. }
  333. t, err := p.parseType(pkg, v.Elt)
  334. if err != nil {
  335. return nil, err
  336. }
  337. return &model.ArrayType{Len: ln, Type: t}, nil
  338. case *ast.ChanType:
  339. t, err := p.parseType(pkg, v.Value)
  340. if err != nil {
  341. return nil, err
  342. }
  343. var dir model.ChanDir
  344. if v.Dir == ast.SEND {
  345. dir = model.SendDir
  346. }
  347. if v.Dir == ast.RECV {
  348. dir = model.RecvDir
  349. }
  350. return &model.ChanType{Dir: dir, Type: t}, nil
  351. case *ast.Ellipsis:
  352. // assume we're parsing a variadic argument
  353. return p.parseType(pkg, v.Elt)
  354. case *ast.FuncType:
  355. in, variadic, out, err := p.parseFunc(pkg, v)
  356. if err != nil {
  357. return nil, err
  358. }
  359. return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
  360. case *ast.Ident:
  361. if v.IsExported() {
  362. // `pkg` may be an aliased imported pkg
  363. // if so, patch the import w/ the fully qualified import
  364. maybeImportedPkg, ok := p.imports[pkg]
  365. if ok {
  366. pkg = maybeImportedPkg
  367. }
  368. // assume type in this package
  369. return &model.NamedType{Package: pkg, Type: v.Name}, nil
  370. }
  371. // assume predeclared type
  372. return model.PredeclaredType(v.Name), nil
  373. case *ast.InterfaceType:
  374. if v.Methods != nil && len(v.Methods.List) > 0 {
  375. return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
  376. }
  377. return model.PredeclaredType("interface{}"), nil
  378. case *ast.MapType:
  379. key, err := p.parseType(pkg, v.Key)
  380. if err != nil {
  381. return nil, err
  382. }
  383. value, err := p.parseType(pkg, v.Value)
  384. if err != nil {
  385. return nil, err
  386. }
  387. return &model.MapType{Key: key, Value: value}, nil
  388. case *ast.SelectorExpr:
  389. pkgName := v.X.(*ast.Ident).String()
  390. pkg, ok := p.imports[pkgName]
  391. if !ok {
  392. return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
  393. }
  394. return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
  395. case *ast.StarExpr:
  396. t, err := p.parseType(pkg, v.X)
  397. if err != nil {
  398. return nil, err
  399. }
  400. return &model.PointerType{Type: t}, nil
  401. case *ast.StructType:
  402. if v.Fields != nil && len(v.Fields.List) > 0 {
  403. return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
  404. }
  405. return model.PredeclaredType("struct{}"), nil
  406. }
  407. return nil, fmt.Errorf("don't know how to parse type %T", typ)
  408. }
  409. // importsOfFile returns a map of package name to import path
  410. // of the imports in file.
  411. func importsOfFile(file *ast.File) (normalImports map[string]string, dotImports []string) {
  412. normalImports = make(map[string]string)
  413. dotImports = make([]string, 0)
  414. for _, is := range file.Imports {
  415. var pkgName string
  416. importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
  417. if is.Name != nil {
  418. // Named imports are always certain.
  419. if is.Name.Name == "_" {
  420. continue
  421. }
  422. pkgName = is.Name.Name
  423. } else {
  424. pkg, err := build.Import(importPath, "", 0)
  425. if err != nil {
  426. // Fallback to import path suffix. Note that this is uncertain.
  427. _, last := path.Split(importPath)
  428. // If the last path component has dots, the first dot-delimited
  429. // field is used as the name.
  430. pkgName = strings.SplitN(last, ".", 2)[0]
  431. } else {
  432. pkgName = pkg.Name
  433. }
  434. }
  435. if pkgName == "." {
  436. dotImports = append(dotImports, importPath)
  437. } else {
  438. if _, ok := normalImports[pkgName]; ok {
  439. log.Fatalf("imported package collision: %q imported twice", pkgName)
  440. }
  441. normalImports[pkgName] = importPath
  442. }
  443. }
  444. return
  445. }
  446. type namedInterface struct {
  447. name *ast.Ident
  448. it *ast.InterfaceType
  449. }
  450. // Create an iterator over all interfaces in file.
  451. func iterInterfaces(file *ast.File) <-chan namedInterface {
  452. ch := make(chan namedInterface)
  453. go func() {
  454. for _, decl := range file.Decls {
  455. gd, ok := decl.(*ast.GenDecl)
  456. if !ok || gd.Tok != token.TYPE {
  457. continue
  458. }
  459. for _, spec := range gd.Specs {
  460. ts, ok := spec.(*ast.TypeSpec)
  461. if !ok {
  462. continue
  463. }
  464. it, ok := ts.Type.(*ast.InterfaceType)
  465. if !ok {
  466. continue
  467. }
  468. ch <- namedInterface{ts.Name, it}
  469. }
  470. }
  471. close(ch)
  472. }()
  473. return ch
  474. }
  475. // isVariadic returns whether the function is variadic.
  476. func isVariadic(f *ast.FuncType) bool {
  477. nargs := len(f.Params.List)
  478. if nargs == 0 {
  479. return false
  480. }
  481. _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
  482. return ok
  483. }