You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

455 lines
11 KiB

  1. // Copyright 2012 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package model contains the data model necessary for generating mock implementations.
  15. package model
  16. import (
  17. "encoding/gob"
  18. "fmt"
  19. "io"
  20. "reflect"
  21. "strings"
  22. )
  23. // pkgPath is the importable path for package model
  24. const pkgPath = "github.com/golang/mock/mockgen/model"
  25. // Package is a Go package. It may be a subset.
  26. type Package struct {
  27. Name string
  28. Interfaces []*Interface
  29. DotImports []string
  30. }
  31. func (pkg *Package) Print(w io.Writer) {
  32. fmt.Fprintf(w, "package %s\n", pkg.Name)
  33. for _, intf := range pkg.Interfaces {
  34. intf.Print(w)
  35. }
  36. }
  37. // Imports returns the imports needed by the Package as a set of import paths.
  38. func (pkg *Package) Imports() map[string]bool {
  39. im := make(map[string]bool)
  40. for _, intf := range pkg.Interfaces {
  41. intf.addImports(im)
  42. }
  43. return im
  44. }
  45. // Interface is a Go interface.
  46. type Interface struct {
  47. Name string
  48. Methods []*Method
  49. }
  50. func (intf *Interface) Print(w io.Writer) {
  51. fmt.Fprintf(w, "interface %s\n", intf.Name)
  52. for _, m := range intf.Methods {
  53. m.Print(w)
  54. }
  55. }
  56. func (intf *Interface) addImports(im map[string]bool) {
  57. for _, m := range intf.Methods {
  58. m.addImports(im)
  59. }
  60. }
  61. // Method is a single method of an interface.
  62. type Method struct {
  63. Name string
  64. In, Out []*Parameter
  65. Variadic *Parameter // may be nil
  66. }
  67. func (m *Method) Print(w io.Writer) {
  68. fmt.Fprintf(w, " - method %s\n", m.Name)
  69. if len(m.In) > 0 {
  70. fmt.Fprintf(w, " in:\n")
  71. for _, p := range m.In {
  72. p.Print(w)
  73. }
  74. }
  75. if m.Variadic != nil {
  76. fmt.Fprintf(w, " ...:\n")
  77. m.Variadic.Print(w)
  78. }
  79. if len(m.Out) > 0 {
  80. fmt.Fprintf(w, " out:\n")
  81. for _, p := range m.Out {
  82. p.Print(w)
  83. }
  84. }
  85. }
  86. func (m *Method) addImports(im map[string]bool) {
  87. for _, p := range m.In {
  88. p.Type.addImports(im)
  89. }
  90. if m.Variadic != nil {
  91. m.Variadic.Type.addImports(im)
  92. }
  93. for _, p := range m.Out {
  94. p.Type.addImports(im)
  95. }
  96. }
  97. // Parameter is an argument or return parameter of a method.
  98. type Parameter struct {
  99. Name string // may be empty
  100. Type Type
  101. }
  102. func (p *Parameter) Print(w io.Writer) {
  103. n := p.Name
  104. if n == "" {
  105. n = `""`
  106. }
  107. fmt.Fprintf(w, " - %v: %v\n", n, p.Type.String(nil, ""))
  108. }
  109. // Type is a Go type.
  110. type Type interface {
  111. String(pm map[string]string, pkgOverride string) string
  112. addImports(im map[string]bool)
  113. }
  114. func init() {
  115. gob.Register(&ArrayType{})
  116. gob.Register(&ChanType{})
  117. gob.Register(&FuncType{})
  118. gob.Register(&MapType{})
  119. gob.Register(&NamedType{})
  120. gob.Register(&PointerType{})
  121. // Call gob.RegisterName to make sure it has the consistent name registered
  122. // for both gob decoder and encoder.
  123. //
  124. // For a non-pointer type, gob.Register will try to get package full path by
  125. // calling rt.PkgPath() for a name to register. If your project has vendor
  126. // directory, it is possible that PkgPath will get a path like this:
  127. // ../../../vendor/github.com/golang/mock/mockgen/model
  128. gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType(""))
  129. }
  130. // ArrayType is an array or slice type.
  131. type ArrayType struct {
  132. Len int // -1 for slices, >= 0 for arrays
  133. Type Type
  134. }
  135. func (at *ArrayType) String(pm map[string]string, pkgOverride string) string {
  136. s := "[]"
  137. if at.Len > -1 {
  138. s = fmt.Sprintf("[%d]", at.Len)
  139. }
  140. return s + at.Type.String(pm, pkgOverride)
  141. }
  142. func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) }
  143. // ChanType is a channel type.
  144. type ChanType struct {
  145. Dir ChanDir // 0, 1 or 2
  146. Type Type
  147. }
  148. func (ct *ChanType) String(pm map[string]string, pkgOverride string) string {
  149. s := ct.Type.String(pm, pkgOverride)
  150. if ct.Dir == RecvDir {
  151. return "<-chan " + s
  152. }
  153. if ct.Dir == SendDir {
  154. return "chan<- " + s
  155. }
  156. return "chan " + s
  157. }
  158. func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) }
  159. // ChanDir is a channel direction.
  160. type ChanDir int
  161. const (
  162. RecvDir ChanDir = 1
  163. SendDir ChanDir = 2
  164. )
  165. // FuncType is a function type.
  166. type FuncType struct {
  167. In, Out []*Parameter
  168. Variadic *Parameter // may be nil
  169. }
  170. func (ft *FuncType) String(pm map[string]string, pkgOverride string) string {
  171. args := make([]string, len(ft.In))
  172. for i, p := range ft.In {
  173. args[i] = p.Type.String(pm, pkgOverride)
  174. }
  175. if ft.Variadic != nil {
  176. args = append(args, "..."+ft.Variadic.Type.String(pm, pkgOverride))
  177. }
  178. rets := make([]string, len(ft.Out))
  179. for i, p := range ft.Out {
  180. rets[i] = p.Type.String(pm, pkgOverride)
  181. }
  182. retString := strings.Join(rets, ", ")
  183. if nOut := len(ft.Out); nOut == 1 {
  184. retString = " " + retString
  185. } else if nOut > 1 {
  186. retString = " (" + retString + ")"
  187. }
  188. return "func(" + strings.Join(args, ", ") + ")" + retString
  189. }
  190. func (ft *FuncType) addImports(im map[string]bool) {
  191. for _, p := range ft.In {
  192. p.Type.addImports(im)
  193. }
  194. if ft.Variadic != nil {
  195. ft.Variadic.Type.addImports(im)
  196. }
  197. for _, p := range ft.Out {
  198. p.Type.addImports(im)
  199. }
  200. }
  201. // MapType is a map type.
  202. type MapType struct {
  203. Key, Value Type
  204. }
  205. func (mt *MapType) String(pm map[string]string, pkgOverride string) string {
  206. return "map[" + mt.Key.String(pm, pkgOverride) + "]" + mt.Value.String(pm, pkgOverride)
  207. }
  208. func (mt *MapType) addImports(im map[string]bool) {
  209. mt.Key.addImports(im)
  210. mt.Value.addImports(im)
  211. }
  212. // NamedType is an exported type in a package.
  213. type NamedType struct {
  214. Package string // may be empty
  215. Type string // TODO: should this be typed Type?
  216. }
  217. func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
  218. // TODO: is this right?
  219. if pkgOverride == nt.Package {
  220. return nt.Type
  221. }
  222. prefix := pm[nt.Package]
  223. if prefix != "" {
  224. return prefix + "." + nt.Type
  225. } else {
  226. return nt.Type
  227. }
  228. }
  229. func (nt *NamedType) addImports(im map[string]bool) {
  230. if nt.Package != "" {
  231. im[nt.Package] = true
  232. }
  233. }
  234. // PointerType is a pointer to another type.
  235. type PointerType struct {
  236. Type Type
  237. }
  238. func (pt *PointerType) String(pm map[string]string, pkgOverride string) string {
  239. return "*" + pt.Type.String(pm, pkgOverride)
  240. }
  241. func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) }
  242. // PredeclaredType is a predeclared type such as "int".
  243. type PredeclaredType string
  244. func (pt PredeclaredType) String(pm map[string]string, pkgOverride string) string { return string(pt) }
  245. func (pt PredeclaredType) addImports(im map[string]bool) {}
  246. // The following code is intended to be called by the program generated by ../reflect.go.
  247. func InterfaceFromInterfaceType(it reflect.Type) (*Interface, error) {
  248. if it.Kind() != reflect.Interface {
  249. return nil, fmt.Errorf("%v is not an interface", it)
  250. }
  251. intf := &Interface{}
  252. for i := 0; i < it.NumMethod(); i++ {
  253. mt := it.Method(i)
  254. // TODO: need to skip unexported methods? or just raise an error?
  255. m := &Method{
  256. Name: mt.Name,
  257. }
  258. var err error
  259. m.In, m.Variadic, m.Out, err = funcArgsFromType(mt.Type)
  260. if err != nil {
  261. return nil, err
  262. }
  263. intf.Methods = append(intf.Methods, m)
  264. }
  265. return intf, nil
  266. }
  267. // t's Kind must be a reflect.Func.
  268. func funcArgsFromType(t reflect.Type) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
  269. nin := t.NumIn()
  270. if t.IsVariadic() {
  271. nin--
  272. }
  273. var p *Parameter
  274. for i := 0; i < nin; i++ {
  275. p, err = parameterFromType(t.In(i))
  276. if err != nil {
  277. return
  278. }
  279. in = append(in, p)
  280. }
  281. if t.IsVariadic() {
  282. p, err = parameterFromType(t.In(nin).Elem())
  283. if err != nil {
  284. return
  285. }
  286. variadic = p
  287. }
  288. for i := 0; i < t.NumOut(); i++ {
  289. p, err = parameterFromType(t.Out(i))
  290. if err != nil {
  291. return
  292. }
  293. out = append(out, p)
  294. }
  295. return
  296. }
  297. func parameterFromType(t reflect.Type) (*Parameter, error) {
  298. tt, err := typeFromType(t)
  299. if err != nil {
  300. return nil, err
  301. }
  302. return &Parameter{Type: tt}, nil
  303. }
  304. var errorType = reflect.TypeOf((*error)(nil)).Elem()
  305. var byteType = reflect.TypeOf(byte(0))
  306. func typeFromType(t reflect.Type) (Type, error) {
  307. // Hack workaround for https://golang.org/issue/3853.
  308. // This explicit check should not be necessary.
  309. if t == byteType {
  310. return PredeclaredType("byte"), nil
  311. }
  312. if imp := t.PkgPath(); imp != "" {
  313. // PkgPath might return a path that includes "vendor"
  314. // These paths do not compile, so we need to remove everything
  315. // up to and including "/vendor/"
  316. // see https://github.com/golang/go/issues/12019
  317. if i := strings.LastIndex(imp, "/vendor/"); i != -1 {
  318. imp = imp[i+len("/vendor/"):]
  319. }
  320. return &NamedType{
  321. Package: imp,
  322. Type: t.Name(),
  323. }, nil
  324. }
  325. // only unnamed or predeclared types after here
  326. // Lots of types have element types. Let's do the parsing and error checking for all of them.
  327. var elemType Type
  328. switch t.Kind() {
  329. case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
  330. var err error
  331. elemType, err = typeFromType(t.Elem())
  332. if err != nil {
  333. return nil, err
  334. }
  335. }
  336. switch t.Kind() {
  337. case reflect.Array:
  338. return &ArrayType{
  339. Len: t.Len(),
  340. Type: elemType,
  341. }, nil
  342. case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
  343. reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
  344. reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
  345. return PredeclaredType(t.Kind().String()), nil
  346. case reflect.Chan:
  347. var dir ChanDir
  348. switch t.ChanDir() {
  349. case reflect.RecvDir:
  350. dir = RecvDir
  351. case reflect.SendDir:
  352. dir = SendDir
  353. }
  354. return &ChanType{
  355. Dir: dir,
  356. Type: elemType,
  357. }, nil
  358. case reflect.Func:
  359. in, variadic, out, err := funcArgsFromType(t)
  360. if err != nil {
  361. return nil, err
  362. }
  363. return &FuncType{
  364. In: in,
  365. Out: out,
  366. Variadic: variadic,
  367. }, nil
  368. case reflect.Interface:
  369. // Two special interfaces.
  370. if t.NumMethod() == 0 {
  371. return PredeclaredType("interface{}"), nil
  372. }
  373. if t == errorType {
  374. return PredeclaredType("error"), nil
  375. }
  376. case reflect.Map:
  377. kt, err := typeFromType(t.Key())
  378. if err != nil {
  379. return nil, err
  380. }
  381. return &MapType{
  382. Key: kt,
  383. Value: elemType,
  384. }, nil
  385. case reflect.Ptr:
  386. return &PointerType{
  387. Type: elemType,
  388. }, nil
  389. case reflect.Slice:
  390. return &ArrayType{
  391. Len: -1,
  392. Type: elemType,
  393. }, nil
  394. case reflect.Struct:
  395. if t.NumField() == 0 {
  396. return PredeclaredType("struct{}"), nil
  397. }
  398. }
  399. // TODO: Struct, UnsafePointer
  400. return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind())
  401. }