// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package impl import ( "fmt" "reflect" "google.golang.org/protobuf/encoding/protowire" "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/reflect/protoreflect" ) // pointerCoderFuncs is a set of pointer encoding functions. type pointerCoderFuncs struct { mi *MessageInfo size func(p pointer, f *coderFieldInfo, opts marshalOptions) int marshal func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) unmarshal func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) isInit func(p pointer, f *coderFieldInfo) error merge func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) } // valueCoderFuncs is a set of protoreflect.Value encoding functions. type valueCoderFuncs struct { size func(v protoreflect.Value, tagsize int, opts marshalOptions) int marshal func(b []byte, v protoreflect.Value, wiretag uint64, opts marshalOptions) ([]byte, error) unmarshal func(b []byte, v protoreflect.Value, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (protoreflect.Value, unmarshalOutput, error) isInit func(v protoreflect.Value) error merge func(dst, src protoreflect.Value, opts mergeOptions) protoreflect.Value } // fieldCoder returns pointer functions for a field, used for operating on // struct fields. func fieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) { switch { case fd.IsMap(): return encoderFuncsForMap(fd, ft) case fd.Cardinality() == protoreflect.Repeated && !fd.IsPacked(): // Repeated fields (not packed). if ft.Kind() != reflect.Slice { break } ft := ft.Elem() switch fd.Kind() { case protoreflect.BoolKind: if ft.Kind() == reflect.Bool { return nil, coderBoolSlice } case protoreflect.EnumKind: if ft.Kind() == reflect.Int32 { return nil, coderEnumSlice } case protoreflect.Int32Kind: if ft.Kind() == reflect.Int32 { return nil, coderInt32Slice } case protoreflect.Sint32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSint32Slice } case protoreflect.Uint32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderUint32Slice } case protoreflect.Int64Kind: if ft.Kind() == reflect.Int64 { return nil, coderInt64Slice } case protoreflect.Sint64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSint64Slice } case protoreflect.Uint64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderUint64Slice } case protoreflect.Sfixed32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSfixed32Slice } case protoreflect.Fixed32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderFixed32Slice } case protoreflect.FloatKind: if ft.Kind() == reflect.Float32 { return nil, coderFloatSlice } case protoreflect.Sfixed64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSfixed64Slice } case protoreflect.Fixed64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderFixed64Slice } case protoreflect.DoubleKind: if ft.Kind() == reflect.Float64 { return nil, coderDoubleSlice } case protoreflect.StringKind: if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return nil, coderStringSliceValidateUTF8 } if ft.Kind() == reflect.String { return nil, coderStringSlice } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { return nil, coderBytesSliceValidateUTF8 } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytesSlice } case protoreflect.BytesKind: if ft.Kind() == reflect.String { return nil, coderStringSlice } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytesSlice } case protoreflect.MessageKind: return getMessageInfo(ft), makeMessageSliceFieldCoder(fd, ft) case protoreflect.GroupKind: return getMessageInfo(ft), makeGroupSliceFieldCoder(fd, ft) } case fd.Cardinality() == protoreflect.Repeated && fd.IsPacked(): // Packed repeated fields. // // Only repeated fields of primitive numeric types // (Varint, Fixed32, or Fixed64 wire type) can be packed. if ft.Kind() != reflect.Slice { break } ft := ft.Elem() switch fd.Kind() { case protoreflect.BoolKind: if ft.Kind() == reflect.Bool { return nil, coderBoolPackedSlice } case protoreflect.EnumKind: if ft.Kind() == reflect.Int32 { return nil, coderEnumPackedSlice } case protoreflect.Int32Kind: if ft.Kind() == reflect.Int32 { return nil, coderInt32PackedSlice } case protoreflect.Sint32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSint32PackedSlice } case protoreflect.Uint32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderUint32PackedSlice } case protoreflect.Int64Kind: if ft.Kind() == reflect.Int64 { return nil, coderInt64PackedSlice } case protoreflect.Sint64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSint64PackedSlice } case protoreflect.Uint64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderUint64PackedSlice } case protoreflect.Sfixed32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSfixed32PackedSlice } case protoreflect.Fixed32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderFixed32PackedSlice } case protoreflect.FloatKind: if ft.Kind() == reflect.Float32 { return nil, coderFloatPackedSlice } case protoreflect.Sfixed64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSfixed64PackedSlice } case protoreflect.Fixed64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderFixed64PackedSlice } case protoreflect.DoubleKind: if ft.Kind() == reflect.Float64 { return nil, coderDoublePackedSlice } } case fd.Kind() == protoreflect.MessageKind: return getMessageInfo(ft), makeMessageFieldCoder(fd, ft) case fd.Kind() == protoreflect.GroupKind: return getMessageInfo(ft), makeGroupFieldCoder(fd, ft) case fd.Syntax() == protoreflect.Proto3 && fd.ContainingOneof() == nil: // Populated oneof fields always encode even if set to the zero value, // which normally are not encoded in proto3. switch fd.Kind() { case protoreflect.BoolKind: if ft.Kind() == reflect.Bool { return nil, coderBoolNoZero } case protoreflect.EnumKind: if ft.Kind() == reflect.Int32 { return nil, coderEnumNoZero } case protoreflect.Int32Kind: if ft.Kind() == reflect.Int32 { return nil, coderInt32NoZero } case protoreflect.Sint32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSint32NoZero } case protoreflect.Uint32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderUint32NoZero } case protoreflect.Int64Kind: if ft.Kind() == reflect.Int64 { return nil, coderInt64NoZero } case protoreflect.Sint64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSint64NoZero } case protoreflect.Uint64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderUint64NoZero } case protoreflect.Sfixed32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSfixed32NoZero } case protoreflect.Fixed32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderFixed32NoZero } case protoreflect.FloatKind: if ft.Kind() == reflect.Float32 { return nil, coderFloatNoZero } case protoreflect.Sfixed64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSfixed64NoZero } case protoreflect.Fixed64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderFixed64NoZero } case protoreflect.DoubleKind: if ft.Kind() == reflect.Float64 { return nil, coderDoubleNoZero } case protoreflect.StringKind: if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return nil, coderStringNoZeroValidateUTF8 } if ft.Kind() == reflect.String { return nil, coderStringNoZero } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { return nil, coderBytesNoZeroValidateUTF8 } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytesNoZero } case protoreflect.BytesKind: if ft.Kind() == reflect.String { return nil, coderStringNoZero } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytesNoZero } } case ft.Kind() == reflect.Ptr: ft := ft.Elem() switch fd.Kind() { case protoreflect.BoolKind: if ft.Kind() == reflect.Bool { return nil, coderBoolPtr } case protoreflect.EnumKind: if ft.Kind() == reflect.Int32 { return nil, coderEnumPtr } case protoreflect.Int32Kind: if ft.Kind() == reflect.Int32 { return nil, coderInt32Ptr } case protoreflect.Sint32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSint32Ptr } case protoreflect.Uint32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderUint32Ptr } case protoreflect.Int64Kind: if ft.Kind() == reflect.Int64 { return nil, coderInt64Ptr } case protoreflect.Sint64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSint64Ptr } case protoreflect.Uint64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderUint64Ptr } case protoreflect.Sfixed32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSfixed32Ptr } case protoreflect.Fixed32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderFixed32Ptr } case protoreflect.FloatKind: if ft.Kind() == reflect.Float32 { return nil, coderFloatPtr } case protoreflect.Sfixed64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSfixed64Ptr } case protoreflect.Fixed64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderFixed64Ptr } case protoreflect.DoubleKind: if ft.Kind() == reflect.Float64 { return nil, coderDoublePtr } case protoreflect.StringKind: if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return nil, coderStringPtrValidateUTF8 } if ft.Kind() == reflect.String { return nil, coderStringPtr } case protoreflect.BytesKind: if ft.Kind() == reflect.String { return nil, coderStringPtr } } default: switch fd.Kind() { case protoreflect.BoolKind: if ft.Kind() == reflect.Bool { return nil, coderBool } case protoreflect.EnumKind: if ft.Kind() == reflect.Int32 { return nil, coderEnum } case protoreflect.Int32Kind: if ft.Kind() == reflect.Int32 { return nil, coderInt32 } case protoreflect.Sint32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSint32 } case protoreflect.Uint32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderUint32 } case protoreflect.Int64Kind: if ft.Kind() == reflect.Int64 { return nil, coderInt64 } case protoreflect.Sint64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSint64 } case protoreflect.Uint64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderUint64 } case protoreflect.Sfixed32Kind: if ft.Kind() == reflect.Int32 { return nil, coderSfixed32 } case protoreflect.Fixed32Kind: if ft.Kind() == reflect.Uint32 { return nil, coderFixed32 } case protoreflect.FloatKind: if ft.Kind() == reflect.Float32 { return nil, coderFloat } case protoreflect.Sfixed64Kind: if ft.Kind() == reflect.Int64 { return nil, coderSfixed64 } case protoreflect.Fixed64Kind: if ft.Kind() == reflect.Uint64 { return nil, coderFixed64 } case protoreflect.DoubleKind: if ft.Kind() == reflect.Float64 { return nil, coderDouble } case protoreflect.StringKind: if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) { return nil, coderStringValidateUTF8 } if ft.Kind() == reflect.String { return nil, coderString } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) { return nil, coderBytesValidateUTF8 } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytes } case protoreflect.BytesKind: if ft.Kind() == reflect.String { return nil, coderString } if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 { return nil, coderBytes } } } panic(fmt.Sprintf("invalid type: no encoder for %v %v %v/%v", fd.FullName(), fd.Cardinality(), fd.Kind(), ft)) } // encoderFuncsForValue returns value functions for a field, used for // extension values and map encoding. func encoderFuncsForValue(fd protoreflect.FieldDescriptor) valueCoderFuncs { switch { case fd.Cardinality() == protoreflect.Repeated && !fd.IsPacked(): switch fd.Kind() { case protoreflect.BoolKind: return coderBoolSliceValue case protoreflect.EnumKind: return coderEnumSliceValue case protoreflect.Int32Kind: return coderInt32SliceValue case protoreflect.Sint32Kind: return coderSint32SliceValue case protoreflect.Uint32Kind: return coderUint32SliceValue case protoreflect.Int64Kind: return coderInt64SliceValue case protoreflect.Sint64Kind: return coderSint64SliceValue case protoreflect.Uint64Kind: return coderUint64SliceValue case protoreflect.Sfixed32Kind: return coderSfixed32SliceValue case protoreflect.Fixed32Kind: return coderFixed32SliceValue case protoreflect.FloatKind: return coderFloatSliceValue case protoreflect.Sfixed64Kind: return coderSfixed64SliceValue case protoreflect.Fixed64Kind: return coderFixed64SliceValue case protoreflect.DoubleKind: return coderDoubleSliceValue case protoreflect.StringKind: // We don't have a UTF-8 validating coder for repeated string fields. // Value coders are used for extensions and maps. // Extensions are never proto3, and maps never contain lists. return coderStringSliceValue case protoreflect.BytesKind: return coderBytesSliceValue case protoreflect.MessageKind: return coderMessageSliceValue case protoreflect.GroupKind: return coderGroupSliceValue } case fd.Cardinality() == protoreflect.Repeated && fd.IsPacked(): switch fd.Kind() { case protoreflect.BoolKind: return coderBoolPackedSliceValue case protoreflect.EnumKind: return coderEnumPackedSliceValue case protoreflect.Int32Kind: return coderInt32PackedSliceValue case protoreflect.Sint32Kind: return coderSint32PackedSliceValue case protoreflect.Uint32Kind: return coderUint32PackedSliceValue case protoreflect.Int64Kind: return coderInt64PackedSliceValue case protoreflect.Sint64Kind: return coderSint64PackedSliceValue case protoreflect.Uint64Kind: return coderUint64PackedSliceValue case protoreflect.Sfixed32Kind: return coderSfixed32PackedSliceValue case protoreflect.Fixed32Kind: return coderFixed32PackedSliceValue case protoreflect.FloatKind: return coderFloatPackedSliceValue case protoreflect.Sfixed64Kind: return coderSfixed64PackedSliceValue case protoreflect.Fixed64Kind: return coderFixed64PackedSliceValue case protoreflect.DoubleKind: return coderDoublePackedSliceValue } default: switch fd.Kind() { default: case protoreflect.BoolKind: return coderBoolValue case protoreflect.EnumKind: return coderEnumValue case protoreflect.Int32Kind: return coderInt32Value case protoreflect.Sint32Kind: return coderSint32Value case protoreflect.Uint32Kind: return coderUint32Value case protoreflect.Int64Kind: return coderInt64Value case protoreflect.Sint64Kind: return coderSint64Value case protoreflect.Uint64Kind: return coderUint64Value case protoreflect.Sfixed32Kind: return coderSfixed32Value case protoreflect.Fixed32Kind: return coderFixed32Value case protoreflect.FloatKind: return coderFloatValue case protoreflect.Sfixed64Kind: return coderSfixed64Value case protoreflect.Fixed64Kind: return coderFixed64Value case protoreflect.DoubleKind: return coderDoubleValue case protoreflect.StringKind: if strs.EnforceUTF8(fd) { return coderStringValueValidateUTF8 } return coderStringValue case protoreflect.BytesKind: return coderBytesValue case protoreflect.MessageKind: return coderMessageValue case protoreflect.GroupKind: return coderGroupValue } } panic(fmt.Sprintf("invalid field: no encoder for %v %v %v", fd.FullName(), fd.Cardinality(), fd.Kind())) }