You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

389 lines
10 KiB

  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "reflect"
  7. "sort"
  8. "google.golang.org/protobuf/encoding/protowire"
  9. "google.golang.org/protobuf/internal/genid"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. type mapInfo struct {
  13. goType reflect.Type
  14. keyWiretag uint64
  15. valWiretag uint64
  16. keyFuncs valueCoderFuncs
  17. valFuncs valueCoderFuncs
  18. keyZero protoreflect.Value
  19. keyKind protoreflect.Kind
  20. conv *mapConverter
  21. }
  22. func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
  23. // TODO: Consider generating specialized map coders.
  24. keyField := fd.MapKey()
  25. valField := fd.MapValue()
  26. keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
  27. valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
  28. keyFuncs := encoderFuncsForValue(keyField)
  29. valFuncs := encoderFuncsForValue(valField)
  30. conv := newMapConverter(ft, fd)
  31. mapi := &mapInfo{
  32. goType: ft,
  33. keyWiretag: keyWiretag,
  34. valWiretag: valWiretag,
  35. keyFuncs: keyFuncs,
  36. valFuncs: valFuncs,
  37. keyZero: keyField.Default(),
  38. keyKind: keyField.Kind(),
  39. conv: conv,
  40. }
  41. if valField.Kind() == protoreflect.MessageKind {
  42. valueMessage = getMessageInfo(ft.Elem())
  43. }
  44. funcs = pointerCoderFuncs{
  45. size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
  46. return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
  47. },
  48. marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  49. return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
  50. },
  51. unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
  52. mp := p.AsValueOf(ft)
  53. if mp.Elem().IsNil() {
  54. mp.Elem().Set(reflect.MakeMap(mapi.goType))
  55. }
  56. if f.mi == nil {
  57. return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
  58. } else {
  59. return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
  60. }
  61. },
  62. }
  63. switch valField.Kind() {
  64. case protoreflect.MessageKind:
  65. funcs.merge = mergeMapOfMessage
  66. case protoreflect.BytesKind:
  67. funcs.merge = mergeMapOfBytes
  68. default:
  69. funcs.merge = mergeMap
  70. }
  71. if valFuncs.isInit != nil {
  72. funcs.isInit = func(p pointer, f *coderFieldInfo) error {
  73. return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
  74. }
  75. }
  76. return valueMessage, funcs
  77. }
  78. const (
  79. mapKeyTagSize = 1 // field 1, tag size 1.
  80. mapValTagSize = 1 // field 2, tag size 2.
  81. )
  82. func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
  83. if mapv.Len() == 0 {
  84. return 0
  85. }
  86. n := 0
  87. iter := mapRange(mapv)
  88. for iter.Next() {
  89. key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
  90. keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  91. var valSize int
  92. value := mapi.conv.valConv.PBValueOf(iter.Value())
  93. if f.mi == nil {
  94. valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
  95. } else {
  96. p := pointerOfValue(iter.Value())
  97. valSize += mapValTagSize
  98. valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
  99. }
  100. n += f.tagsize + protowire.SizeBytes(keySize+valSize)
  101. }
  102. return n
  103. }
  104. func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  105. if wtyp != protowire.BytesType {
  106. return out, errUnknown
  107. }
  108. b, n := protowire.ConsumeBytes(b)
  109. if n < 0 {
  110. return out, errDecode
  111. }
  112. var (
  113. key = mapi.keyZero
  114. val = mapi.conv.valConv.New()
  115. )
  116. for len(b) > 0 {
  117. num, wtyp, n := protowire.ConsumeTag(b)
  118. if n < 0 {
  119. return out, errDecode
  120. }
  121. if num > protowire.MaxValidNumber {
  122. return out, errDecode
  123. }
  124. b = b[n:]
  125. err := errUnknown
  126. switch num {
  127. case genid.MapEntry_Key_field_number:
  128. var v protoreflect.Value
  129. var o unmarshalOutput
  130. v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  131. if err != nil {
  132. break
  133. }
  134. key = v
  135. n = o.n
  136. case genid.MapEntry_Value_field_number:
  137. var v protoreflect.Value
  138. var o unmarshalOutput
  139. v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
  140. if err != nil {
  141. break
  142. }
  143. val = v
  144. n = o.n
  145. }
  146. if err == errUnknown {
  147. n = protowire.ConsumeFieldValue(num, wtyp, b)
  148. if n < 0 {
  149. return out, errDecode
  150. }
  151. } else if err != nil {
  152. return out, err
  153. }
  154. b = b[n:]
  155. }
  156. mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
  157. out.n = n
  158. return out, nil
  159. }
  160. func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  161. if wtyp != protowire.BytesType {
  162. return out, errUnknown
  163. }
  164. b, n := protowire.ConsumeBytes(b)
  165. if n < 0 {
  166. return out, errDecode
  167. }
  168. var (
  169. key = mapi.keyZero
  170. val = reflect.New(f.mi.GoReflectType.Elem())
  171. )
  172. for len(b) > 0 {
  173. num, wtyp, n := protowire.ConsumeTag(b)
  174. if n < 0 {
  175. return out, errDecode
  176. }
  177. if num > protowire.MaxValidNumber {
  178. return out, errDecode
  179. }
  180. b = b[n:]
  181. err := errUnknown
  182. switch num {
  183. case 1:
  184. var v protoreflect.Value
  185. var o unmarshalOutput
  186. v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
  187. if err != nil {
  188. break
  189. }
  190. key = v
  191. n = o.n
  192. case 2:
  193. if wtyp != protowire.BytesType {
  194. break
  195. }
  196. var v []byte
  197. v, n = protowire.ConsumeBytes(b)
  198. if n < 0 {
  199. return out, errDecode
  200. }
  201. var o unmarshalOutput
  202. o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
  203. if o.initialized {
  204. // Consider this map item initialized so long as we see
  205. // an initialized value.
  206. out.initialized = true
  207. }
  208. }
  209. if err == errUnknown {
  210. n = protowire.ConsumeFieldValue(num, wtyp, b)
  211. if n < 0 {
  212. return out, errDecode
  213. }
  214. } else if err != nil {
  215. return out, err
  216. }
  217. b = b[n:]
  218. }
  219. mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
  220. out.n = n
  221. return out, nil
  222. }
  223. func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  224. if f.mi == nil {
  225. key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
  226. val := mapi.conv.valConv.PBValueOf(valrv)
  227. size := 0
  228. size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  229. size += mapi.valFuncs.size(val, mapValTagSize, opts)
  230. b = protowire.AppendVarint(b, uint64(size))
  231. b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
  232. if err != nil {
  233. return nil, err
  234. }
  235. return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
  236. } else {
  237. key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
  238. val := pointerOfValue(valrv)
  239. valSize := f.mi.sizePointer(val, opts)
  240. size := 0
  241. size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
  242. size += mapValTagSize + protowire.SizeBytes(valSize)
  243. b = protowire.AppendVarint(b, uint64(size))
  244. b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
  245. if err != nil {
  246. return nil, err
  247. }
  248. b = protowire.AppendVarint(b, mapi.valWiretag)
  249. b = protowire.AppendVarint(b, uint64(valSize))
  250. return f.mi.marshalAppendPointer(b, val, opts)
  251. }
  252. }
  253. func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  254. if mapv.Len() == 0 {
  255. return b, nil
  256. }
  257. if opts.Deterministic() {
  258. return appendMapDeterministic(b, mapv, mapi, f, opts)
  259. }
  260. iter := mapRange(mapv)
  261. for iter.Next() {
  262. var err error
  263. b = protowire.AppendVarint(b, f.wiretag)
  264. b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
  265. if err != nil {
  266. return b, err
  267. }
  268. }
  269. return b, nil
  270. }
  271. func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  272. keys := mapv.MapKeys()
  273. sort.Slice(keys, func(i, j int) bool {
  274. switch keys[i].Kind() {
  275. case reflect.Bool:
  276. return !keys[i].Bool() && keys[j].Bool()
  277. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  278. return keys[i].Int() < keys[j].Int()
  279. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
  280. return keys[i].Uint() < keys[j].Uint()
  281. case reflect.Float32, reflect.Float64:
  282. return keys[i].Float() < keys[j].Float()
  283. case reflect.String:
  284. return keys[i].String() < keys[j].String()
  285. default:
  286. panic("invalid kind: " + keys[i].Kind().String())
  287. }
  288. })
  289. for _, key := range keys {
  290. var err error
  291. b = protowire.AppendVarint(b, f.wiretag)
  292. b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
  293. if err != nil {
  294. return b, err
  295. }
  296. }
  297. return b, nil
  298. }
  299. func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
  300. if mi := f.mi; mi != nil {
  301. mi.init()
  302. if !mi.needsInitCheck {
  303. return nil
  304. }
  305. iter := mapRange(mapv)
  306. for iter.Next() {
  307. val := pointerOfValue(iter.Value())
  308. if err := mi.checkInitializedPointer(val); err != nil {
  309. return err
  310. }
  311. }
  312. } else {
  313. iter := mapRange(mapv)
  314. for iter.Next() {
  315. val := mapi.conv.valConv.PBValueOf(iter.Value())
  316. if err := mapi.valFuncs.isInit(val); err != nil {
  317. return err
  318. }
  319. }
  320. }
  321. return nil
  322. }
  323. func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  324. dstm := dst.AsValueOf(f.ft).Elem()
  325. srcm := src.AsValueOf(f.ft).Elem()
  326. if srcm.Len() == 0 {
  327. return
  328. }
  329. if dstm.IsNil() {
  330. dstm.Set(reflect.MakeMap(f.ft))
  331. }
  332. iter := mapRange(srcm)
  333. for iter.Next() {
  334. dstm.SetMapIndex(iter.Key(), iter.Value())
  335. }
  336. }
  337. func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  338. dstm := dst.AsValueOf(f.ft).Elem()
  339. srcm := src.AsValueOf(f.ft).Elem()
  340. if srcm.Len() == 0 {
  341. return
  342. }
  343. if dstm.IsNil() {
  344. dstm.Set(reflect.MakeMap(f.ft))
  345. }
  346. iter := mapRange(srcm)
  347. for iter.Next() {
  348. dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
  349. }
  350. }
  351. func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  352. dstm := dst.AsValueOf(f.ft).Elem()
  353. srcm := src.AsValueOf(f.ft).Elem()
  354. if srcm.Len() == 0 {
  355. return
  356. }
  357. if dstm.IsNil() {
  358. dstm.Set(reflect.MakeMap(f.ft))
  359. }
  360. iter := mapRange(srcm)
  361. for iter.Next() {
  362. val := reflect.New(f.ft.Elem().Elem())
  363. if f.mi != nil {
  364. f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
  365. } else {
  366. opts.Merge(asMessage(val), asMessage(iter.Value()))
  367. }
  368. dstm.SetMapIndex(iter.Key(), val)
  369. }
  370. }