Skip to content

Commit 9727cbf

Browse files
Specify type parameters for context in async grpc methods
1 parent d946d68 commit 9727cbf

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

mypy_protobuf/main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -837,10 +837,10 @@ def write_grpc_iterator_type(self) -> None:
837837
)
838838
wl("")
839839

840-
def get_servicer_context_type(self) -> str:
840+
def get_servicer_context_type(self, input_: str, output: str) -> str:
841841
"""Get the type to use for the context parameter in servicer methods."""
842842
if self.grpc_type == GRPCType.ASYNC:
843-
return self._import("grpc.aio", "ServicerContext")
843+
return self._import("grpc.aio", f"ServicerContext[{input_}, {output}]")
844844
elif self.grpc_type == GRPCType.SYNC:
845845
return self._import("grpc", "ServicerContext")
846846
else:
@@ -928,19 +928,20 @@ def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: Sour
928928
wl("")
929929
for i, method in methods:
930930
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
931+
input_type = self._servicer_input_type(method)
932+
output_type = self._servicer_output_type(method)
931933

932934
if self.generate_concrete_servicer_stubs is False:
933935
wl("@{}", self._import("abc", "abstractmethod"))
934936
wl("def {}(", method.name)
935937
with self._indent():
936938
wl("self,")
937939
input_name = "request_iterator" if method.client_streaming else "request"
938-
input_type = self._servicer_input_type(method)
939940
wl(f"{input_name}: {input_type},")
940-
wl("context: {},", self.get_servicer_context_type())
941+
wl("context: {},", self.get_servicer_context_type(input_type, output_type))
941942
wl(
942943
") -> {}:{}",
943-
self._servicer_output_type(method),
944+
output_type,
944945
" ..." if not self._has_comments(scl) else "",
945946
)
946947
if self._has_comments(scl):

test/generated-async-only/testproto/grpc/dummy_pb2_grpc.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,31 +61,31 @@ class DummyServiceServicer(metaclass=abc.ABCMeta):
6161
def UnaryUnary(
6262
self,
6363
request: testproto.grpc.dummy_pb2.DummyRequest,
64-
context: grpc.aio.ServicerContext,
64+
context: grpc.aio.ServicerContext[testproto.grpc.dummy_pb2.DummyRequest, collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]],
6565
) -> collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]:
6666
"""UnaryUnary"""
6767

6868
@abc.abstractmethod
6969
def UnaryStream(
7070
self,
7171
request: testproto.grpc.dummy_pb2.DummyRequest,
72-
context: grpc.aio.ServicerContext,
72+
context: grpc.aio.ServicerContext[testproto.grpc.dummy_pb2.DummyRequest, collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]],
7373
) -> collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]:
7474
"""UnaryStream"""
7575

7676
@abc.abstractmethod
7777
def StreamUnary(
7878
self,
7979
request_iterator: collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyRequest],
80-
context: grpc.aio.ServicerContext,
80+
context: grpc.aio.ServicerContext[collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyRequest], collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]],
8181
) -> collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]:
8282
"""StreamUnary"""
8383

8484
@abc.abstractmethod
8585
def StreamStream(
8686
self,
8787
request_iterator: collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyRequest],
88-
context: grpc.aio.ServicerContext,
88+
context: grpc.aio.ServicerContext[collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyRequest], collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]],
8989
) -> collections.abc.AsyncIterator[testproto.grpc.dummy_pb2.DummyReply]:
9090
"""StreamStream"""
9191

@@ -117,15 +117,15 @@ class DeprecatedServiceServicer(metaclass=abc.ABCMeta):
117117
def DeprecatedMethod(
118118
self,
119119
request: testproto.grpc.dummy_pb2.DeprecatedRequest,
120-
context: grpc.aio.ServicerContext,
120+
context: grpc.aio.ServicerContext[testproto.grpc.dummy_pb2.DeprecatedRequest, collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]],
121121
) -> collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]:
122122
"""DeprecatedMethod"""
123123

124124
@abc.abstractmethod
125125
def DeprecatedMethodNotDeprecatedRequest(
126126
self,
127127
request: testproto.grpc.dummy_pb2.DummyRequest,
128-
context: grpc.aio.ServicerContext,
128+
context: grpc.aio.ServicerContext[testproto.grpc.dummy_pb2.DummyRequest, collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]],
129129
) -> collections.abc.Awaitable[testproto.grpc.dummy_pb2.DummyReply]:
130130
"""DeprecatedMethodNotDeprecatedRequest"""
131131

test/generated-async-only/testproto/grpc/import_pb2_grpc.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,23 @@ class SimpleServiceServicer(metaclass=abc.ABCMeta):
4949
def UnaryUnary(
5050
self,
5151
request: google.protobuf.empty_pb2.Empty,
52-
context: grpc.aio.ServicerContext,
52+
context: grpc.aio.ServicerContext[google.protobuf.empty_pb2.Empty, collections.abc.Awaitable[testproto.test_pb2.Simple1]],
5353
) -> collections.abc.Awaitable[testproto.test_pb2.Simple1]:
5454
"""UnaryUnary"""
5555

5656
@abc.abstractmethod
5757
def UnaryStream(
5858
self,
5959
request: testproto.test_pb2.Simple1,
60-
context: grpc.aio.ServicerContext,
60+
context: grpc.aio.ServicerContext[testproto.test_pb2.Simple1, collections.abc.Awaitable[google.protobuf.empty_pb2.Empty]],
6161
) -> collections.abc.Awaitable[google.protobuf.empty_pb2.Empty]:
6262
"""UnaryStream"""
6363

6464
@abc.abstractmethod
6565
def NoComment(
6666
self,
6767
request: testproto.test_pb2.Simple1,
68-
context: grpc.aio.ServicerContext,
68+
context: grpc.aio.ServicerContext[testproto.test_pb2.Simple1, collections.abc.Awaitable[google.protobuf.empty_pb2.Empty]],
6969
) -> collections.abc.Awaitable[google.protobuf.empty_pb2.Empty]: ...
7070

7171
def add_SimpleServiceServicer_to_server(servicer: SimpleServiceServicer, server: grpc.aio.Server) -> None: ...

0 commit comments

Comments
 (0)