You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

243 lines
7.4 KiB

  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package messageset encodes and decodes the obsolete MessageSet wire format.
  5. package messageset
  6. import (
  7. "math"
  8. "google.golang.org/protobuf/encoding/protowire"
  9. "google.golang.org/protobuf/internal/errors"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. // The MessageSet wire format is equivalent to a message defined as follows,
  13. // where each Item defines an extension field with a field number of 'type_id'
  14. // and content of 'message'. MessageSet extensions must be non-repeated message
  15. // fields.
  16. //
  17. // message MessageSet {
  18. // repeated group Item = 1 {
  19. // required int32 type_id = 2;
  20. // required string message = 3;
  21. // }
  22. // }
  23. const (
  24. FieldItem = protowire.Number(1)
  25. FieldTypeID = protowire.Number(2)
  26. FieldMessage = protowire.Number(3)
  27. )
  28. // ExtensionName is the field name for extensions of MessageSet.
  29. //
  30. // A valid MessageSet extension must be of the form:
  31. //
  32. // message MyMessage {
  33. // extend proto2.bridge.MessageSet {
  34. // optional MyMessage message_set_extension = 1234;
  35. // }
  36. // ...
  37. // }
  38. const ExtensionName = "message_set_extension"
  39. // IsMessageSet returns whether the message uses the MessageSet wire format.
  40. func IsMessageSet(md protoreflect.MessageDescriptor) bool {
  41. xmd, ok := md.(interface{ IsMessageSet() bool })
  42. return ok && xmd.IsMessageSet()
  43. }
  44. // IsMessageSetExtension reports this field properly extends a MessageSet.
  45. func IsMessageSetExtension(fd protoreflect.FieldDescriptor) bool {
  46. switch {
  47. case fd.Name() != ExtensionName:
  48. return false
  49. case !IsMessageSet(fd.ContainingMessage()):
  50. return false
  51. case fd.FullName().Parent() != fd.Message().FullName():
  52. return false
  53. }
  54. return true
  55. }
  56. // SizeField returns the size of a MessageSet item field containing an extension
  57. // with the given field number, not counting the contents of the message subfield.
  58. func SizeField(num protowire.Number) int {
  59. return 2*protowire.SizeTag(FieldItem) + protowire.SizeTag(FieldTypeID) + protowire.SizeVarint(uint64(num))
  60. }
  61. // Unmarshal parses a MessageSet.
  62. //
  63. // It calls fn with the type ID and value of each item in the MessageSet.
  64. // Unknown fields are discarded.
  65. //
  66. // If wantLen is true, the item values include the varint length prefix.
  67. // This is ugly, but simplifies the fast-path decoder in internal/impl.
  68. func Unmarshal(b []byte, wantLen bool, fn func(typeID protowire.Number, value []byte) error) error {
  69. for len(b) > 0 {
  70. num, wtyp, n := protowire.ConsumeTag(b)
  71. if n < 0 {
  72. return protowire.ParseError(n)
  73. }
  74. b = b[n:]
  75. if num != FieldItem || wtyp != protowire.StartGroupType {
  76. n := protowire.ConsumeFieldValue(num, wtyp, b)
  77. if n < 0 {
  78. return protowire.ParseError(n)
  79. }
  80. b = b[n:]
  81. continue
  82. }
  83. typeID, value, n, err := ConsumeFieldValue(b, wantLen)
  84. if err != nil {
  85. return err
  86. }
  87. b = b[n:]
  88. if typeID == 0 {
  89. continue
  90. }
  91. if err := fn(typeID, value); err != nil {
  92. return err
  93. }
  94. }
  95. return nil
  96. }
  97. // ConsumeFieldValue parses b as a MessageSet item field value until and including
  98. // the trailing end group marker. It assumes the start group tag has already been parsed.
  99. // It returns the contents of the type_id and message subfields and the total
  100. // item length.
  101. //
  102. // If wantLen is true, the returned message value includes the length prefix.
  103. func ConsumeFieldValue(b []byte, wantLen bool) (typeid protowire.Number, message []byte, n int, err error) {
  104. ilen := len(b)
  105. for {
  106. num, wtyp, n := protowire.ConsumeTag(b)
  107. if n < 0 {
  108. return 0, nil, 0, protowire.ParseError(n)
  109. }
  110. b = b[n:]
  111. switch {
  112. case num == FieldItem && wtyp == protowire.EndGroupType:
  113. if wantLen && len(message) == 0 {
  114. // The message field was missing, which should never happen.
  115. // Be prepared for this case anyway.
  116. message = protowire.AppendVarint(message, 0)
  117. }
  118. return typeid, message, ilen - len(b), nil
  119. case num == FieldTypeID && wtyp == protowire.VarintType:
  120. v, n := protowire.ConsumeVarint(b)
  121. if n < 0 {
  122. return 0, nil, 0, protowire.ParseError(n)
  123. }
  124. b = b[n:]
  125. if v < 1 || v > math.MaxInt32 {
  126. return 0, nil, 0, errors.New("invalid type_id in message set")
  127. }
  128. typeid = protowire.Number(v)
  129. case num == FieldMessage && wtyp == protowire.BytesType:
  130. m, n := protowire.ConsumeBytes(b)
  131. if n < 0 {
  132. return 0, nil, 0, protowire.ParseError(n)
  133. }
  134. if message == nil {
  135. if wantLen {
  136. message = b[:n:n]
  137. } else {
  138. message = m[:len(m):len(m)]
  139. }
  140. } else {
  141. // This case should never happen in practice, but handle it for
  142. // correctness: The MessageSet item contains multiple message
  143. // fields, which need to be merged.
  144. //
  145. // In the case where we're returning the length, this becomes
  146. // quite inefficient since we need to strip the length off
  147. // the existing data and reconstruct it with the combined length.
  148. if wantLen {
  149. _, nn := protowire.ConsumeVarint(message)
  150. m0 := message[nn:]
  151. message = nil
  152. message = protowire.AppendVarint(message, uint64(len(m0)+len(m)))
  153. message = append(message, m0...)
  154. message = append(message, m...)
  155. } else {
  156. message = append(message, m...)
  157. }
  158. }
  159. b = b[n:]
  160. default:
  161. // We have no place to put it, so we just ignore unknown fields.
  162. n := protowire.ConsumeFieldValue(num, wtyp, b)
  163. if n < 0 {
  164. return 0, nil, 0, protowire.ParseError(n)
  165. }
  166. b = b[n:]
  167. }
  168. }
  169. }
  170. // AppendFieldStart appends the start of a MessageSet item field containing
  171. // an extension with the given number. The caller must add the message
  172. // subfield (including the tag).
  173. func AppendFieldStart(b []byte, num protowire.Number) []byte {
  174. b = protowire.AppendTag(b, FieldItem, protowire.StartGroupType)
  175. b = protowire.AppendTag(b, FieldTypeID, protowire.VarintType)
  176. b = protowire.AppendVarint(b, uint64(num))
  177. return b
  178. }
  179. // AppendFieldEnd appends the trailing end group marker for a MessageSet item field.
  180. func AppendFieldEnd(b []byte) []byte {
  181. return protowire.AppendTag(b, FieldItem, protowire.EndGroupType)
  182. }
  183. // SizeUnknown returns the size of an unknown fields section in MessageSet format.
  184. //
  185. // See AppendUnknown.
  186. func SizeUnknown(unknown []byte) (size int) {
  187. for len(unknown) > 0 {
  188. num, typ, n := protowire.ConsumeTag(unknown)
  189. if n < 0 || typ != protowire.BytesType {
  190. return 0
  191. }
  192. unknown = unknown[n:]
  193. _, n = protowire.ConsumeBytes(unknown)
  194. if n < 0 {
  195. return 0
  196. }
  197. unknown = unknown[n:]
  198. size += SizeField(num) + protowire.SizeTag(FieldMessage) + n
  199. }
  200. return size
  201. }
  202. // AppendUnknown appends unknown fields to b in MessageSet format.
  203. //
  204. // For historic reasons, unresolved items in a MessageSet are stored in a
  205. // message's unknown fields section in non-MessageSet format. That is, an
  206. // unknown item with typeID T and value V appears in the unknown fields as
  207. // a field with number T and value V.
  208. //
  209. // This function converts the unknown fields back into MessageSet form.
  210. func AppendUnknown(b, unknown []byte) ([]byte, error) {
  211. for len(unknown) > 0 {
  212. num, typ, n := protowire.ConsumeTag(unknown)
  213. if n < 0 || typ != protowire.BytesType {
  214. return nil, errors.New("invalid data in message set unknown fields")
  215. }
  216. unknown = unknown[n:]
  217. _, n = protowire.ConsumeBytes(unknown)
  218. if n < 0 {
  219. return nil, errors.New("invalid data in message set unknown fields")
  220. }
  221. b = AppendFieldStart(b, num)
  222. b = protowire.AppendTag(b, FieldMessage, protowire.BytesType)
  223. b = append(b, unknown[:n]...)
  224. b = AppendFieldEnd(b)
  225. unknown = unknown[n:]
  226. }
  227. return b, nil
  228. }