Skip to content

Commit 48f8205

Browse files
refactor to fix typing errors for mypy 1.0.0 (Azure#29038)
1 parent 089e364 commit 48f8205

File tree

4 files changed

+119
-47
lines changed

4 files changed

+119
-47
lines changed

sdk/tables/azure-data-tables/azure/data/tables/_authentication.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from urllib.parse import urlparse
88
except ImportError:
99
from urlparse import urlparse # type: ignore
10+
from typing import Optional, Union, overload, cast
1011

11-
from azure.core.credentials import TokenCredential
12+
from azure.core.credentials import TokenCredential, AzureSasCredential, AzureNamedKeyCredential
1213
from azure.core.exceptions import ClientAuthenticationError
1314
from azure.core.pipeline import PipelineResponse, PipelineRequest
14-
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy
15+
from azure.core.pipeline.policies import BearerTokenCredentialPolicy, SansIOHTTPPolicy, AzureSasCredentialPolicy
1516

1617
try:
1718
from azure.core.pipeline.transport import AsyncHttpTransport
@@ -25,6 +26,7 @@
2526

2627
from ._common_conversion import _sign_string
2728
from ._error import _wrap_exception
29+
from ._constants import STORAGE_OAUTH_SCOPE
2830

2931

3032
class AzureSigningError(ClientAuthenticationError):
@@ -216,3 +218,55 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
216218
else:
217219
self.authorize_request(request, scope)
218220
return True
221+
222+
223+
@overload
224+
def _configure_credential(credential: AzureNamedKeyCredential) -> SharedKeyCredentialPolicy:
225+
...
226+
227+
@overload
228+
def _configure_credential(credential: SharedKeyCredentialPolicy) -> SharedKeyCredentialPolicy:
229+
...
230+
231+
@overload
232+
def _configure_credential(credential: AzureSasCredential) -> AzureSasCredentialPolicy:
233+
...
234+
235+
@overload
236+
def _configure_credential(credential: TokenCredential) -> BearerTokenChallengePolicy:
237+
...
238+
239+
@overload
240+
def _configure_credential(credential: None) -> None:
241+
...
242+
243+
def _configure_credential(
244+
credential: Optional[
245+
Union[
246+
AzureNamedKeyCredential,
247+
AzureSasCredential,
248+
TokenCredential,
249+
SharedKeyCredentialPolicy
250+
]
251+
]
252+
) -> Optional[
253+
Union[
254+
BearerTokenChallengePolicy,
255+
AzureSasCredentialPolicy,
256+
SharedKeyCredentialPolicy
257+
]
258+
]:
259+
if hasattr(credential, "get_token"):
260+
credential = cast(TokenCredential, credential)
261+
return BearerTokenChallengePolicy(
262+
credential, STORAGE_OAUTH_SCOPE
263+
)
264+
if isinstance(credential, SharedKeyCredentialPolicy):
265+
return credential
266+
if isinstance(credential, AzureSasCredential):
267+
return AzureSasCredentialPolicy(credential)
268+
if isinstance(credential, AzureNamedKeyCredential):
269+
return SharedKeyCredentialPolicy(credential)
270+
if credential is not None:
271+
raise TypeError("Unsupported credential: {}".format(credential))
272+
return None

sdk/tables/azure-data-tables/azure/data/tables/_base_client.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
DistributedTracingPolicy,
2727
HttpLoggingPolicy,
2828
UserAgentPolicy,
29-
AzureSasCredentialPolicy,
3029
NetworkTraceLoggingPolicy,
3130
CustomHookPolicy,
3231
RequestIdPolicy,
@@ -36,7 +35,6 @@
3635
from ._common_conversion import _is_cosmos_endpoint
3736
from ._shared_access_signature import QueryStringConstants
3837
from ._constants import (
39-
STORAGE_OAUTH_SCOPE,
4038
DEFAULT_COSMOS_ENDPOINT_SUFFIX,
4139
DEFAULT_STORAGE_ENDPOINT_SUFFIX,
4240
)
@@ -47,7 +45,7 @@
4745
_validate_tablename_error
4846
)
4947
from ._models import LocationMode
50-
from ._authentication import BearerTokenChallengePolicy, SharedKeyCredentialPolicy
48+
from ._authentication import _configure_credential
5149
from ._policies import (
5250
CosmosPatchTransformPolicy,
5351
StorageHeadersPolicy,
@@ -139,8 +137,7 @@ def __init__(
139137
LocationMode.PRIMARY: primary_hostname,
140138
LocationMode.SECONDARY: secondary_hostname,
141139
}
142-
self._credential_policy = None # type: ignore
143-
self._configure_credential(self.credential) # type: ignore
140+
144141
self._policies = self._configure_policies(hosts=self._hosts, **kwargs) # type: ignore
145142
if self._cosmos_endpoint:
146143
self._policies.insert(0, CosmosPatchTransformPolicy())
@@ -244,12 +241,13 @@ def __exit__(self, *args):
244241
self._client.__exit__(*args)
245242

246243
def _configure_policies(self, **kwargs):
244+
credential_policy = _configure_credential(self.credential)
247245
return [
248246
RequestIdPolicy(**kwargs),
249247
StorageHeadersPolicy(**kwargs),
250248
UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs),
251249
ProxyPolicy(**kwargs),
252-
self._credential_policy,
250+
credential_policy,
253251
ContentDecodePolicy(response_encoding="utf-8"),
254252
RedirectPolicy(**kwargs),
255253
StorageHosts(**kwargs),
@@ -260,22 +258,6 @@ def _configure_policies(self, **kwargs):
260258
HttpLoggingPolicy(**kwargs),
261259
]
262260

263-
def _configure_credential(
264-
self, credential: Optional[Union[AzureNamedKeyCredential, AzureSasCredential, TokenCredential]]
265-
) -> None:
266-
if hasattr(credential, "get_token"):
267-
self._credential_policy = BearerTokenChallengePolicy(
268-
credential, STORAGE_OAUTH_SCOPE # type: ignore
269-
)
270-
elif isinstance(credential, SharedKeyCredentialPolicy):
271-
self._credential_policy = credential # type: ignore
272-
elif isinstance(credential, AzureSasCredential):
273-
self._credential_policy = AzureSasCredentialPolicy(credential) # type: ignore
274-
elif isinstance(credential, AzureNamedKeyCredential):
275-
self._credential_policy = SharedKeyCredentialPolicy(credential) # type: ignore
276-
elif credential is not None:
277-
raise TypeError("Unsupported credential: {}".format(credential))
278-
279261
def _batch_send(self, table_name: str, *reqs: HttpRequest, **kwargs) -> List[Mapping[str, Any]]:
280262
"""Given a series of request, do a Storage batch call."""
281263
# Pop it here, so requests doesn't feel bad about additional kwarg

sdk/tables/azure-data-tables/azure/data/tables/aio/_authentication_async.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6-
from typing import Any
6+
from typing import Union, Optional, cast, overload
77

8+
from azure.core.credentials import AzureSasCredential, AzureNamedKeyCredential
89
from azure.core.credentials_async import AsyncTokenCredential
910
from azure.core.pipeline import PipelineResponse, PipelineRequest
1011
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
1112

12-
from .._authentication import _HttpChallenge
13+
from .._constants import STORAGE_OAUTH_SCOPE
14+
from .._authentication import _HttpChallenge, AzureSasCredentialPolicy, SharedKeyCredentialPolicy
1315

1416

1517
class AsyncBearerTokenChallengePolicy(AsyncBearerTokenCredentialPolicy):
@@ -67,3 +69,55 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons
6769
else:
6870
await self.authorize_request(request, scope)
6971
return True
72+
73+
74+
@overload
75+
def _configure_credential(credential: AzureNamedKeyCredential) -> SharedKeyCredentialPolicy:
76+
...
77+
78+
@overload
79+
def _configure_credential(credential: SharedKeyCredentialPolicy) -> SharedKeyCredentialPolicy:
80+
...
81+
82+
@overload
83+
def _configure_credential(credential: AzureSasCredential) -> AzureSasCredentialPolicy:
84+
...
85+
86+
@overload
87+
def _configure_credential(credential: AsyncTokenCredential) -> AsyncBearerTokenChallengePolicy:
88+
...
89+
90+
@overload
91+
def _configure_credential(credential: None) -> None:
92+
...
93+
94+
def _configure_credential(
95+
credential: Optional[
96+
Union[
97+
AzureNamedKeyCredential,
98+
AzureSasCredential,
99+
AsyncTokenCredential,
100+
SharedKeyCredentialPolicy
101+
]
102+
]
103+
) -> Optional[
104+
Union[
105+
AsyncBearerTokenChallengePolicy,
106+
AzureSasCredentialPolicy,
107+
SharedKeyCredentialPolicy
108+
]
109+
]:
110+
if hasattr(credential, "get_token"):
111+
credential = cast(AsyncTokenCredential, credential)
112+
return AsyncBearerTokenChallengePolicy(
113+
credential, STORAGE_OAUTH_SCOPE
114+
)
115+
if isinstance(credential, SharedKeyCredentialPolicy):
116+
return credential
117+
if isinstance(credential, AzureSasCredential):
118+
return AzureSasCredentialPolicy(credential)
119+
if isinstance(credential, AzureNamedKeyCredential):
120+
return SharedKeyCredentialPolicy(credential)
121+
if credential is not None:
122+
raise TypeError("Unsupported credential: {}".format(credential))
123+
return None

sdk/tables/azure-data-tables/azure/data/tables/aio/_base_client_async.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
HttpLoggingPolicy,
1616
UserAgentPolicy,
1717
ProxyPolicy,
18-
AzureSasCredentialPolicy,
1918
RequestIdPolicy,
2019
CustomHookPolicy,
2120
NetworkTraceLoggingPolicy,
@@ -25,11 +24,9 @@
2524
HttpRequest,
2625
)
2726

28-
from ._authentication_async import AsyncBearerTokenChallengePolicy
27+
from ._authentication_async import _configure_credential
2928
from .._generated.aio import AzureTable
3029
from .._base_client import AccountHostsMixin, get_api_version, extract_batch_part_metadata
31-
from .._authentication import SharedKeyCredentialPolicy
32-
from .._constants import STORAGE_OAUTH_SCOPE
3330
from .._error import (
3431
RequestTooLargeError,
3532
TableTransactionError,
@@ -86,29 +83,14 @@ async def close(self) -> None:
8683
"""
8784
await self._client.close()
8885

89-
def _configure_credential(
90-
self, credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, AsyncTokenCredential]]
91-
) -> None:
92-
if hasattr(credential, "get_token"):
93-
self._credential_policy = AsyncBearerTokenChallengePolicy(
94-
credential, STORAGE_OAUTH_SCOPE # type: ignore
95-
)
96-
elif isinstance(credential, SharedKeyCredentialPolicy):
97-
self._credential_policy = credential # type: ignore
98-
elif isinstance(credential, AzureSasCredential):
99-
self._credential_policy = AzureSasCredentialPolicy(credential) # type: ignore
100-
elif isinstance(credential, AzureNamedKeyCredential):
101-
self._credential_policy = SharedKeyCredentialPolicy(credential) # type: ignore
102-
elif credential is not None:
103-
raise TypeError("Unsupported credential: {}".format(credential))
104-
10586
def _configure_policies(self, **kwargs):
87+
credential_policy = _configure_credential(self.credential)
10688
return [
10789
RequestIdPolicy(**kwargs),
10890
StorageHeadersPolicy(**kwargs),
10991
UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs),
11092
ProxyPolicy(**kwargs),
111-
self._credential_policy,
93+
credential_policy,
11294
ContentDecodePolicy(response_encoding="utf-8"),
11395
AsyncRedirectPolicy(**kwargs),
11496
StorageHosts(**kwargs),

0 commit comments

Comments
 (0)