You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

561 lines
13 KiB

  1. // Copyright 2010 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "bytes"
  7. "encoding"
  8. "fmt"
  9. "io"
  10. "math"
  11. "sort"
  12. "strings"
  13. "google.golang.org/protobuf/encoding/prototext"
  14. "google.golang.org/protobuf/encoding/protowire"
  15. "google.golang.org/protobuf/proto"
  16. "google.golang.org/protobuf/reflect/protoreflect"
  17. "google.golang.org/protobuf/reflect/protoregistry"
  18. )
  19. const wrapTextMarshalV2 = false
  20. // TextMarshaler is a configurable text format marshaler.
  21. type TextMarshaler struct {
  22. Compact bool // use compact text format (one line)
  23. ExpandAny bool // expand google.protobuf.Any messages of known types
  24. }
  25. // Marshal writes the proto text format of m to w.
  26. func (tm *TextMarshaler) Marshal(w io.Writer, m Message) error {
  27. b, err := tm.marshal(m)
  28. if len(b) > 0 {
  29. if _, err := w.Write(b); err != nil {
  30. return err
  31. }
  32. }
  33. return err
  34. }
  35. // Text returns a proto text formatted string of m.
  36. func (tm *TextMarshaler) Text(m Message) string {
  37. b, _ := tm.marshal(m)
  38. return string(b)
  39. }
  40. func (tm *TextMarshaler) marshal(m Message) ([]byte, error) {
  41. mr := MessageReflect(m)
  42. if mr == nil || !mr.IsValid() {
  43. return []byte("<nil>"), nil
  44. }
  45. if wrapTextMarshalV2 {
  46. if m, ok := m.(encoding.TextMarshaler); ok {
  47. return m.MarshalText()
  48. }
  49. opts := prototext.MarshalOptions{
  50. AllowPartial: true,
  51. EmitUnknown: true,
  52. }
  53. if !tm.Compact {
  54. opts.Indent = " "
  55. }
  56. if !tm.ExpandAny {
  57. opts.Resolver = (*protoregistry.Types)(nil)
  58. }
  59. return opts.Marshal(mr.Interface())
  60. } else {
  61. w := &textWriter{
  62. compact: tm.Compact,
  63. expandAny: tm.ExpandAny,
  64. complete: true,
  65. }
  66. if m, ok := m.(encoding.TextMarshaler); ok {
  67. b, err := m.MarshalText()
  68. if err != nil {
  69. return nil, err
  70. }
  71. w.Write(b)
  72. return w.buf, nil
  73. }
  74. err := w.writeMessage(mr)
  75. return w.buf, err
  76. }
  77. }
  78. var (
  79. defaultTextMarshaler = TextMarshaler{}
  80. compactTextMarshaler = TextMarshaler{Compact: true}
  81. )
  82. // MarshalText writes the proto text format of m to w.
  83. func MarshalText(w io.Writer, m Message) error { return defaultTextMarshaler.Marshal(w, m) }
  84. // MarshalTextString returns a proto text formatted string of m.
  85. func MarshalTextString(m Message) string { return defaultTextMarshaler.Text(m) }
  86. // CompactText writes the compact proto text format of m to w.
  87. func CompactText(w io.Writer, m Message) error { return compactTextMarshaler.Marshal(w, m) }
  88. // CompactTextString returns a compact proto text formatted string of m.
  89. func CompactTextString(m Message) string { return compactTextMarshaler.Text(m) }
  90. var (
  91. newline = []byte("\n")
  92. endBraceNewline = []byte("}\n")
  93. posInf = []byte("inf")
  94. negInf = []byte("-inf")
  95. nan = []byte("nan")
  96. )
  97. // textWriter is an io.Writer that tracks its indentation level.
  98. type textWriter struct {
  99. compact bool // same as TextMarshaler.Compact
  100. expandAny bool // same as TextMarshaler.ExpandAny
  101. complete bool // whether the current position is a complete line
  102. indent int // indentation level; never negative
  103. buf []byte
  104. }
  105. func (w *textWriter) Write(p []byte) (n int, _ error) {
  106. newlines := bytes.Count(p, newline)
  107. if newlines == 0 {
  108. if !w.compact && w.complete {
  109. w.writeIndent()
  110. }
  111. w.buf = append(w.buf, p...)
  112. w.complete = false
  113. return len(p), nil
  114. }
  115. frags := bytes.SplitN(p, newline, newlines+1)
  116. if w.compact {
  117. for i, frag := range frags {
  118. if i > 0 {
  119. w.buf = append(w.buf, ' ')
  120. n++
  121. }
  122. w.buf = append(w.buf, frag...)
  123. n += len(frag)
  124. }
  125. return n, nil
  126. }
  127. for i, frag := range frags {
  128. if w.complete {
  129. w.writeIndent()
  130. }
  131. w.buf = append(w.buf, frag...)
  132. n += len(frag)
  133. if i+1 < len(frags) {
  134. w.buf = append(w.buf, '\n')
  135. n++
  136. }
  137. }
  138. w.complete = len(frags[len(frags)-1]) == 0
  139. return n, nil
  140. }
  141. func (w *textWriter) WriteByte(c byte) error {
  142. if w.compact && c == '\n' {
  143. c = ' '
  144. }
  145. if !w.compact && w.complete {
  146. w.writeIndent()
  147. }
  148. w.buf = append(w.buf, c)
  149. w.complete = c == '\n'
  150. return nil
  151. }
  152. func (w *textWriter) writeName(fd protoreflect.FieldDescriptor) {
  153. if !w.compact && w.complete {
  154. w.writeIndent()
  155. }
  156. w.complete = false
  157. if fd.Kind() != protoreflect.GroupKind {
  158. w.buf = append(w.buf, fd.Name()...)
  159. w.WriteByte(':')
  160. } else {
  161. // Use message type name for group field name.
  162. w.buf = append(w.buf, fd.Message().Name()...)
  163. }
  164. if !w.compact {
  165. w.WriteByte(' ')
  166. }
  167. }
  168. func requiresQuotes(u string) bool {
  169. // When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
  170. for _, ch := range u {
  171. switch {
  172. case ch == '.' || ch == '/' || ch == '_':
  173. continue
  174. case '0' <= ch && ch <= '9':
  175. continue
  176. case 'A' <= ch && ch <= 'Z':
  177. continue
  178. case 'a' <= ch && ch <= 'z':
  179. continue
  180. default:
  181. return true
  182. }
  183. }
  184. return false
  185. }
  186. // writeProto3Any writes an expanded google.protobuf.Any message.
  187. //
  188. // It returns (false, nil) if sv value can't be unmarshaled (e.g. because
  189. // required messages are not linked in).
  190. //
  191. // It returns (true, error) when sv was written in expanded format or an error
  192. // was encountered.
  193. func (w *textWriter) writeProto3Any(m protoreflect.Message) (bool, error) {
  194. md := m.Descriptor()
  195. fdURL := md.Fields().ByName("type_url")
  196. fdVal := md.Fields().ByName("value")
  197. url := m.Get(fdURL).String()
  198. mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
  199. if err != nil {
  200. return false, nil
  201. }
  202. b := m.Get(fdVal).Bytes()
  203. m2 := mt.New()
  204. if err := proto.Unmarshal(b, m2.Interface()); err != nil {
  205. return false, nil
  206. }
  207. w.Write([]byte("["))
  208. if requiresQuotes(url) {
  209. w.writeQuotedString(url)
  210. } else {
  211. w.Write([]byte(url))
  212. }
  213. if w.compact {
  214. w.Write([]byte("]:<"))
  215. } else {
  216. w.Write([]byte("]: <\n"))
  217. w.indent++
  218. }
  219. if err := w.writeMessage(m2); err != nil {
  220. return true, err
  221. }
  222. if w.compact {
  223. w.Write([]byte("> "))
  224. } else {
  225. w.indent--
  226. w.Write([]byte(">\n"))
  227. }
  228. return true, nil
  229. }
  230. func (w *textWriter) writeMessage(m protoreflect.Message) error {
  231. md := m.Descriptor()
  232. if w.expandAny && md.FullName() == "google.protobuf.Any" {
  233. if canExpand, err := w.writeProto3Any(m); canExpand {
  234. return err
  235. }
  236. }
  237. fds := md.Fields()
  238. for i := 0; i < fds.Len(); {
  239. fd := fds.Get(i)
  240. if od := fd.ContainingOneof(); od != nil {
  241. fd = m.WhichOneof(od)
  242. i += od.Fields().Len()
  243. } else {
  244. i++
  245. }
  246. if fd == nil || !m.Has(fd) {
  247. continue
  248. }
  249. switch {
  250. case fd.IsList():
  251. lv := m.Get(fd).List()
  252. for j := 0; j < lv.Len(); j++ {
  253. w.writeName(fd)
  254. v := lv.Get(j)
  255. if err := w.writeSingularValue(v, fd); err != nil {
  256. return err
  257. }
  258. w.WriteByte('\n')
  259. }
  260. case fd.IsMap():
  261. kfd := fd.MapKey()
  262. vfd := fd.MapValue()
  263. mv := m.Get(fd).Map()
  264. type entry struct{ key, val protoreflect.Value }
  265. var entries []entry
  266. mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
  267. entries = append(entries, entry{k.Value(), v})
  268. return true
  269. })
  270. sort.Slice(entries, func(i, j int) bool {
  271. switch kfd.Kind() {
  272. case protoreflect.BoolKind:
  273. return !entries[i].key.Bool() && entries[j].key.Bool()
  274. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  275. return entries[i].key.Int() < entries[j].key.Int()
  276. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  277. return entries[i].key.Uint() < entries[j].key.Uint()
  278. case protoreflect.StringKind:
  279. return entries[i].key.String() < entries[j].key.String()
  280. default:
  281. panic("invalid kind")
  282. }
  283. })
  284. for _, entry := range entries {
  285. w.writeName(fd)
  286. w.WriteByte('<')
  287. if !w.compact {
  288. w.WriteByte('\n')
  289. }
  290. w.indent++
  291. w.writeName(kfd)
  292. if err := w.writeSingularValue(entry.key, kfd); err != nil {
  293. return err
  294. }
  295. w.WriteByte('\n')
  296. w.writeName(vfd)
  297. if err := w.writeSingularValue(entry.val, vfd); err != nil {
  298. return err
  299. }
  300. w.WriteByte('\n')
  301. w.indent--
  302. w.WriteByte('>')
  303. w.WriteByte('\n')
  304. }
  305. default:
  306. w.writeName(fd)
  307. if err := w.writeSingularValue(m.Get(fd), fd); err != nil {
  308. return err
  309. }
  310. w.WriteByte('\n')
  311. }
  312. }
  313. if b := m.GetUnknown(); len(b) > 0 {
  314. w.writeUnknownFields(b)
  315. }
  316. return w.writeExtensions(m)
  317. }
  318. func (w *textWriter) writeSingularValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
  319. switch fd.Kind() {
  320. case protoreflect.FloatKind, protoreflect.DoubleKind:
  321. switch vf := v.Float(); {
  322. case math.IsInf(vf, +1):
  323. w.Write(posInf)
  324. case math.IsInf(vf, -1):
  325. w.Write(negInf)
  326. case math.IsNaN(vf):
  327. w.Write(nan)
  328. default:
  329. fmt.Fprint(w, v.Interface())
  330. }
  331. case protoreflect.StringKind:
  332. // NOTE: This does not validate UTF-8 for historical reasons.
  333. w.writeQuotedString(string(v.String()))
  334. case protoreflect.BytesKind:
  335. w.writeQuotedString(string(v.Bytes()))
  336. case protoreflect.MessageKind, protoreflect.GroupKind:
  337. var bra, ket byte = '<', '>'
  338. if fd.Kind() == protoreflect.GroupKind {
  339. bra, ket = '{', '}'
  340. }
  341. w.WriteByte(bra)
  342. if !w.compact {
  343. w.WriteByte('\n')
  344. }
  345. w.indent++
  346. m := v.Message()
  347. if m2, ok := m.Interface().(encoding.TextMarshaler); ok {
  348. b, err := m2.MarshalText()
  349. if err != nil {
  350. return err
  351. }
  352. w.Write(b)
  353. } else {
  354. w.writeMessage(m)
  355. }
  356. w.indent--
  357. w.WriteByte(ket)
  358. case protoreflect.EnumKind:
  359. if ev := fd.Enum().Values().ByNumber(v.Enum()); ev != nil {
  360. fmt.Fprint(w, ev.Name())
  361. } else {
  362. fmt.Fprint(w, v.Enum())
  363. }
  364. default:
  365. fmt.Fprint(w, v.Interface())
  366. }
  367. return nil
  368. }
  369. // writeQuotedString writes a quoted string in the protocol buffer text format.
  370. func (w *textWriter) writeQuotedString(s string) {
  371. w.WriteByte('"')
  372. for i := 0; i < len(s); i++ {
  373. switch c := s[i]; c {
  374. case '\n':
  375. w.buf = append(w.buf, `\n`...)
  376. case '\r':
  377. w.buf = append(w.buf, `\r`...)
  378. case '\t':
  379. w.buf = append(w.buf, `\t`...)
  380. case '"':
  381. w.buf = append(w.buf, `\"`...)
  382. case '\\':
  383. w.buf = append(w.buf, `\\`...)
  384. default:
  385. if isPrint := c >= 0x20 && c < 0x7f; isPrint {
  386. w.buf = append(w.buf, c)
  387. } else {
  388. w.buf = append(w.buf, fmt.Sprintf(`\%03o`, c)...)
  389. }
  390. }
  391. }
  392. w.WriteByte('"')
  393. }
  394. func (w *textWriter) writeUnknownFields(b []byte) {
  395. if !w.compact {
  396. fmt.Fprintf(w, "/* %d unknown bytes */\n", len(b))
  397. }
  398. for len(b) > 0 {
  399. num, wtyp, n := protowire.ConsumeTag(b)
  400. if n < 0 {
  401. return
  402. }
  403. b = b[n:]
  404. if wtyp == protowire.EndGroupType {
  405. w.indent--
  406. w.Write(endBraceNewline)
  407. continue
  408. }
  409. fmt.Fprint(w, num)
  410. if wtyp != protowire.StartGroupType {
  411. w.WriteByte(':')
  412. }
  413. if !w.compact || wtyp == protowire.StartGroupType {
  414. w.WriteByte(' ')
  415. }
  416. switch wtyp {
  417. case protowire.VarintType:
  418. v, n := protowire.ConsumeVarint(b)
  419. if n < 0 {
  420. return
  421. }
  422. b = b[n:]
  423. fmt.Fprint(w, v)
  424. case protowire.Fixed32Type:
  425. v, n := protowire.ConsumeFixed32(b)
  426. if n < 0 {
  427. return
  428. }
  429. b = b[n:]
  430. fmt.Fprint(w, v)
  431. case protowire.Fixed64Type:
  432. v, n := protowire.ConsumeFixed64(b)
  433. if n < 0 {
  434. return
  435. }
  436. b = b[n:]
  437. fmt.Fprint(w, v)
  438. case protowire.BytesType:
  439. v, n := protowire.ConsumeBytes(b)
  440. if n < 0 {
  441. return
  442. }
  443. b = b[n:]
  444. fmt.Fprintf(w, "%q", v)
  445. case protowire.StartGroupType:
  446. w.WriteByte('{')
  447. w.indent++
  448. default:
  449. fmt.Fprintf(w, "/* unknown wire type %d */", wtyp)
  450. }
  451. w.WriteByte('\n')
  452. }
  453. }
  454. // writeExtensions writes all the extensions in m.
  455. func (w *textWriter) writeExtensions(m protoreflect.Message) error {
  456. md := m.Descriptor()
  457. if md.ExtensionRanges().Len() == 0 {
  458. return nil
  459. }
  460. type ext struct {
  461. desc protoreflect.FieldDescriptor
  462. val protoreflect.Value
  463. }
  464. var exts []ext
  465. m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
  466. if fd.IsExtension() {
  467. exts = append(exts, ext{fd, v})
  468. }
  469. return true
  470. })
  471. sort.Slice(exts, func(i, j int) bool {
  472. return exts[i].desc.Number() < exts[j].desc.Number()
  473. })
  474. for _, ext := range exts {
  475. // For message set, use the name of the message as the extension name.
  476. name := string(ext.desc.FullName())
  477. if isMessageSet(ext.desc.ContainingMessage()) {
  478. name = strings.TrimSuffix(name, ".message_set_extension")
  479. }
  480. if !ext.desc.IsList() {
  481. if err := w.writeSingularExtension(name, ext.val, ext.desc); err != nil {
  482. return err
  483. }
  484. } else {
  485. lv := ext.val.List()
  486. for i := 0; i < lv.Len(); i++ {
  487. if err := w.writeSingularExtension(name, lv.Get(i), ext.desc); err != nil {
  488. return err
  489. }
  490. }
  491. }
  492. }
  493. return nil
  494. }
  495. func (w *textWriter) writeSingularExtension(name string, v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
  496. fmt.Fprintf(w, "[%s]:", name)
  497. if !w.compact {
  498. w.WriteByte(' ')
  499. }
  500. if err := w.writeSingularValue(v, fd); err != nil {
  501. return err
  502. }
  503. w.WriteByte('\n')
  504. return nil
  505. }
  506. func (w *textWriter) writeIndent() {
  507. if !w.complete {
  508. return
  509. }
  510. for i := 0; i < w.indent*2; i++ {
  511. w.buf = append(w.buf, ' ')
  512. }
  513. w.complete = false
  514. }