You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

552 lines
13 KiB

  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package protowire parses and formats the raw wire encoding.
  5. // See https://developers.google.com/protocol-buffers/docs/encoding.
  6. //
  7. // For marshaling and unmarshaling entire protobuf messages,
  8. // use the "google.golang.org/protobuf/proto" package instead.
  9. package protowire
  10. import (
  11. "io"
  12. "math"
  13. "math/bits"
  14. "google.golang.org/protobuf/internal/errors"
  15. )
  16. // Number represents the field number.
  17. type Number int32
  18. const (
  19. MinValidNumber Number = 1
  20. FirstReservedNumber Number = 19000
  21. LastReservedNumber Number = 19999
  22. MaxValidNumber Number = 1<<29 - 1
  23. DefaultRecursionLimit = 10000
  24. )
  25. // IsValid reports whether the field number is semantically valid.
  26. //
  27. // Note that while numbers within the reserved range are semantically invalid,
  28. // they are syntactically valid in the wire format.
  29. // Implementations may treat records with reserved field numbers as unknown.
  30. func (n Number) IsValid() bool {
  31. return MinValidNumber <= n && n < FirstReservedNumber || LastReservedNumber < n && n <= MaxValidNumber
  32. }
  33. // Type represents the wire type.
  34. type Type int8
  35. const (
  36. VarintType Type = 0
  37. Fixed32Type Type = 5
  38. Fixed64Type Type = 1
  39. BytesType Type = 2
  40. StartGroupType Type = 3
  41. EndGroupType Type = 4
  42. )
  43. const (
  44. _ = -iota
  45. errCodeTruncated
  46. errCodeFieldNumber
  47. errCodeOverflow
  48. errCodeReserved
  49. errCodeEndGroup
  50. errCodeRecursionDepth
  51. )
  52. var (
  53. errFieldNumber = errors.New("invalid field number")
  54. errOverflow = errors.New("variable length integer overflow")
  55. errReserved = errors.New("cannot parse reserved wire type")
  56. errEndGroup = errors.New("mismatching end group marker")
  57. errParse = errors.New("parse error")
  58. )
  59. // ParseError converts an error code into an error value.
  60. // This returns nil if n is a non-negative number.
  61. func ParseError(n int) error {
  62. if n >= 0 {
  63. return nil
  64. }
  65. switch n {
  66. case errCodeTruncated:
  67. return io.ErrUnexpectedEOF
  68. case errCodeFieldNumber:
  69. return errFieldNumber
  70. case errCodeOverflow:
  71. return errOverflow
  72. case errCodeReserved:
  73. return errReserved
  74. case errCodeEndGroup:
  75. return errEndGroup
  76. default:
  77. return errParse
  78. }
  79. }
  80. // ConsumeField parses an entire field record (both tag and value) and returns
  81. // the field number, the wire type, and the total length.
  82. // This returns a negative length upon an error (see ParseError).
  83. //
  84. // The total length includes the tag header and the end group marker (if the
  85. // field is a group).
  86. func ConsumeField(b []byte) (Number, Type, int) {
  87. num, typ, n := ConsumeTag(b)
  88. if n < 0 {
  89. return 0, 0, n // forward error code
  90. }
  91. m := ConsumeFieldValue(num, typ, b[n:])
  92. if m < 0 {
  93. return 0, 0, m // forward error code
  94. }
  95. return num, typ, n + m
  96. }
  97. // ConsumeFieldValue parses a field value and returns its length.
  98. // This assumes that the field Number and wire Type have already been parsed.
  99. // This returns a negative length upon an error (see ParseError).
  100. //
  101. // When parsing a group, the length includes the end group marker and
  102. // the end group is verified to match the starting field number.
  103. func ConsumeFieldValue(num Number, typ Type, b []byte) (n int) {
  104. return consumeFieldValueD(num, typ, b, DefaultRecursionLimit)
  105. }
  106. func consumeFieldValueD(num Number, typ Type, b []byte, depth int) (n int) {
  107. switch typ {
  108. case VarintType:
  109. _, n = ConsumeVarint(b)
  110. return n
  111. case Fixed32Type:
  112. _, n = ConsumeFixed32(b)
  113. return n
  114. case Fixed64Type:
  115. _, n = ConsumeFixed64(b)
  116. return n
  117. case BytesType:
  118. _, n = ConsumeBytes(b)
  119. return n
  120. case StartGroupType:
  121. if depth < 0 {
  122. return errCodeRecursionDepth
  123. }
  124. n0 := len(b)
  125. for {
  126. num2, typ2, n := ConsumeTag(b)
  127. if n < 0 {
  128. return n // forward error code
  129. }
  130. b = b[n:]
  131. if typ2 == EndGroupType {
  132. if num != num2 {
  133. return errCodeEndGroup
  134. }
  135. return n0 - len(b)
  136. }
  137. n = consumeFieldValueD(num2, typ2, b, depth-1)
  138. if n < 0 {
  139. return n // forward error code
  140. }
  141. b = b[n:]
  142. }
  143. case EndGroupType:
  144. return errCodeEndGroup
  145. default:
  146. return errCodeReserved
  147. }
  148. }
  149. // AppendTag encodes num and typ as a varint-encoded tag and appends it to b.
  150. func AppendTag(b []byte, num Number, typ Type) []byte {
  151. return AppendVarint(b, EncodeTag(num, typ))
  152. }
  153. // ConsumeTag parses b as a varint-encoded tag, reporting its length.
  154. // This returns a negative length upon an error (see ParseError).
  155. func ConsumeTag(b []byte) (Number, Type, int) {
  156. v, n := ConsumeVarint(b)
  157. if n < 0 {
  158. return 0, 0, n // forward error code
  159. }
  160. num, typ := DecodeTag(v)
  161. if num < MinValidNumber {
  162. return 0, 0, errCodeFieldNumber
  163. }
  164. return num, typ, n
  165. }
  166. func SizeTag(num Number) int {
  167. return SizeVarint(EncodeTag(num, 0)) // wire type has no effect on size
  168. }
  169. // AppendVarint appends v to b as a varint-encoded uint64.
  170. func AppendVarint(b []byte, v uint64) []byte {
  171. switch {
  172. case v < 1<<7:
  173. b = append(b, byte(v))
  174. case v < 1<<14:
  175. b = append(b,
  176. byte((v>>0)&0x7f|0x80),
  177. byte(v>>7))
  178. case v < 1<<21:
  179. b = append(b,
  180. byte((v>>0)&0x7f|0x80),
  181. byte((v>>7)&0x7f|0x80),
  182. byte(v>>14))
  183. case v < 1<<28:
  184. b = append(b,
  185. byte((v>>0)&0x7f|0x80),
  186. byte((v>>7)&0x7f|0x80),
  187. byte((v>>14)&0x7f|0x80),
  188. byte(v>>21))
  189. case v < 1<<35:
  190. b = append(b,
  191. byte((v>>0)&0x7f|0x80),
  192. byte((v>>7)&0x7f|0x80),
  193. byte((v>>14)&0x7f|0x80),
  194. byte((v>>21)&0x7f|0x80),
  195. byte(v>>28))
  196. case v < 1<<42:
  197. b = append(b,
  198. byte((v>>0)&0x7f|0x80),
  199. byte((v>>7)&0x7f|0x80),
  200. byte((v>>14)&0x7f|0x80),
  201. byte((v>>21)&0x7f|0x80),
  202. byte((v>>28)&0x7f|0x80),
  203. byte(v>>35))
  204. case v < 1<<49:
  205. b = append(b,
  206. byte((v>>0)&0x7f|0x80),
  207. byte((v>>7)&0x7f|0x80),
  208. byte((v>>14)&0x7f|0x80),
  209. byte((v>>21)&0x7f|0x80),
  210. byte((v>>28)&0x7f|0x80),
  211. byte((v>>35)&0x7f|0x80),
  212. byte(v>>42))
  213. case v < 1<<56:
  214. b = append(b,
  215. byte((v>>0)&0x7f|0x80),
  216. byte((v>>7)&0x7f|0x80),
  217. byte((v>>14)&0x7f|0x80),
  218. byte((v>>21)&0x7f|0x80),
  219. byte((v>>28)&0x7f|0x80),
  220. byte((v>>35)&0x7f|0x80),
  221. byte((v>>42)&0x7f|0x80),
  222. byte(v>>49))
  223. case v < 1<<63:
  224. b = append(b,
  225. byte((v>>0)&0x7f|0x80),
  226. byte((v>>7)&0x7f|0x80),
  227. byte((v>>14)&0x7f|0x80),
  228. byte((v>>21)&0x7f|0x80),
  229. byte((v>>28)&0x7f|0x80),
  230. byte((v>>35)&0x7f|0x80),
  231. byte((v>>42)&0x7f|0x80),
  232. byte((v>>49)&0x7f|0x80),
  233. byte(v>>56))
  234. default:
  235. b = append(b,
  236. byte((v>>0)&0x7f|0x80),
  237. byte((v>>7)&0x7f|0x80),
  238. byte((v>>14)&0x7f|0x80),
  239. byte((v>>21)&0x7f|0x80),
  240. byte((v>>28)&0x7f|0x80),
  241. byte((v>>35)&0x7f|0x80),
  242. byte((v>>42)&0x7f|0x80),
  243. byte((v>>49)&0x7f|0x80),
  244. byte((v>>56)&0x7f|0x80),
  245. 1)
  246. }
  247. return b
  248. }
  249. // ConsumeVarint parses b as a varint-encoded uint64, reporting its length.
  250. // This returns a negative length upon an error (see ParseError).
  251. func ConsumeVarint(b []byte) (v uint64, n int) {
  252. var y uint64
  253. if len(b) <= 0 {
  254. return 0, errCodeTruncated
  255. }
  256. v = uint64(b[0])
  257. if v < 0x80 {
  258. return v, 1
  259. }
  260. v -= 0x80
  261. if len(b) <= 1 {
  262. return 0, errCodeTruncated
  263. }
  264. y = uint64(b[1])
  265. v += y << 7
  266. if y < 0x80 {
  267. return v, 2
  268. }
  269. v -= 0x80 << 7
  270. if len(b) <= 2 {
  271. return 0, errCodeTruncated
  272. }
  273. y = uint64(b[2])
  274. v += y << 14
  275. if y < 0x80 {
  276. return v, 3
  277. }
  278. v -= 0x80 << 14
  279. if len(b) <= 3 {
  280. return 0, errCodeTruncated
  281. }
  282. y = uint64(b[3])
  283. v += y << 21
  284. if y < 0x80 {
  285. return v, 4
  286. }
  287. v -= 0x80 << 21
  288. if len(b) <= 4 {
  289. return 0, errCodeTruncated
  290. }
  291. y = uint64(b[4])
  292. v += y << 28
  293. if y < 0x80 {
  294. return v, 5
  295. }
  296. v -= 0x80 << 28
  297. if len(b) <= 5 {
  298. return 0, errCodeTruncated
  299. }
  300. y = uint64(b[5])
  301. v += y << 35
  302. if y < 0x80 {
  303. return v, 6
  304. }
  305. v -= 0x80 << 35
  306. if len(b) <= 6 {
  307. return 0, errCodeTruncated
  308. }
  309. y = uint64(b[6])
  310. v += y << 42
  311. if y < 0x80 {
  312. return v, 7
  313. }
  314. v -= 0x80 << 42
  315. if len(b) <= 7 {
  316. return 0, errCodeTruncated
  317. }
  318. y = uint64(b[7])
  319. v += y << 49
  320. if y < 0x80 {
  321. return v, 8
  322. }
  323. v -= 0x80 << 49
  324. if len(b) <= 8 {
  325. return 0, errCodeTruncated
  326. }
  327. y = uint64(b[8])
  328. v += y << 56
  329. if y < 0x80 {
  330. return v, 9
  331. }
  332. v -= 0x80 << 56
  333. if len(b) <= 9 {
  334. return 0, errCodeTruncated
  335. }
  336. y = uint64(b[9])
  337. v += y << 63
  338. if y < 2 {
  339. return v, 10
  340. }
  341. return 0, errCodeOverflow
  342. }
  343. // SizeVarint returns the encoded size of a varint.
  344. // The size is guaranteed to be within 1 and 10, inclusive.
  345. func SizeVarint(v uint64) int {
  346. // This computes 1 + (bits.Len64(v)-1)/7.
  347. // 9/64 is a good enough approximation of 1/7
  348. return int(9*uint32(bits.Len64(v))+64) / 64
  349. }
  350. // AppendFixed32 appends v to b as a little-endian uint32.
  351. func AppendFixed32(b []byte, v uint32) []byte {
  352. return append(b,
  353. byte(v>>0),
  354. byte(v>>8),
  355. byte(v>>16),
  356. byte(v>>24))
  357. }
  358. // ConsumeFixed32 parses b as a little-endian uint32, reporting its length.
  359. // This returns a negative length upon an error (see ParseError).
  360. func ConsumeFixed32(b []byte) (v uint32, n int) {
  361. if len(b) < 4 {
  362. return 0, errCodeTruncated
  363. }
  364. v = uint32(b[0])<<0 | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
  365. return v, 4
  366. }
  367. // SizeFixed32 returns the encoded size of a fixed32; which is always 4.
  368. func SizeFixed32() int {
  369. return 4
  370. }
  371. // AppendFixed64 appends v to b as a little-endian uint64.
  372. func AppendFixed64(b []byte, v uint64) []byte {
  373. return append(b,
  374. byte(v>>0),
  375. byte(v>>8),
  376. byte(v>>16),
  377. byte(v>>24),
  378. byte(v>>32),
  379. byte(v>>40),
  380. byte(v>>48),
  381. byte(v>>56))
  382. }
  383. // ConsumeFixed64 parses b as a little-endian uint64, reporting its length.
  384. // This returns a negative length upon an error (see ParseError).
  385. func ConsumeFixed64(b []byte) (v uint64, n int) {
  386. if len(b) < 8 {
  387. return 0, errCodeTruncated
  388. }
  389. v = uint64(b[0])<<0 | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
  390. return v, 8
  391. }
  392. // SizeFixed64 returns the encoded size of a fixed64; which is always 8.
  393. func SizeFixed64() int {
  394. return 8
  395. }
  396. // AppendBytes appends v to b as a length-prefixed bytes value.
  397. func AppendBytes(b []byte, v []byte) []byte {
  398. return append(AppendVarint(b, uint64(len(v))), v...)
  399. }
  400. // ConsumeBytes parses b as a length-prefixed bytes value, reporting its length.
  401. // This returns a negative length upon an error (see ParseError).
  402. func ConsumeBytes(b []byte) (v []byte, n int) {
  403. m, n := ConsumeVarint(b)
  404. if n < 0 {
  405. return nil, n // forward error code
  406. }
  407. if m > uint64(len(b[n:])) {
  408. return nil, errCodeTruncated
  409. }
  410. return b[n:][:m], n + int(m)
  411. }
  412. // SizeBytes returns the encoded size of a length-prefixed bytes value,
  413. // given only the length.
  414. func SizeBytes(n int) int {
  415. return SizeVarint(uint64(n)) + n
  416. }
  417. // AppendString appends v to b as a length-prefixed bytes value.
  418. func AppendString(b []byte, v string) []byte {
  419. return append(AppendVarint(b, uint64(len(v))), v...)
  420. }
  421. // ConsumeString parses b as a length-prefixed bytes value, reporting its length.
  422. // This returns a negative length upon an error (see ParseError).
  423. func ConsumeString(b []byte) (v string, n int) {
  424. bb, n := ConsumeBytes(b)
  425. return string(bb), n
  426. }
  427. // AppendGroup appends v to b as group value, with a trailing end group marker.
  428. // The value v must not contain the end marker.
  429. func AppendGroup(b []byte, num Number, v []byte) []byte {
  430. return AppendVarint(append(b, v...), EncodeTag(num, EndGroupType))
  431. }
  432. // ConsumeGroup parses b as a group value until the trailing end group marker,
  433. // and verifies that the end marker matches the provided num. The value v
  434. // does not contain the end marker, while the length does contain the end marker.
  435. // This returns a negative length upon an error (see ParseError).
  436. func ConsumeGroup(num Number, b []byte) (v []byte, n int) {
  437. n = ConsumeFieldValue(num, StartGroupType, b)
  438. if n < 0 {
  439. return nil, n // forward error code
  440. }
  441. b = b[:n]
  442. // Truncate off end group marker, but need to handle denormalized varints.
  443. // Assuming end marker is never 0 (which is always the case since
  444. // EndGroupType is non-zero), we can truncate all trailing bytes where the
  445. // lower 7 bits are all zero (implying that the varint is denormalized).
  446. for len(b) > 0 && b[len(b)-1]&0x7f == 0 {
  447. b = b[:len(b)-1]
  448. }
  449. b = b[:len(b)-SizeTag(num)]
  450. return b, n
  451. }
  452. // SizeGroup returns the encoded size of a group, given only the length.
  453. func SizeGroup(num Number, n int) int {
  454. return n + SizeTag(num)
  455. }
  456. // DecodeTag decodes the field Number and wire Type from its unified form.
  457. // The Number is -1 if the decoded field number overflows int32.
  458. // Other than overflow, this does not check for field number validity.
  459. func DecodeTag(x uint64) (Number, Type) {
  460. // NOTE: MessageSet allows for larger field numbers than normal.
  461. if x>>3 > uint64(math.MaxInt32) {
  462. return -1, 0
  463. }
  464. return Number(x >> 3), Type(x & 7)
  465. }
  466. // EncodeTag encodes the field Number and wire Type into its unified form.
  467. func EncodeTag(num Number, typ Type) uint64 {
  468. return uint64(num)<<3 | uint64(typ&7)
  469. }
  470. // DecodeZigZag decodes a zig-zag-encoded uint64 as an int64.
  471. //
  472. // Input: {…, 5, 3, 1, 0, 2, 4, 6, …}
  473. // Output: {…, -3, -2, -1, 0, +1, +2, +3, …}
  474. func DecodeZigZag(x uint64) int64 {
  475. return int64(x>>1) ^ int64(x)<<63>>63
  476. }
  477. // EncodeZigZag encodes an int64 as a zig-zag-encoded uint64.
  478. //
  479. // Input: {…, -3, -2, -1, 0, +1, +2, +3, …}
  480. // Output: {…, 5, 3, 1, 0, 2, 4, 6, …}
  481. func EncodeZigZag(x int64) uint64 {
  482. return uint64(x<<1) ^ uint64(x>>63)
  483. }
  484. // DecodeBool decodes a uint64 as a bool.
  485. //
  486. // Input: { 0, 1, 2, …}
  487. // Output: {false, true, true, …}
  488. func DecodeBool(x uint64) bool {
  489. return x != 0
  490. }
  491. // EncodeBool encodes a bool as a uint64.
  492. //
  493. // Input: {false, true}
  494. // Output: { 0, 1}
  495. func EncodeBool(x bool) uint64 {
  496. if x {
  497. return 1
  498. }
  499. return 0
  500. }