Skip to content

Commit dfc0ade

Browse files
Add back struct support (#102)
* Test struct to dict * Fix test struct * Struct support * Add missing file * Make Any's pack a class method * Fix Any.from_dict * Fix remaining problems * Update compiler lib * Fix test * Fix typechecking * Remove typing.Self * Add more tests * Remove useless file * Remove wrapping for structs * Fix typechecking * Remove JSON * Switch from is to == * Fix comparaison
1 parent 5b1ed7a commit dfc0ade

File tree

12 files changed

+477
-157
lines changed

12 files changed

+477
-157
lines changed

betterproto2/src/betterproto2/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def __bytes__(self) -> bytes:
779779
# Default (zero) values are not serialized.
780780
continue
781781

782-
if isinstance(value, list):
782+
if meta.repeated:
783783
if meta.proto_type in PACKED_TYPES:
784784
# Packed lists look like a length-delimited field. First,
785785
# preprocess/encode each value into a buffer and then
@@ -802,9 +802,8 @@ def __bytes__(self) -> bytes:
802802
or b"\n\x00"
803803
)
804804

805-
elif isinstance(value, dict):
805+
elif meta.map_meta:
806806
for k, v in value.items():
807-
assert meta.map_meta
808807
sk = _serialize_single(1, meta.map_meta[0].proto_type, k)
809808
sv = _serialize_single(2, meta.map_meta[1].proto_type, v, unwrap=meta.map_meta[1].unwrap)
810809
stream.write(_serialize_single(meta.number, meta.proto_type, sk + sv))
@@ -944,8 +943,10 @@ def load(
944943

945944
meta = proto_meta.meta_by_field_name[field_name]
946945

946+
is_packed_repeated = parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES
947+
947948
value: Any
948-
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
949+
if is_packed_repeated:
949950
# This is a packed repeated field.
950951
pos = 0
951952
value = []
@@ -969,8 +970,8 @@ def load(
969970
if meta.proto_type == TYPE_MAP:
970971
# Value represents a single key/value pair entry in the map.
971972
current[value.key] = value.value
972-
elif isinstance(current, list):
973-
if isinstance(value, list):
973+
elif meta.repeated:
974+
if is_packed_repeated:
974975
current.extend(value)
975976
else:
976977
current.append(value)
@@ -1142,7 +1143,12 @@ def _from_dict_init(cls, mapping: Mapping[str, Any] | Any, *, ignore_unknown_fie
11421143
raise KeyError(f"Unknown field '{field_name}' in message {cls.__name__}.") from None
11431144

11441145
if value is None:
1145-
continue
1146+
name, module = field_cls.__name__, field_cls.__module__
1147+
1148+
# Edge case: None shouldn't be ignored for google.protobuf.Value
1149+
# See https://protobuf.dev/programming-guides/json/
1150+
if not (module.endswith("google.protobuf") and name == "Value"):
1151+
continue
11461152

11471153
if meta.proto_type == TYPE_MESSAGE:
11481154
if meta.repeated:

betterproto2/tests/inputs/config.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

betterproto2/tests/test_any.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
def test_any() -> None:
2-
# TODO using a custom message pool will no longer be necessary when the well-known types will be compiled as well
32
from tests.outputs.any.any import Person
43
from tests.outputs.any.google.protobuf import Any
54

65
person = Person(first_name="John", last_name="Smith")
76

8-
any = Any()
9-
any.pack(person)
7+
any = Any.pack(person)
108

119
new_any = Any.parse(bytes(any))
1210

@@ -19,25 +17,28 @@ def test_any_to_dict() -> None:
1917

2018
person = Person(first_name="John", last_name="Smith")
2119

22-
any = Any()
23-
2420
# TODO test with include defautl value
25-
assert any.to_dict() == {"@type": ""}
21+
assert Any().to_dict() == {"@type": ""}
2622

2723
# Pack an object inside
28-
any.pack(person)
24+
any = Any.pack(person)
2925

3026
assert any.to_dict() == {
3127
"@type": "type.googleapis.com/any.Person",
3228
"firstName": "John",
3329
"lastName": "Smith",
3430
}
3531

32+
assert Any.from_dict(any.to_dict()) == any
33+
assert Any.parse(bytes(any)) == any
34+
3635
# Pack again in another Any
37-
any2 = Any()
38-
any2.pack(any)
36+
any2 = Any.pack(any)
3937

4038
assert any2.to_dict() == {
4139
"@type": "type.googleapis.com/google.protobuf.Any",
4240
"value": {"@type": "type.googleapis.com/any.Person", "firstName": "John", "lastName": "Smith"},
4341
}
42+
43+
assert Any.from_dict(any2.to_dict()) == any2
44+
assert Any.parse(bytes(any2)) == any2

betterproto2/tests/test_inputs.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,11 @@ def reset_sys_path():
107107
["googletypes_struct/googletypes_struct.json"],
108108
"googletypes_struct.googletypes_struct",
109109
"googletypes_struct_reference.googletypes_struct_pb2",
110-
xfail=True,
111110
),
112111
TestCase(
113112
["googletypes_value/googletypes_value.json"],
114113
"googletypes_value.googletypes_value",
115114
"googletypes_value_reference.googletypes_value_pb2",
116-
xfail=True,
117115
),
118116
TestCase(["int32/int32.json"], "int32.int32", "int32_reference.int32_pb2"),
119117
TestCase(["map/map.json"], "map.map", "map_reference.map_pb2"),

betterproto2/tests/test_pickling.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,15 @@ def complex_msg():
1717
fe=Fe(abc="1"),
1818
nested_data=NestedData(
1919
struct_foo={
20-
"foo": google.Struct(
21-
fields={
22-
"hello": google.Value(list_value=google.ListValue(values=[google.Value(string_value="world")]))
20+
"foo": google.Struct.from_dict(
21+
{
22+
"hello": [["world"]],
2323
}
2424
),
2525
},
26-
map_str_any_bar={
27-
"key": google.Any(value=b"value"),
28-
},
2926
),
3027
mapping={
31-
"message": google.Any(value=bytes(Fi(abc="hi"))),
32-
"string": google.Any(value=b"howdy"),
28+
"message": google.Any.pack(Fi(abc="hi")),
3329
},
3430
)
3531

@@ -40,9 +36,8 @@ def test_pickling_complex_message():
4036
assert msg == deser
4137
assert msg.fe.abc == "1"
4238
assert msg.is_set("fi") is not True
43-
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
44-
assert msg.mapping["string"].value.decode() == "howdy"
45-
assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world"
39+
assert msg.mapping["message"] == google.Any.pack(Fi(abc="hi"))
40+
assert msg.nested_data.struct_foo["foo"].to_dict()["hello"][0][0] == "world"
4641

4742

4843
def test_recursive_message_defaults():
@@ -51,11 +46,7 @@ def test_recursive_message_defaults():
5146
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
5247
msg = unpickled(msg)
5348

54-
# set values are as expected
5549
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
56-
57-
# lazy initialized works modifies the message
58-
assert msg != RecursiveMessage(name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude"))
5950
msg.child = RecursiveMessage(child=RecursiveMessage(name="jude"))
6051
assert msg == RecursiveMessage(
6152
name="bob",
@@ -104,7 +95,6 @@ def use_cache():
10495
msg = use_cache()
10596
assert use_cache.calls == 1 # The message is only ever built once
10697
assert msg.fe.abc == "1"
107-
assert msg.is_set("fi") is not True
108-
assert msg.mapping["message"] == google.Any(value=bytes(Fi(abc="hi")))
109-
assert msg.mapping["string"].value.decode() == "howdy"
110-
assert msg.nested_data.struct_foo["foo"].fields["hello"].list_value.values[0].string_value == "world"
98+
assert not msg.is_set("fi")
99+
assert msg.mapping["message"] == google.Any.pack(Fi(abc="hi"))
100+
assert msg.nested_data.struct_foo["foo"].to_dict()["hello"][0][0] == "world"

betterproto2/tests/test_struct.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
def test_struct_to_dict():
2+
from tests.outputs.google.google.protobuf import Struct
3+
4+
struct = Struct.from_dict(
5+
{
6+
"null_field": None,
7+
"number_field": 12,
8+
"string_field": "test",
9+
"bool_field": True,
10+
"struct_field": {"x": "abc"},
11+
"list_field": [42, False, None],
12+
}
13+
)
14+
15+
assert struct.to_dict() == {
16+
"null_field": None,
17+
"number_field": 12,
18+
"string_field": "test",
19+
"bool_field": True,
20+
"struct_field": {"x": "abc"},
21+
"list_field": [42, False, None],
22+
}
23+
24+
assert Struct.from_dict(struct.to_dict()) == struct
25+
26+
27+
def test_listvalue_to_dict():
28+
from tests.outputs.google.google.protobuf import ListValue
29+
30+
list_value = ListValue.from_dict([42, False, {}])
31+
32+
assert list_value.to_dict() == [42, False, {}]
33+
assert ListValue.from_dict(list_value.to_dict()) == list_value
34+
35+
36+
def test_nullvalue():
37+
from tests.outputs.google.google.protobuf import NullValue, Value
38+
39+
null_value = NullValue.NULL_VALUE
40+
41+
assert bytes(Value(null_value=null_value)) == b"\x08\x00"
42+
43+
44+
def test_value_to_dict():
45+
from tests.outputs.google.google.protobuf import Value
46+
47+
value = Value.from_dict([1, 2, False])
48+
49+
assert value.to_dict() == [1, 2, False]
50+
assert Value.from_dict(value.to_dict()) == value

betterproto2_compiler/pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@ keywords = [
1515
requires-python = ">=3.10,<4.0"
1616
dependencies = [
1717
# TODO use the version from the current repo?
18-
"betterproto2[grpclib]>=0.7.0,<0.8",
18+
# "betterproto2[grpclib]>=0.7.0,<0.8",
19+
"betterproto2[grpclib]",
1920
"ruff~=0.9.3",
2021
"jinja2>=3.0.3",
2122
"typing-extensions>=4.7.1,<5",
2223
"strenum>=0.4.15,<0.5 ; python_version == '3.10'",
2324
]
2425

26+
[tool.uv.sources]
27+
"betterproto2" = { path = "../betterproto2" }
28+
2529
[project.urls]
2630
Documentation = "https://betterproto.github.io/python-betterproto2/"
2731
Repository = "https://github.com/betterproto/python-betterproto2"

betterproto2_compiler/src/betterproto2_compiler/known_types/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
UInt32Value,
1414
UInt64Value,
1515
)
16+
from .struct import ListValue, Struct, Value
1617
from .timestamp import Timestamp
1718

1819
# For each (package, message name), lists the methods that should be added to the message definition.
1920
# The source code of the method is read from the `known_types` folder. If imports are needed, they can be directly added
2021
# to the template file: they will automatically be removed if not necessary.
2122
KNOWN_METHODS: dict[tuple[str, str], list[Callable]] = {
22-
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict],
23+
("google.protobuf", "Any"): [Any.pack, Any.unpack, Any.to_dict, Any.from_dict],
2324
("google.protobuf", "Timestamp"): [
2425
Timestamp.from_datetime,
2526
Timestamp.to_datetime,
@@ -92,6 +93,18 @@
9293
BytesValue.from_wrapped,
9394
BytesValue.to_wrapped,
9495
],
96+
("google.protobuf", "Struct"): [
97+
Struct.from_dict,
98+
Struct.to_dict,
99+
],
100+
("google.protobuf", "ListValue"): [
101+
ListValue.from_dict,
102+
ListValue.to_dict,
103+
],
104+
("google.protobuf", "Value"): [
105+
Value.from_dict,
106+
Value.to_dict,
107+
],
95108
}
96109

97110
# A wrapped type is the type of a message that is automatically replaced by a known Python type.

betterproto2_compiler/src/betterproto2_compiler/known_types/any.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99

1010
class Any(VanillaAny):
11-
def pack(self, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> None:
11+
@classmethod
12+
def pack(cls, message: betterproto2.Message, message_pool: "betterproto2.MessagePool | None" = None) -> "Any":
1213
"""
1314
Pack the given message in the `Any` object.
1415
@@ -17,8 +18,10 @@ def pack(self, message: betterproto2.Message, message_pool: "betterproto2.Messag
1718
"""
1819
message_pool = message_pool or default_message_pool
1920

20-
self.type_url = message_pool.type_to_url[type(message)]
21-
self.value = bytes(message)
21+
type_url = message_pool.type_to_url[type(message)]
22+
value = bytes(message)
23+
24+
return cls(type_url=type_url, value=value)
2225

2326
def unpack(self, message_pool: "betterproto2.MessagePool | None" = None) -> betterproto2.Message | None:
2427
"""
@@ -54,3 +57,21 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]:
5457
output["value"] = value.to_dict(**kwargs)
5558

5659
return output
60+
61+
# TODO typing
62+
@classmethod
63+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
64+
value = dict(value) # Make a copy
65+
66+
type_url = value.pop("@type", None)
67+
msg_cls = default_message_pool.url_to_type.get(type_url, None)
68+
69+
if not msg_cls:
70+
raise TypeError(f"Can't unpack unregistered type: {type_url}")
71+
72+
if not msg_cls.to_dict == betterproto2.Message.to_dict:
73+
value = value["value"]
74+
75+
return cls(
76+
type_url=type_url, value=bytes(msg_cls.from_dict(value, ignore_unknown_fields=ignore_unknown_fields))
77+
)

0 commit comments

Comments
 (0)