You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

464 lines
14 KiB

  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/detrand"
  9. "google.golang.org/protobuf/internal/pragma"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. type reflectMessageInfo struct {
  13. fields map[protoreflect.FieldNumber]*fieldInfo
  14. oneofs map[protoreflect.Name]*oneofInfo
  15. // fieldTypes contains the zero value of an enum or message field.
  16. // For lists, it contains the element type.
  17. // For maps, it contains the entry value type.
  18. fieldTypes map[protoreflect.FieldNumber]interface{}
  19. // denseFields is a subset of fields where:
  20. // 0 < fieldDesc.Number() < len(denseFields)
  21. // It provides faster access to the fieldInfo, but may be incomplete.
  22. denseFields []*fieldInfo
  23. // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
  24. rangeInfos []interface{} // either *fieldInfo or *oneofInfo
  25. getUnknown func(pointer) protoreflect.RawFields
  26. setUnknown func(pointer, protoreflect.RawFields)
  27. extensionMap func(pointer) *extensionMap
  28. nilMessage atomicNilMessage
  29. }
  30. // makeReflectFuncs generates the set of functions to support reflection.
  31. func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
  32. mi.makeKnownFieldsFunc(si)
  33. mi.makeUnknownFieldsFunc(t, si)
  34. mi.makeExtensionFieldsFunc(t, si)
  35. mi.makeFieldTypes(si)
  36. }
  37. // makeKnownFieldsFunc generates functions for operations that can be performed
  38. // on each protobuf message field. It takes in a reflect.Type representing the
  39. // Go struct and matches message fields with struct fields.
  40. //
  41. // This code assumes that the struct is well-formed and panics if there are
  42. // any discrepancies.
  43. func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
  44. mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
  45. md := mi.Desc
  46. fds := md.Fields()
  47. for i := 0; i < fds.Len(); i++ {
  48. fd := fds.Get(i)
  49. fs := si.fieldsByNumber[fd.Number()]
  50. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  51. if isOneof {
  52. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  53. }
  54. var fi fieldInfo
  55. switch {
  56. case fs.Type == nil:
  57. fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
  58. case isOneof:
  59. fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
  60. case fd.IsMap():
  61. fi = fieldInfoForMap(fd, fs, mi.Exporter)
  62. case fd.IsList():
  63. fi = fieldInfoForList(fd, fs, mi.Exporter)
  64. case fd.IsWeak():
  65. fi = fieldInfoForWeakMessage(fd, si.weakOffset)
  66. case fd.Message() != nil:
  67. fi = fieldInfoForMessage(fd, fs, mi.Exporter)
  68. default:
  69. fi = fieldInfoForScalar(fd, fs, mi.Exporter)
  70. }
  71. mi.fields[fd.Number()] = &fi
  72. }
  73. mi.oneofs = map[protoreflect.Name]*oneofInfo{}
  74. for i := 0; i < md.Oneofs().Len(); i++ {
  75. od := md.Oneofs().Get(i)
  76. mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
  77. }
  78. mi.denseFields = make([]*fieldInfo, fds.Len()*2)
  79. for i := 0; i < fds.Len(); i++ {
  80. if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
  81. mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
  82. }
  83. }
  84. for i := 0; i < fds.Len(); {
  85. fd := fds.Get(i)
  86. if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
  87. mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
  88. i += od.Fields().Len()
  89. } else {
  90. mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
  91. i++
  92. }
  93. }
  94. // Introduce instability to iteration order, but keep it deterministic.
  95. if len(mi.rangeInfos) > 1 && detrand.Bool() {
  96. i := detrand.Intn(len(mi.rangeInfos) - 1)
  97. mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
  98. }
  99. }
  100. func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
  101. switch {
  102. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
  103. // Handle as []byte.
  104. mi.getUnknown = func(p pointer) protoreflect.RawFields {
  105. if p.IsNil() {
  106. return nil
  107. }
  108. return *p.Apply(mi.unknownOffset).Bytes()
  109. }
  110. mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
  111. if p.IsNil() {
  112. panic("invalid SetUnknown on nil Message")
  113. }
  114. *p.Apply(mi.unknownOffset).Bytes() = b
  115. }
  116. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
  117. // Handle as *[]byte.
  118. mi.getUnknown = func(p pointer) protoreflect.RawFields {
  119. if p.IsNil() {
  120. return nil
  121. }
  122. bp := p.Apply(mi.unknownOffset).BytesPtr()
  123. if *bp == nil {
  124. return nil
  125. }
  126. return **bp
  127. }
  128. mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
  129. if p.IsNil() {
  130. panic("invalid SetUnknown on nil Message")
  131. }
  132. bp := p.Apply(mi.unknownOffset).BytesPtr()
  133. if *bp == nil {
  134. *bp = new([]byte)
  135. }
  136. **bp = b
  137. }
  138. default:
  139. mi.getUnknown = func(pointer) protoreflect.RawFields {
  140. return nil
  141. }
  142. mi.setUnknown = func(p pointer, _ protoreflect.RawFields) {
  143. if p.IsNil() {
  144. panic("invalid SetUnknown on nil Message")
  145. }
  146. }
  147. }
  148. }
  149. func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
  150. if si.extensionOffset.IsValid() {
  151. mi.extensionMap = func(p pointer) *extensionMap {
  152. if p.IsNil() {
  153. return (*extensionMap)(nil)
  154. }
  155. v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
  156. return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
  157. }
  158. } else {
  159. mi.extensionMap = func(pointer) *extensionMap {
  160. return (*extensionMap)(nil)
  161. }
  162. }
  163. }
  164. func (mi *MessageInfo) makeFieldTypes(si structInfo) {
  165. md := mi.Desc
  166. fds := md.Fields()
  167. for i := 0; i < fds.Len(); i++ {
  168. var ft reflect.Type
  169. fd := fds.Get(i)
  170. fs := si.fieldsByNumber[fd.Number()]
  171. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  172. if isOneof {
  173. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  174. }
  175. var isMessage bool
  176. switch {
  177. case fs.Type == nil:
  178. continue // never occurs for officially generated message types
  179. case isOneof:
  180. if fd.Enum() != nil || fd.Message() != nil {
  181. ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
  182. }
  183. case fd.IsMap():
  184. if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
  185. ft = fs.Type.Elem()
  186. }
  187. isMessage = fd.MapValue().Message() != nil
  188. case fd.IsList():
  189. if fd.Enum() != nil || fd.Message() != nil {
  190. ft = fs.Type.Elem()
  191. }
  192. isMessage = fd.Message() != nil
  193. case fd.Enum() != nil:
  194. ft = fs.Type
  195. if fd.HasPresence() && ft.Kind() == reflect.Ptr {
  196. ft = ft.Elem()
  197. }
  198. case fd.Message() != nil:
  199. ft = fs.Type
  200. if fd.IsWeak() {
  201. ft = nil
  202. }
  203. isMessage = true
  204. }
  205. if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
  206. ft = reflect.PtrTo(ft) // never occurs for officially generated message types
  207. }
  208. if ft != nil {
  209. if mi.fieldTypes == nil {
  210. mi.fieldTypes = make(map[protoreflect.FieldNumber]interface{})
  211. }
  212. mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
  213. }
  214. }
  215. }
  216. type extensionMap map[int32]ExtensionField
  217. func (m *extensionMap) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
  218. if m != nil {
  219. for _, x := range *m {
  220. xd := x.Type().TypeDescriptor()
  221. v := x.Value()
  222. if xd.IsList() && v.List().Len() == 0 {
  223. continue
  224. }
  225. if !f(xd, v) {
  226. return
  227. }
  228. }
  229. }
  230. }
  231. func (m *extensionMap) Has(xt protoreflect.ExtensionType) (ok bool) {
  232. if m == nil {
  233. return false
  234. }
  235. xd := xt.TypeDescriptor()
  236. x, ok := (*m)[int32(xd.Number())]
  237. if !ok {
  238. return false
  239. }
  240. switch {
  241. case xd.IsList():
  242. return x.Value().List().Len() > 0
  243. case xd.IsMap():
  244. return x.Value().Map().Len() > 0
  245. case xd.Message() != nil:
  246. return x.Value().Message().IsValid()
  247. }
  248. return true
  249. }
  250. func (m *extensionMap) Clear(xt protoreflect.ExtensionType) {
  251. delete(*m, int32(xt.TypeDescriptor().Number()))
  252. }
  253. func (m *extensionMap) Get(xt protoreflect.ExtensionType) protoreflect.Value {
  254. xd := xt.TypeDescriptor()
  255. if m != nil {
  256. if x, ok := (*m)[int32(xd.Number())]; ok {
  257. return x.Value()
  258. }
  259. }
  260. return xt.Zero()
  261. }
  262. func (m *extensionMap) Set(xt protoreflect.ExtensionType, v protoreflect.Value) {
  263. xd := xt.TypeDescriptor()
  264. isValid := true
  265. switch {
  266. case !xt.IsValidValue(v):
  267. isValid = false
  268. case xd.IsList():
  269. isValid = v.List().IsValid()
  270. case xd.IsMap():
  271. isValid = v.Map().IsValid()
  272. case xd.Message() != nil:
  273. isValid = v.Message().IsValid()
  274. }
  275. if !isValid {
  276. panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
  277. }
  278. if *m == nil {
  279. *m = make(map[int32]ExtensionField)
  280. }
  281. var x ExtensionField
  282. x.Set(xt, v)
  283. (*m)[int32(xd.Number())] = x
  284. }
  285. func (m *extensionMap) Mutable(xt protoreflect.ExtensionType) protoreflect.Value {
  286. xd := xt.TypeDescriptor()
  287. if xd.Kind() != protoreflect.MessageKind && xd.Kind() != protoreflect.GroupKind && !xd.IsList() && !xd.IsMap() {
  288. panic("invalid Mutable on field with non-composite type")
  289. }
  290. if x, ok := (*m)[int32(xd.Number())]; ok {
  291. return x.Value()
  292. }
  293. v := xt.New()
  294. m.Set(xt, v)
  295. return v
  296. }
  297. // MessageState is a data structure that is nested as the first field in a
  298. // concrete message. It provides a way to implement the ProtoReflect method
  299. // in an allocation-free way without needing to have a shadow Go type generated
  300. // for every message type. This technique only works using unsafe.
  301. //
  302. // Example generated code:
  303. //
  304. // type M struct {
  305. // state protoimpl.MessageState
  306. //
  307. // Field1 int32
  308. // Field2 string
  309. // Field3 *BarMessage
  310. // ...
  311. // }
  312. //
  313. // func (m *M) ProtoReflect() protoreflect.Message {
  314. // mi := &file_fizz_buzz_proto_msgInfos[5]
  315. // if protoimpl.UnsafeEnabled && m != nil {
  316. // ms := protoimpl.X.MessageStateOf(Pointer(m))
  317. // if ms.LoadMessageInfo() == nil {
  318. // ms.StoreMessageInfo(mi)
  319. // }
  320. // return ms
  321. // }
  322. // return mi.MessageOf(m)
  323. // }
  324. //
  325. // The MessageState type holds a *MessageInfo, which must be atomically set to
  326. // the message info associated with a given message instance.
  327. // By unsafely converting a *M into a *MessageState, the MessageState object
  328. // has access to all the information needed to implement protobuf reflection.
  329. // It has access to the message info as its first field, and a pointer to the
  330. // MessageState is identical to a pointer to the concrete message value.
  331. //
  332. // Requirements:
  333. // - The type M must implement protoreflect.ProtoMessage.
  334. // - The address of m must not be nil.
  335. // - The address of m and the address of m.state must be equal,
  336. // even though they are different Go types.
  337. type MessageState struct {
  338. pragma.NoUnkeyedLiterals
  339. pragma.DoNotCompare
  340. pragma.DoNotCopy
  341. atomicMessageInfo *MessageInfo
  342. }
  343. type messageState MessageState
  344. var (
  345. _ protoreflect.Message = (*messageState)(nil)
  346. _ unwrapper = (*messageState)(nil)
  347. )
  348. // messageDataType is a tuple of a pointer to the message data and
  349. // a pointer to the message type. It is a generalized way of providing a
  350. // reflective view over a message instance. The disadvantage of this approach
  351. // is the need to allocate this tuple of 16B.
  352. type messageDataType struct {
  353. p pointer
  354. mi *MessageInfo
  355. }
  356. type (
  357. messageReflectWrapper messageDataType
  358. messageIfaceWrapper messageDataType
  359. )
  360. var (
  361. _ protoreflect.Message = (*messageReflectWrapper)(nil)
  362. _ unwrapper = (*messageReflectWrapper)(nil)
  363. _ protoreflect.ProtoMessage = (*messageIfaceWrapper)(nil)
  364. _ unwrapper = (*messageIfaceWrapper)(nil)
  365. )
  366. // MessageOf returns a reflective view over a message. The input must be a
  367. // pointer to a named Go struct. If the provided type has a ProtoReflect method,
  368. // it must be implemented by calling this method.
  369. func (mi *MessageInfo) MessageOf(m interface{}) protoreflect.Message {
  370. if reflect.TypeOf(m) != mi.GoReflectType {
  371. panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
  372. }
  373. p := pointerOfIface(m)
  374. if p.IsNil() {
  375. return mi.nilMessage.Init(mi)
  376. }
  377. return &messageReflectWrapper{p, mi}
  378. }
  379. func (m *messageReflectWrapper) pointer() pointer { return m.p }
  380. func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
  381. // Reset implements the v1 proto.Message.Reset method.
  382. func (m *messageIfaceWrapper) Reset() {
  383. if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
  384. mr.Reset()
  385. return
  386. }
  387. rv := reflect.ValueOf(m.protoUnwrap())
  388. if rv.Kind() == reflect.Ptr && !rv.IsNil() {
  389. rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
  390. }
  391. }
  392. func (m *messageIfaceWrapper) ProtoReflect() protoreflect.Message {
  393. return (*messageReflectWrapper)(m)
  394. }
  395. func (m *messageIfaceWrapper) protoUnwrap() interface{} {
  396. return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
  397. }
  398. // checkField verifies that the provided field descriptor is valid.
  399. // Exactly one of the returned values is populated.
  400. func (mi *MessageInfo) checkField(fd protoreflect.FieldDescriptor) (*fieldInfo, protoreflect.ExtensionType) {
  401. var fi *fieldInfo
  402. if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
  403. fi = mi.denseFields[n]
  404. } else {
  405. fi = mi.fields[n]
  406. }
  407. if fi != nil {
  408. if fi.fieldDesc != fd {
  409. if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
  410. panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
  411. }
  412. panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
  413. }
  414. return fi, nil
  415. }
  416. if fd.IsExtension() {
  417. if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
  418. // TODO: Should this be exact containing message descriptor match?
  419. panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
  420. }
  421. if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
  422. panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
  423. }
  424. xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
  425. if !ok {
  426. panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
  427. }
  428. return nil, xtd.Type()
  429. }
  430. panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
  431. }