text_encode.go raw

   1  // Copyright 2010 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package proto
   6  
   7  import (
   8  	"bytes"
   9  	"encoding"
  10  	"fmt"
  11  	"io"
  12  	"math"
  13  	"sort"
  14  	"strings"
  15  
  16  	"google.golang.org/protobuf/encoding/prototext"
  17  	"google.golang.org/protobuf/encoding/protowire"
  18  	"google.golang.org/protobuf/proto"
  19  	"google.golang.org/protobuf/reflect/protoreflect"
  20  	"google.golang.org/protobuf/reflect/protoregistry"
  21  )
  22  
  23  const wrapTextMarshalV2 = false
  24  
  25  // TextMarshaler is a configurable text format marshaler.
  26  type TextMarshaler struct {
  27  	Compact   bool // use compact text format (one line)
  28  	ExpandAny bool // expand google.protobuf.Any messages of known types
  29  }
  30  
  31  // Marshal writes the proto text format of m to w.
  32  func (tm *TextMarshaler) Marshal(w io.Writer, m Message) error {
  33  	b, err := tm.marshal(m)
  34  	if len(b) > 0 {
  35  		if _, err := w.Write(b); err != nil {
  36  			return err
  37  		}
  38  	}
  39  	return err
  40  }
  41  
  42  // Text returns a proto text formatted string of m.
  43  func (tm *TextMarshaler) Text(m Message) string {
  44  	b, _ := tm.marshal(m)
  45  	return string(b)
  46  }
  47  
  48  func (tm *TextMarshaler) marshal(m Message) ([]byte, error) {
  49  	mr := MessageReflect(m)
  50  	if mr == nil || !mr.IsValid() {
  51  		return []byte("<nil>"), nil
  52  	}
  53  
  54  	if wrapTextMarshalV2 {
  55  		if m, ok := m.(encoding.TextMarshaler); ok {
  56  			return m.MarshalText()
  57  		}
  58  
  59  		opts := prototext.MarshalOptions{
  60  			AllowPartial: true,
  61  			EmitUnknown:  true,
  62  		}
  63  		if !tm.Compact {
  64  			opts.Indent = "  "
  65  		}
  66  		if !tm.ExpandAny {
  67  			opts.Resolver = (*protoregistry.Types)(nil)
  68  		}
  69  		return opts.Marshal(mr.Interface())
  70  	} else {
  71  		w := &textWriter{
  72  			compact:   tm.Compact,
  73  			expandAny: tm.ExpandAny,
  74  			complete:  true,
  75  		}
  76  
  77  		if m, ok := m.(encoding.TextMarshaler); ok {
  78  			b, err := m.MarshalText()
  79  			if err != nil {
  80  				return nil, err
  81  			}
  82  			w.Write(b)
  83  			return w.buf, nil
  84  		}
  85  
  86  		err := w.writeMessage(mr)
  87  		return w.buf, err
  88  	}
  89  }
  90  
  91  var (
  92  	defaultTextMarshaler = TextMarshaler{}
  93  	compactTextMarshaler = TextMarshaler{Compact: true}
  94  )
  95  
  96  // MarshalText writes the proto text format of m to w.
  97  func MarshalText(w io.Writer, m Message) error { return defaultTextMarshaler.Marshal(w, m) }
  98  
  99  // MarshalTextString returns a proto text formatted string of m.
 100  func MarshalTextString(m Message) string { return defaultTextMarshaler.Text(m) }
 101  
 102  // CompactText writes the compact proto text format of m to w.
 103  func CompactText(w io.Writer, m Message) error { return compactTextMarshaler.Marshal(w, m) }
 104  
 105  // CompactTextString returns a compact proto text formatted string of m.
 106  func CompactTextString(m Message) string { return compactTextMarshaler.Text(m) }
 107  
 108  var (
 109  	newline         = []byte("\n")
 110  	endBraceNewline = []byte("}\n")
 111  	posInf          = []byte("inf")
 112  	negInf          = []byte("-inf")
 113  	nan             = []byte("nan")
 114  )
 115  
 116  // textWriter is an io.Writer that tracks its indentation level.
 117  type textWriter struct {
 118  	compact   bool // same as TextMarshaler.Compact
 119  	expandAny bool // same as TextMarshaler.ExpandAny
 120  	complete  bool // whether the current position is a complete line
 121  	indent    int  // indentation level; never negative
 122  	buf       []byte
 123  }
 124  
 125  func (w *textWriter) Write(p []byte) (n int, _ error) {
 126  	newlines := bytes.Count(p, newline)
 127  	if newlines == 0 {
 128  		if !w.compact && w.complete {
 129  			w.writeIndent()
 130  		}
 131  		w.buf = append(w.buf, p...)
 132  		w.complete = false
 133  		return len(p), nil
 134  	}
 135  
 136  	frags := bytes.SplitN(p, newline, newlines+1)
 137  	if w.compact {
 138  		for i, frag := range frags {
 139  			if i > 0 {
 140  				w.buf = append(w.buf, ' ')
 141  				n++
 142  			}
 143  			w.buf = append(w.buf, frag...)
 144  			n += len(frag)
 145  		}
 146  		return n, nil
 147  	}
 148  
 149  	for i, frag := range frags {
 150  		if w.complete {
 151  			w.writeIndent()
 152  		}
 153  		w.buf = append(w.buf, frag...)
 154  		n += len(frag)
 155  		if i+1 < len(frags) {
 156  			w.buf = append(w.buf, '\n')
 157  			n++
 158  		}
 159  	}
 160  	w.complete = len(frags[len(frags)-1]) == 0
 161  	return n, nil
 162  }
 163  
 164  func (w *textWriter) WriteByte(c byte) error {
 165  	if w.compact && c == '\n' {
 166  		c = ' '
 167  	}
 168  	if !w.compact && w.complete {
 169  		w.writeIndent()
 170  	}
 171  	w.buf = append(w.buf, c)
 172  	w.complete = c == '\n'
 173  	return nil
 174  }
 175  
 176  func (w *textWriter) writeName(fd protoreflect.FieldDescriptor) {
 177  	if !w.compact && w.complete {
 178  		w.writeIndent()
 179  	}
 180  	w.complete = false
 181  
 182  	if fd.Kind() != protoreflect.GroupKind {
 183  		w.buf = append(w.buf, fd.Name()...)
 184  		w.WriteByte(':')
 185  	} else {
 186  		// Use message type name for group field name.
 187  		w.buf = append(w.buf, fd.Message().Name()...)
 188  	}
 189  
 190  	if !w.compact {
 191  		w.WriteByte(' ')
 192  	}
 193  }
 194  
 195  func requiresQuotes(u string) bool {
 196  	// When type URL contains any characters except [0-9A-Za-z./\-]*, it must be quoted.
 197  	for _, ch := range u {
 198  		switch {
 199  		case ch == '.' || ch == '/' || ch == '_':
 200  			continue
 201  		case '0' <= ch && ch <= '9':
 202  			continue
 203  		case 'A' <= ch && ch <= 'Z':
 204  			continue
 205  		case 'a' <= ch && ch <= 'z':
 206  			continue
 207  		default:
 208  			return true
 209  		}
 210  	}
 211  	return false
 212  }
 213  
 214  // writeProto3Any writes an expanded google.protobuf.Any message.
 215  //
 216  // It returns (false, nil) if sv value can't be unmarshaled (e.g. because
 217  // required messages are not linked in).
 218  //
 219  // It returns (true, error) when sv was written in expanded format or an error
 220  // was encountered.
 221  func (w *textWriter) writeProto3Any(m protoreflect.Message) (bool, error) {
 222  	md := m.Descriptor()
 223  	fdURL := md.Fields().ByName("type_url")
 224  	fdVal := md.Fields().ByName("value")
 225  
 226  	url := m.Get(fdURL).String()
 227  	mt, err := protoregistry.GlobalTypes.FindMessageByURL(url)
 228  	if err != nil {
 229  		return false, nil
 230  	}
 231  
 232  	b := m.Get(fdVal).Bytes()
 233  	m2 := mt.New()
 234  	if err := proto.Unmarshal(b, m2.Interface()); err != nil {
 235  		return false, nil
 236  	}
 237  	w.Write([]byte("["))
 238  	if requiresQuotes(url) {
 239  		w.writeQuotedString(url)
 240  	} else {
 241  		w.Write([]byte(url))
 242  	}
 243  	if w.compact {
 244  		w.Write([]byte("]:<"))
 245  	} else {
 246  		w.Write([]byte("]: <\n"))
 247  		w.indent++
 248  	}
 249  	if err := w.writeMessage(m2); err != nil {
 250  		return true, err
 251  	}
 252  	if w.compact {
 253  		w.Write([]byte("> "))
 254  	} else {
 255  		w.indent--
 256  		w.Write([]byte(">\n"))
 257  	}
 258  	return true, nil
 259  }
 260  
 261  func (w *textWriter) writeMessage(m protoreflect.Message) error {
 262  	md := m.Descriptor()
 263  	if w.expandAny && md.FullName() == "google.protobuf.Any" {
 264  		if canExpand, err := w.writeProto3Any(m); canExpand {
 265  			return err
 266  		}
 267  	}
 268  
 269  	fds := md.Fields()
 270  	for i := 0; i < fds.Len(); {
 271  		fd := fds.Get(i)
 272  		if od := fd.ContainingOneof(); od != nil {
 273  			fd = m.WhichOneof(od)
 274  			i += od.Fields().Len()
 275  		} else {
 276  			i++
 277  		}
 278  		if fd == nil || !m.Has(fd) {
 279  			continue
 280  		}
 281  
 282  		switch {
 283  		case fd.IsList():
 284  			lv := m.Get(fd).List()
 285  			for j := 0; j < lv.Len(); j++ {
 286  				w.writeName(fd)
 287  				v := lv.Get(j)
 288  				if err := w.writeSingularValue(v, fd); err != nil {
 289  					return err
 290  				}
 291  				w.WriteByte('\n')
 292  			}
 293  		case fd.IsMap():
 294  			kfd := fd.MapKey()
 295  			vfd := fd.MapValue()
 296  			mv := m.Get(fd).Map()
 297  
 298  			type entry struct{ key, val protoreflect.Value }
 299  			var entries []entry
 300  			mv.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
 301  				entries = append(entries, entry{k.Value(), v})
 302  				return true
 303  			})
 304  			sort.Slice(entries, func(i, j int) bool {
 305  				switch kfd.Kind() {
 306  				case protoreflect.BoolKind:
 307  					return !entries[i].key.Bool() && entries[j].key.Bool()
 308  				case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
 309  					return entries[i].key.Int() < entries[j].key.Int()
 310  				case protoreflect.Uint32Kind, protoreflect.Fixed32Kind, protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
 311  					return entries[i].key.Uint() < entries[j].key.Uint()
 312  				case protoreflect.StringKind:
 313  					return entries[i].key.String() < entries[j].key.String()
 314  				default:
 315  					panic("invalid kind")
 316  				}
 317  			})
 318  			for _, entry := range entries {
 319  				w.writeName(fd)
 320  				w.WriteByte('<')
 321  				if !w.compact {
 322  					w.WriteByte('\n')
 323  				}
 324  				w.indent++
 325  				w.writeName(kfd)
 326  				if err := w.writeSingularValue(entry.key, kfd); err != nil {
 327  					return err
 328  				}
 329  				w.WriteByte('\n')
 330  				w.writeName(vfd)
 331  				if err := w.writeSingularValue(entry.val, vfd); err != nil {
 332  					return err
 333  				}
 334  				w.WriteByte('\n')
 335  				w.indent--
 336  				w.WriteByte('>')
 337  				w.WriteByte('\n')
 338  			}
 339  		default:
 340  			w.writeName(fd)
 341  			if err := w.writeSingularValue(m.Get(fd), fd); err != nil {
 342  				return err
 343  			}
 344  			w.WriteByte('\n')
 345  		}
 346  	}
 347  
 348  	if b := m.GetUnknown(); len(b) > 0 {
 349  		w.writeUnknownFields(b)
 350  	}
 351  	return w.writeExtensions(m)
 352  }
 353  
 354  func (w *textWriter) writeSingularValue(v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
 355  	switch fd.Kind() {
 356  	case protoreflect.FloatKind, protoreflect.DoubleKind:
 357  		switch vf := v.Float(); {
 358  		case math.IsInf(vf, +1):
 359  			w.Write(posInf)
 360  		case math.IsInf(vf, -1):
 361  			w.Write(negInf)
 362  		case math.IsNaN(vf):
 363  			w.Write(nan)
 364  		default:
 365  			fmt.Fprint(w, v.Interface())
 366  		}
 367  	case protoreflect.StringKind:
 368  		// NOTE: This does not validate UTF-8 for historical reasons.
 369  		w.writeQuotedString(string(v.String()))
 370  	case protoreflect.BytesKind:
 371  		w.writeQuotedString(string(v.Bytes()))
 372  	case protoreflect.MessageKind, protoreflect.GroupKind:
 373  		var bra, ket byte = '<', '>'
 374  		if fd.Kind() == protoreflect.GroupKind {
 375  			bra, ket = '{', '}'
 376  		}
 377  		w.WriteByte(bra)
 378  		if !w.compact {
 379  			w.WriteByte('\n')
 380  		}
 381  		w.indent++
 382  		m := v.Message()
 383  		if m2, ok := m.Interface().(encoding.TextMarshaler); ok {
 384  			b, err := m2.MarshalText()
 385  			if err != nil {
 386  				return err
 387  			}
 388  			w.Write(b)
 389  		} else {
 390  			w.writeMessage(m)
 391  		}
 392  		w.indent--
 393  		w.WriteByte(ket)
 394  	case protoreflect.EnumKind:
 395  		if ev := fd.Enum().Values().ByNumber(v.Enum()); ev != nil {
 396  			fmt.Fprint(w, ev.Name())
 397  		} else {
 398  			fmt.Fprint(w, v.Enum())
 399  		}
 400  	default:
 401  		fmt.Fprint(w, v.Interface())
 402  	}
 403  	return nil
 404  }
 405  
 406  // writeQuotedString writes a quoted string in the protocol buffer text format.
 407  func (w *textWriter) writeQuotedString(s string) {
 408  	w.WriteByte('"')
 409  	for i := 0; i < len(s); i++ {
 410  		switch c := s[i]; c {
 411  		case '\n':
 412  			w.buf = append(w.buf, `\n`...)
 413  		case '\r':
 414  			w.buf = append(w.buf, `\r`...)
 415  		case '\t':
 416  			w.buf = append(w.buf, `\t`...)
 417  		case '"':
 418  			w.buf = append(w.buf, `\"`...)
 419  		case '\\':
 420  			w.buf = append(w.buf, `\\`...)
 421  		default:
 422  			if isPrint := c >= 0x20 && c < 0x7f; isPrint {
 423  				w.buf = append(w.buf, c)
 424  			} else {
 425  				w.buf = append(w.buf, fmt.Sprintf(`\%03o`, c)...)
 426  			}
 427  		}
 428  	}
 429  	w.WriteByte('"')
 430  }
 431  
 432  func (w *textWriter) writeUnknownFields(b []byte) {
 433  	if !w.compact {
 434  		fmt.Fprintf(w, "/* %d unknown bytes */\n", len(b))
 435  	}
 436  
 437  	for len(b) > 0 {
 438  		num, wtyp, n := protowire.ConsumeTag(b)
 439  		if n < 0 {
 440  			return
 441  		}
 442  		b = b[n:]
 443  
 444  		if wtyp == protowire.EndGroupType {
 445  			w.indent--
 446  			w.Write(endBraceNewline)
 447  			continue
 448  		}
 449  		fmt.Fprint(w, num)
 450  		if wtyp != protowire.StartGroupType {
 451  			w.WriteByte(':')
 452  		}
 453  		if !w.compact || wtyp == protowire.StartGroupType {
 454  			w.WriteByte(' ')
 455  		}
 456  		switch wtyp {
 457  		case protowire.VarintType:
 458  			v, n := protowire.ConsumeVarint(b)
 459  			if n < 0 {
 460  				return
 461  			}
 462  			b = b[n:]
 463  			fmt.Fprint(w, v)
 464  		case protowire.Fixed32Type:
 465  			v, n := protowire.ConsumeFixed32(b)
 466  			if n < 0 {
 467  				return
 468  			}
 469  			b = b[n:]
 470  			fmt.Fprint(w, v)
 471  		case protowire.Fixed64Type:
 472  			v, n := protowire.ConsumeFixed64(b)
 473  			if n < 0 {
 474  				return
 475  			}
 476  			b = b[n:]
 477  			fmt.Fprint(w, v)
 478  		case protowire.BytesType:
 479  			v, n := protowire.ConsumeBytes(b)
 480  			if n < 0 {
 481  				return
 482  			}
 483  			b = b[n:]
 484  			fmt.Fprintf(w, "%q", v)
 485  		case protowire.StartGroupType:
 486  			w.WriteByte('{')
 487  			w.indent++
 488  		default:
 489  			fmt.Fprintf(w, "/* unknown wire type %d */", wtyp)
 490  		}
 491  		w.WriteByte('\n')
 492  	}
 493  }
 494  
 495  // writeExtensions writes all the extensions in m.
 496  func (w *textWriter) writeExtensions(m protoreflect.Message) error {
 497  	md := m.Descriptor()
 498  	if md.ExtensionRanges().Len() == 0 {
 499  		return nil
 500  	}
 501  
 502  	type ext struct {
 503  		desc protoreflect.FieldDescriptor
 504  		val  protoreflect.Value
 505  	}
 506  	var exts []ext
 507  	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
 508  		if fd.IsExtension() {
 509  			exts = append(exts, ext{fd, v})
 510  		}
 511  		return true
 512  	})
 513  	sort.Slice(exts, func(i, j int) bool {
 514  		return exts[i].desc.Number() < exts[j].desc.Number()
 515  	})
 516  
 517  	for _, ext := range exts {
 518  		// For message set, use the name of the message as the extension name.
 519  		name := string(ext.desc.FullName())
 520  		if isMessageSet(ext.desc.ContainingMessage()) {
 521  			name = strings.TrimSuffix(name, ".message_set_extension")
 522  		}
 523  
 524  		if !ext.desc.IsList() {
 525  			if err := w.writeSingularExtension(name, ext.val, ext.desc); err != nil {
 526  				return err
 527  			}
 528  		} else {
 529  			lv := ext.val.List()
 530  			for i := 0; i < lv.Len(); i++ {
 531  				if err := w.writeSingularExtension(name, lv.Get(i), ext.desc); err != nil {
 532  					return err
 533  				}
 534  			}
 535  		}
 536  	}
 537  	return nil
 538  }
 539  
 540  func (w *textWriter) writeSingularExtension(name string, v protoreflect.Value, fd protoreflect.FieldDescriptor) error {
 541  	fmt.Fprintf(w, "[%s]:", name)
 542  	if !w.compact {
 543  		w.WriteByte(' ')
 544  	}
 545  	if err := w.writeSingularValue(v, fd); err != nil {
 546  		return err
 547  	}
 548  	w.WriteByte('\n')
 549  	return nil
 550  }
 551  
 552  func (w *textWriter) writeIndent() {
 553  	if !w.complete {
 554  		return
 555  	}
 556  	for i := 0; i < w.indent*2; i++ {
 557  		w.buf = append(w.buf, ' ')
 558  	}
 559  	w.complete = false
 560  }
 561