2121 timedelta ,
2222 timezone ,
2323)
24+ from enum import IntEnum
2425from io import BytesIO
2526from itertools import count
2627from typing import (
@@ -560,6 +561,15 @@ def _get_cls_by_field(cls: Type["Message"], fields: Iterable[dataclasses.Field])
560561 return field_cls
561562
562563
564+ class OutputFormat (IntEnum ):
565+ """
566+ Chosen output format for the `Message.to_dict` method.
567+ """
568+
569+ PYTHON = 1
570+ PROTO_JSON = 2
571+
572+
563573class Message (ABC ):
564574 """
565575 The base class for protobuf messages, all generated messages will inherit from
@@ -606,10 +616,6 @@ def __repr__(self) -> str:
606616 ]
607617 return f"{ self .__class__ .__name__ } ({ ', ' .join (parts )} )"
608618
609- # def __rich_repr__(self) -> Iterable[Tuple[str, Any, Any]]:
610- # for field_name in self._betterproto.sorted_field_names:
611- # yield field_name, self.__getattribute__(field_name), PLACEHOLDER
612-
613619 def __bool__ (self ) -> bool :
614620 """True if the message has any fields with non-default values."""
615621 return any (
@@ -946,9 +952,15 @@ def FromString(cls: Type[T], data: bytes) -> T:
946952 """
947953 return cls ().parse (data )
948954
949- def to_dict (self , casing : Casing = Casing .CAMEL , include_default_values : bool = False ) -> Dict [str , Any ]:
955+ def to_dict (
956+ self ,
957+ * ,
958+ output_format : OutputFormat = OutputFormat .PROTO_JSON ,
959+ casing : Casing = Casing .CAMEL ,
960+ include_default_values : bool = False ,
961+ ) -> Dict [str , Any ]:
950962 """
951- Returns a JSON serializable dict representation of this object .
963+ Return a dict representation of the message .
952964
953965 Parameters
954966 -----------
@@ -965,6 +977,12 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
965977 Dict[:class:`str`, Any]
966978 The JSON serializable dict representation of this object.
967979 """
980+ kwargs = { # For recursive calls
981+ "output_format" : output_format ,
982+ "casing" : casing ,
983+ "include_default_values" : include_default_values ,
984+ }
985+
968986 output : Dict [str , Any ] = {}
969987 field_types = self ._type_hints ()
970988 for field_name , meta in self ._betterproto .meta_by_field_name .items ():
@@ -973,74 +991,87 @@ def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool =
973991 cased_name = casing (field_name ).rstrip ("_" ) # type: ignore
974992 if meta .proto_type == TYPE_MESSAGE :
975993 if isinstance (value , datetime ):
976- output [cased_name ] = _Timestamp .timestamp_to_json (value )
994+ if output_format == OutputFormat .PROTO_JSON :
995+ output [cased_name ] = _Timestamp .timestamp_to_json (value )
996+ else :
997+ output [cased_name ] = value
977998 elif isinstance (value , timedelta ):
978- output [cased_name ] = _Duration .delta_to_json (value )
999+ if output_format == OutputFormat .PROTO_JSON :
1000+ output [cased_name ] = _Duration .delta_to_json (value )
1001+ else :
1002+ output [cased_name ] = value
1003+
9791004 elif meta .wraps :
9801005 if value is not None or include_default_values :
9811006 output [cased_name ] = value
9821007 elif field_is_repeated :
9831008 # Convert each item.
984- cls = self ._betterproto .cls_by_field [field_name ]
985- if cls == datetime :
986- value = [_Timestamp .timestamp_to_json (i ) for i in value ]
987- elif cls == timedelta :
988- value = [_Duration .delta_to_json (i ) for i in value ]
1009+ if output_format == OutputFormat .PYTHON :
1010+ value = [i .to_dict (** kwargs ) for i in value ]
9891011 else :
990- value = [i .to_dict (casing , include_default_values ) for i in value ]
1012+ cls = self ._betterproto .cls_by_field [field_name ]
1013+ if cls == datetime :
1014+ value = [_Timestamp .timestamp_to_json (i ) for i in value ]
1015+ elif cls == timedelta :
1016+ value = [_Duration .delta_to_json (i ) for i in value ]
1017+ else :
1018+ value = [i .to_dict (** kwargs ) for i in value ]
9911019 if value or include_default_values :
9921020 output [cased_name ] = value
9931021 elif value is None :
9941022 if include_default_values :
995- output [cased_name ] = value
1023+ output [cased_name ] = None
9961024 else :
997- output [cased_name ] = value .to_dict (casing , include_default_values )
1025+ output [cased_name ] = value .to_dict (** kwargs )
9981026 elif meta .proto_type == TYPE_MAP :
9991027 output_map = {** value }
10001028 for k in value :
10011029 if hasattr (value [k ], "to_dict" ):
1002- output_map [k ] = value [k ].to_dict (casing , include_default_values )
1030+ output_map [k ] = value [k ].to_dict (** kwargs )
10031031
10041032 if value or include_default_values :
10051033 output [cased_name ] = output_map
10061034 elif value != self ._get_field_default (field_name ) or include_default_values :
1007- if meta .proto_type in INT_64_TYPES :
1008- if field_is_repeated :
1009- output [cased_name ] = [str (n ) for n in value ]
1010- elif value is None :
1011- if include_default_values :
1012- output [cased_name ] = value
1013- else :
1014- output [cased_name ] = str (value )
1015- elif meta .proto_type == TYPE_BYTES :
1016- if field_is_repeated :
1017- output [cased_name ] = [b64encode (b ).decode ("utf8" ) for b in value ]
1018- elif value is None and include_default_values :
1019- output [cased_name ] = value
1020- else :
1021- output [cased_name ] = b64encode (value ).decode ("utf8" )
1022- elif meta .proto_type == TYPE_ENUM :
1023- if field_is_repeated :
1024- enum_class = field_types [field_name ].__args__ [0 ]
1025- if isinstance (value , typing .Iterable ) and not isinstance (value , str ):
1026- output [cased_name ] = [enum_class (el ).name for el in value ]
1035+ if output_format == OutputFormat .PROTO_JSON :
1036+ if meta .proto_type in INT_64_TYPES :
1037+ if field_is_repeated :
1038+ output [cased_name ] = [str (n ) for n in value ]
1039+ elif value is None :
1040+ if include_default_values :
1041+ output [cased_name ] = value
10271042 else :
1028- # transparently upgrade single value to repeated
1029- output [cased_name ] = [enum_class (value ).name ]
1030- elif value is None :
1031- if include_default_values :
1043+ output [cased_name ] = str (value )
1044+ elif meta .proto_type == TYPE_BYTES :
1045+ if field_is_repeated :
1046+ output [cased_name ] = [b64encode (b ).decode ("utf8" ) for b in value ]
1047+ elif value is None and include_default_values :
10321048 output [cased_name ] = value
1033- elif meta .optional :
1034- enum_class = field_types [field_name ].__args__ [0 ]
1035- output [cased_name ] = enum_class (value ).name
1036- else :
1037- enum_class = field_types [field_name ] # noqa
1038- output [cased_name ] = enum_class (value ).name
1039- elif meta .proto_type in (TYPE_FLOAT , TYPE_DOUBLE ):
1040- if field_is_repeated :
1041- output [cased_name ] = [_dump_float (n ) for n in value ]
1049+ else :
1050+ output [cased_name ] = b64encode (value ).decode ("utf8" )
1051+ elif meta .proto_type == TYPE_ENUM :
1052+ if field_is_repeated :
1053+ enum_class = field_types [field_name ].__args__ [0 ]
1054+ if isinstance (value , typing .Iterable ) and not isinstance (value , str ):
1055+ output [cased_name ] = [enum_class (el ).name for el in value ]
1056+ else :
1057+ # transparently upgrade single value to repeated
1058+ output [cased_name ] = [enum_class (value ).name ]
1059+ elif value is None :
1060+ if include_default_values :
1061+ output [cased_name ] = value
1062+ elif meta .optional :
1063+ enum_class = field_types [field_name ].__args__ [0 ]
1064+ output [cased_name ] = enum_class (value ).name
1065+ else :
1066+ enum_class = field_types [field_name ] # noqa
1067+ output [cased_name ] = enum_class (value ).name
1068+ elif meta .proto_type in (TYPE_FLOAT , TYPE_DOUBLE ):
1069+ if field_is_repeated :
1070+ output [cased_name ] = [_dump_float (n ) for n in value ]
1071+ else :
1072+ output [cased_name ] = _dump_float (value )
10421073 else :
1043- output [cased_name ] = _dump_float ( value )
1074+ output [cased_name ] = value
10441075 else :
10451076 output [cased_name ] = value
10461077 return output
@@ -1188,69 +1219,6 @@ def from_json(self: T, value: Union[str, bytes]) -> T:
11881219 """
11891220 return self .from_dict (json .loads (value ))
11901221
1191- def to_pydict (self , casing : Casing = Casing .CAMEL , include_default_values : bool = False ) -> Dict [str , Any ]:
1192- """
1193- Returns a python dict representation of this object.
1194-
1195- Parameters
1196- -----------
1197- casing: :class:`Casing`
1198- The casing to use for key values. Default is :attr:`Casing.CAMEL` for
1199- compatibility purposes.
1200- include_default_values: :class:`bool`
1201- If ``True`` will include the default values of fields. Default is ``False``.
1202- E.g. an ``int32`` field will be included with a value of ``0`` if this is
1203- set to ``True``, otherwise this would be ignored.
1204-
1205- Returns
1206- --------
1207- Dict[:class:`str`, Any]
1208- The python dict representation of this object.
1209- """
1210- output : Dict [str , Any ] = {}
1211- for field_name , meta in self ._betterproto .meta_by_field_name .items ():
1212- field_is_repeated = meta .repeated
1213- value = getattr (self , field_name )
1214- cased_name = casing (field_name ).rstrip ("_" ) # type: ignore
1215- if meta .proto_type == TYPE_MESSAGE :
1216- if isinstance (value , datetime ):
1217- if (
1218- value != DATETIME_ZERO
1219- or include_default_values
1220- or self ._include_default_value_for_oneof (field_name = field_name , meta = meta )
1221- ):
1222- output [cased_name ] = value
1223- elif isinstance (value , timedelta ):
1224- if (
1225- value != timedelta (0 )
1226- or include_default_values
1227- or self ._include_default_value_for_oneof (field_name = field_name , meta = meta )
1228- ):
1229- output [cased_name ] = value
1230- elif meta .wraps :
1231- if value is not None or include_default_values :
1232- output [cased_name ] = value
1233- elif field_is_repeated :
1234- # Convert each item.
1235- value = [i .to_pydict (casing , include_default_values ) for i in value ]
1236- if value or include_default_values :
1237- output [cased_name ] = value
1238- elif value is None :
1239- if include_default_values :
1240- output [cased_name ] = None
1241- else :
1242- output [cased_name ] = value .to_pydict (casing , include_default_values )
1243- elif meta .proto_type == TYPE_MAP :
1244- for k in value :
1245- if hasattr (value [k ], "to_pydict" ):
1246- value [k ] = value [k ].to_pydict (casing , include_default_values )
1247-
1248- if value or include_default_values :
1249- output [cased_name ] = value
1250- elif value != self ._get_field_default (field_name ) or include_default_values :
1251- output [cased_name ] = value
1252- return output
1253-
12541222 def from_pydict (self : T , value : Mapping [str , Any ]) -> T :
12551223 """
12561224 Parse the key/value pairs into the current message instance. This returns the
0 commit comments