Skip to content

Commit 94a7adf

Browse files
add missing Self return types to from_dict methods (#173)
* add missing Self return types to from_dict methods * prefer cls and self over hardcoded class name to fix pyright and better subclassing * add Self to generated code imports * fix: need to import Self from typing_extensions for older Python versions...
1 parent 662476a commit 94a7adf

File tree

7 files changed

+62
-55
lines changed

7 files changed

+62
-55
lines changed

betterproto2_compiler/src/betterproto2_compiler/known_types/any.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing
22

33
import betterproto2
4+
from typing_extensions import Self
45

56
from betterproto2_compiler.lib.google.protobuf import Any as VanillaAny
67

@@ -60,7 +61,7 @@ def to_dict(self, **kwargs) -> dict[str, typing.Any]:
6061

6162
# TODO typing
6263
@classmethod
63-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
64+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
6465
value = dict(value) # Make a copy
6566

6667
type_url = value.pop("@type", None)

betterproto2_compiler/src/betterproto2_compiler/known_types/duration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing
44

55
import betterproto2
6+
from typing_extensions import Self
67

78
from betterproto2_compiler.lib.google.protobuf import Duration as VanillaDuration
89

@@ -30,13 +31,13 @@ def delta_to_json(delta: datetime.timedelta) -> str:
3031

3132
# TODO typing
3233
@classmethod
33-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
34+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
3435
if isinstance(value, str):
3536
if not re.match(r"^\d+(\.\d+)?s$", value):
3637
raise ValueError(f"Invalid duration string: {value}")
3738

3839
seconds = float(value[:-1])
39-
return Duration(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9))
40+
return cls(seconds=int(seconds), nanos=int((seconds - int(seconds)) * 1e9))
4041

4142
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
4243

betterproto2_compiler/src/betterproto2_compiler/known_types/google_values.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing
22

33
import betterproto2
4+
from typing_extensions import Self
45

56
from betterproto2_compiler.lib.google.protobuf import (
67
BoolValue as VanillaBoolValue,
@@ -24,9 +25,9 @@ def to_wrapped(self) -> bool:
2425
return self.value
2526

2627
@classmethod
27-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
28+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
2829
if isinstance(value, bool):
29-
return BoolValue(value=value)
30+
return cls(value=value)
3031
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
3132

3233
def to_dict(
@@ -48,9 +49,9 @@ def to_wrapped(self) -> int:
4849
return self.value
4950

5051
@classmethod
51-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
52+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
5253
if isinstance(value, int):
53-
return Int32Value(value=value)
54+
return cls(value=value)
5455
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
5556

5657
def to_dict(
@@ -72,9 +73,9 @@ def to_wrapped(self) -> int:
7273
return self.value
7374

7475
@classmethod
75-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
76+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
7677
if isinstance(value, int):
77-
return Int64Value(value=value)
78+
return cls(value=value)
7879
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
7980

8081
def to_dict(
@@ -96,9 +97,9 @@ def to_wrapped(self) -> int:
9697
return self.value
9798

9899
@classmethod
99-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
100+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
100101
if isinstance(value, int):
101-
return UInt32Value(value=value)
102+
return cls(value=value)
102103
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
103104

104105
def to_dict(
@@ -120,9 +121,9 @@ def to_wrapped(self) -> int:
120121
return self.value
121122

122123
@classmethod
123-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
124+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
124125
if isinstance(value, int):
125-
return UInt64Value(value=value)
126+
return cls(value=value)
126127
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
127128

128129
def to_dict(
@@ -144,9 +145,9 @@ def to_wrapped(self) -> float:
144145
return self.value
145146

146147
@classmethod
147-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
148+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
148149
if isinstance(value, float):
149-
return FloatValue(value=value)
150+
return cls(value=value)
150151
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
151152

152153
def to_dict(
@@ -168,9 +169,9 @@ def to_wrapped(self) -> float:
168169
return self.value
169170

170171
@classmethod
171-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
172+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
172173
if isinstance(value, float):
173-
return DoubleValue(value=value)
174+
return cls(value=value)
174175
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
175176

176177
def to_dict(
@@ -192,9 +193,9 @@ def to_wrapped(self) -> str:
192193
return self.value
193194

194195
@classmethod
195-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
196+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
196197
if isinstance(value, str):
197-
return StringValue(value=value)
198+
return cls(value=value)
198199
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
199200

200201
def to_dict(
@@ -216,9 +217,9 @@ def to_wrapped(self) -> bytes:
216217
return self.value
217218

218219
@classmethod
219-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
220+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
220221
if isinstance(value, bytes):
221-
return BytesValue(value=value)
222+
return cls(value=value)
222223
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
223224

224225
def to_dict(

betterproto2_compiler/src/betterproto2_compiler/known_types/struct.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing
22

33
import betterproto2
4+
from typing_extensions import Self
45

56
from betterproto2_compiler.lib.google.protobuf import (
67
ListValue as VanillaListValue,
@@ -13,7 +14,7 @@
1314
class Struct(VanillaStruct):
1415
# TODO typing
1516
@classmethod
16-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
17+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
1718
assert isinstance(value, dict)
1819

1920
fields: dict[str, Value] = {}
@@ -47,7 +48,7 @@ def to_dict(
4748
class Value(VanillaValue):
4849
# TODO typing
4950
@classmethod
50-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
51+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
5152
match value:
5253
case bool() as b:
5354
return cls(bool_value=b)
@@ -94,7 +95,7 @@ def to_dict(
9495
class ListValue(VanillaListValue):
9596
# TODO typing
9697
@classmethod
97-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
98+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
9899
return cls(values=[Value.from_dict(v) for v in value])
99100

100101
# TODO typing

betterproto2_compiler/src/betterproto2_compiler/known_types/timestamp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
import betterproto2
55
import dateutil.parser
6+
from typing_extensions import Self
67

78
from betterproto2_compiler.lib.google.protobuf import Timestamp as VanillaTimestamp
89

910

1011
class Timestamp(VanillaTimestamp):
1112
@classmethod
12-
def from_datetime(cls, dt: datetime.datetime) -> "Timestamp":
13+
def from_datetime(cls, dt: datetime.datetime) -> Self:
1314
if not dt.tzinfo:
1415
raise ValueError("datetime must be timezone aware")
1516

@@ -55,11 +56,11 @@ def timestamp_to_json(dt: datetime.datetime) -> str:
5556

5657
# TODO typing
5758
@classmethod
58-
def from_dict(cls, value, *, ignore_unknown_fields: bool = False):
59+
def from_dict(cls, value, *, ignore_unknown_fields: bool = False) -> Self:
5960
if isinstance(value, str):
6061
dt = dateutil.parser.isoparse(value)
6162
dt = dt.astimezone(datetime.timezone.utc)
62-
return Timestamp.from_datetime(dt)
63+
return cls.from_datetime(dt)
6364

6465
return super().from_dict(value, ignore_unknown_fields=ignore_unknown_fields)
6566

0 commit comments

Comments
 (0)