Skip to content

Commit fb99b85

Browse files
Fix enum to dict (#124)
* Use proto name in enum to dict * Remove useless import, fix precommit * Make compiler generate enum entries dictionnary * Fix for Python 3.10
1 parent 03be812 commit fb99b85

File tree

9 files changed

+82
-16
lines changed

9 files changed

+82
-16
lines changed

betterproto2/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ build-backend = "hatchling.build"
5959
# ]
6060

6161
[tool.ruff]
62-
extend-exclude = ["tests/output_*", "src/betterproto2/internal_lib"]
62+
extend-exclude = ["tests/outputs", "src/betterproto2/internal_lib"]
6363
target-version = "py310"
6464
line-length = 120
6565

betterproto2/src/betterproto2/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._types import T
3535
from ._version import __version__, check_compiler_version
3636
from .casing import camel_case, safe_snake_case, snake_case
37-
from .enum import Enum as Enum
37+
from .enum_ import Enum as Enum
3838
from .grpc.grpclib_client import ServiceStub as ServiceStub
3939
from .utils import classproperty
4040

@@ -585,9 +585,10 @@ def _value_to_dict(
585585
if proto_type in INT_64_TYPES:
586586
return str(value), not bool(value)
587587
if proto_type == TYPE_BYTES:
588-
return b64encode(value).decode("utf8"), not (bool(value))
588+
return b64encode(value).decode("utf8"), not bool(value)
589589
if proto_type == TYPE_ENUM:
590-
return field_type(value).name, not bool(value)
590+
enum_value = field_type(value)
591+
return enum_value.proto_name or enum_value.name, not bool(value)
591592
if proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
592593
return _dump_float(value), not bool(value)
593594
return value, not bool(value)

betterproto2/src/betterproto2/enum.py renamed to betterproto2/src/betterproto2/enum_.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,42 @@
1-
from enum import IntEnum
1+
import sys
2+
from enum import EnumMeta, IntEnum
23

34
from typing_extensions import Self
45

56

6-
class Enum(IntEnum):
7+
class _EnumMeta(EnumMeta):
8+
def __new__(metacls, cls, bases, classdict):
9+
# Find the proto names if defined
10+
if sys.version_info >= (3, 11):
11+
proto_names = classdict.pop("betterproto_proto_names", {})
12+
classdict._member_names.pop("betterproto_proto_names", None)
13+
else:
14+
proto_names = {}
15+
if "betterproto_proto_names" in classdict:
16+
proto_names = classdict.pop("betterproto_proto_names")
17+
classdict._member_names.remove("betterproto_proto_names")
18+
19+
enum_class = super().__new__(metacls, cls, bases, classdict)
20+
21+
# Attach extra info to each enum member
22+
for member in enum_class:
23+
value = member.value # type: ignore[reportAttributeAccessIssue]
24+
extra = proto_names.get(value)
25+
member._proto_name = extra # type: ignore[reportAttributeAccessIssue]
26+
27+
return enum_class
28+
29+
30+
class Enum(IntEnum, metaclass=_EnumMeta):
31+
@property
32+
def proto_name(self) -> str | None:
33+
return self._proto_name # type: ignore[reportAttributeAccessIssue]
34+
735
@classmethod
836
def _missing_(cls, value):
937
# If the given value is not an integer, let the standard enum implementation raise an error
1038
if not isinstance(value, int):
11-
return None
39+
return
1240

1341
# Create a new "unknown" instance with the given value.
1442
obj = int.__new__(cls, value)

betterproto2/tests/test_all_definition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ def test_all_definition():
1717
"TestSyncStub",
1818
"ThingType",
1919
)
20-
assert enum.__all__ == ("ArithmeticOperator", "Choice", "HttpCode", "NoStriping", "Test")
20+
assert enum.__all__ == ("ArithmeticOperator", "Choice", "EnumMessage", "HttpCode", "NoStriping", "Test")

betterproto2/tests/test_enum.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ def test_enum_renaming() -> None:
8282
assert set(ArithmeticOperator.__members__) == {"NONE", "PLUS", "MINUS", "_0_PREFIXED"}
8383
assert set(HttpCode.__members__) == {"UNSPECIFIED", "OK", "NOT_FOUND"}
8484
assert set(NoStriping.__members__) == {"NO_STRIPING_NONE", "NO_STRIPING_A", "B"}
85+
86+
87+
def test_enum_to_dict() -> None:
88+
from tests.outputs.enum.enum import ArithmeticOperator, EnumMessage, NoStriping
89+
90+
msg = EnumMessage(
91+
arithmetic_operator=ArithmeticOperator.PLUS,
92+
no_striping=NoStriping.NO_STRIPING_A,
93+
)
94+
95+
print(ArithmeticOperator.PLUS.proto_name)
96+
97+
assert msg.to_dict() == {
98+
"arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved
99+
"noStriping": "NO_STRIPING_A",
100+
}

betterproto2_compiler/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ requires = ["hatchling"]
6060
build-backend = "hatchling.build"
6161

6262
[tool.ruff]
63-
extend-exclude = ["tests/output_*", "src/betterproto2_compiler/lib"]
63+
extend-exclude = ["tests/outputs", "src/betterproto2_compiler/lib"]
6464
target-version = "py310"
6565
line-length = 120
6666

betterproto2_compiler/src/betterproto2_compiler/plugin/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,13 +610,15 @@ class EnumEntry:
610610
"""Representation of an Enum entry."""
611611

612612
name: str
613+
proto_name: str
613614
value: int
614615
comment: str
615616

616617
def __post_init__(self) -> None:
617618
self.entries = [
618619
self.EnumEntry(
619620
name=entry_proto_value.name,
621+
proto_name=entry_proto_value.name,
620622
value=entry_proto_value.number,
621623
comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]),
622624
)
@@ -672,6 +674,10 @@ def descriptor_name(self) -> str:
672674
"""
673675
return self.output_file.get_descriptor_name(self.source_file)
674676

677+
@property
678+
def has_renamed_entries(self) -> bool:
679+
return any(entry.proto_name != entry.name for entry in self.entries)
680+
675681

676682
@dataclass(kw_only=True)
677683
class ServiceCompiler(ProtoContentBase):

betterproto2_compiler/src/betterproto2_compiler/templates/template.py.j2

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum):
3131
return core_schema.int_schema(ge=0)
3232
{% endif %}
3333

34+
{% if enum.has_renamed_entries %}
35+
betterproto_proto_names = {
36+
{% for entry in enum.entries %}
37+
{% if entry.proto_name != entry.name %}
38+
{{ entry.value }}: "{{ entry.proto_name }}",
39+
{% endif %}
40+
{% endfor %}
41+
}
42+
{% endif %}
43+
3444
{% endfor %}
3545
{% for _, message in output_file.messages|dictsort(by="key") %}
3646
{% if output_file.settings.pydantic_dataclasses %}

betterproto2_compiler/tests/inputs/enum/enum.proto

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ package enum;
44

55
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
66
message Test {
7-
Choice choice = 1;
8-
repeated Choice choices = 2;
7+
Choice choice = 1;
8+
repeated Choice choices = 2;
99
}
1010

1111
enum Choice {
12-
ZERO = 0;
13-
ONE = 1;
14-
// TWO = 2;
15-
FOUR = 4;
16-
THREE = 3;
12+
ZERO = 0;
13+
ONE = 1;
14+
// TWO = 2;
15+
FOUR = 4;
16+
THREE = 3;
1717
}
1818

1919
// A "C" like enum with the enum name prefixed onto members, these should be stripped
@@ -38,3 +38,8 @@ enum HTTPCode {
3838
HTTP_CODE_OK = 200;
3939
HTTP_CODE_NOT_FOUND = 404;
4040
}
41+
42+
message EnumMessage {
43+
ArithmeticOperator arithmetic_operator = 1;
44+
NoStriping no_striping = 2;
45+
}

0 commit comments

Comments
 (0)