2020
2121from uamqp import AMQPClient , Message , authentication , constants , errors , compat , utils
2222import six
23- from azure .core .credentials import AccessToken , AzureSasCredential , AzureNamedKeyCredential
23+ from azure .core .credentials import (
24+ AccessToken ,
25+ AzureSasCredential ,
26+ AzureNamedKeyCredential ,
27+ )
2428from azure .core .utils import parse_connection_string as core_parse_connection_string
2529
2630
2731from .exceptions import _handle_exception , ClientClosedError , ConnectError
2832from ._configuration import Configuration
33+ from ._retry import RetryMode
2934from ._utils import utc_from_timestamp , parse_sas_credential
3035from ._connection_manager import get_connection_manager
3136from ._constants import (
3439 MGMT_OPERATION ,
3540 MGMT_PARTITION_OPERATION ,
3641 MGMT_STATUS_CODE ,
37- MGMT_STATUS_DESC
42+ MGMT_STATUS_DESC ,
3843)
3944
4045if TYPE_CHECKING :
@@ -52,9 +57,11 @@ def _parse_conn_str(conn_str, **kwargs):
5257 entity_path = None # type: Optional[str]
5358 shared_access_signature = None # type: Optional[str]
5459 shared_access_signature_expiry = None
55- eventhub_name = kwargs .pop ("eventhub_name" , None ) # type: Optional[str]
56- check_case = kwargs .pop ("check_case" , False ) # type: bool
57- conn_settings = core_parse_connection_string (conn_str , case_sensitive_keys = check_case )
60+ eventhub_name = kwargs .pop ("eventhub_name" , None ) # type: Optional[str]
61+ check_case = kwargs .pop ("check_case" , False ) # type: bool
62+ conn_settings = core_parse_connection_string (
63+ conn_str , case_sensitive_keys = check_case
64+ )
5865 if check_case :
5966 shared_access_key = conn_settings .get ("SharedAccessKey" )
6067 shared_access_key_name = conn_settings .get ("SharedAccessKeyName" )
@@ -79,7 +86,7 @@ def _parse_conn_str(conn_str, **kwargs):
7986 try :
8087 # Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
8188 shared_access_signature_expiry = int (
82- shared_access_signature .split ("se=" )[1 ].split ("&" )[0 ] # type: ignore
89+ shared_access_signature .split ("se=" )[1 ].split ("&" )[0 ] # type: ignore
8390 )
8491 except (
8592 IndexError ,
@@ -117,12 +124,14 @@ def _parse_conn_str(conn_str, **kwargs):
117124 "At least one of the SharedAccessKey or SharedAccessSignature must be present."
118125 )
119126
120- return (host ,
121- str (shared_access_key_name ) if shared_access_key_name else None ,
122- str (shared_access_key ) if shared_access_key else None ,
123- entity ,
124- str (shared_access_signature ) if shared_access_signature else None ,
125- shared_access_signature_expiry )
127+ return (
128+ host ,
129+ str (shared_access_key_name ) if shared_access_key_name else None ,
130+ str (shared_access_key ) if shared_access_key else None ,
131+ entity ,
132+ str (shared_access_signature ) if shared_access_signature else None ,
133+ shared_access_signature_expiry ,
134+ )
126135
127136
128137def _generate_sas_token (uri , policy , key , expiry = None ):
@@ -154,6 +163,14 @@ def _build_uri(address, entity):
154163 return address
155164
156165
166+ def _get_backoff_time (retry_mode , backoff_factor , backoff_max , retried_times ):
167+ if retry_mode == RetryMode .FIXED :
168+ backoff_value = backoff_factor
169+ else :
170+ backoff_value = backoff_factor * (2 ** retried_times )
171+ return min (backoff_max , backoff_value )
172+
173+
157174class EventHubSharedKeyCredential (object ):
158175 """The shared access key credential used for authentication.
159176
@@ -200,6 +217,7 @@ class EventHubSASTokenCredential(object):
200217 :param str token: The shared access token string
201218 :param int expiry: The epoch timestamp
202219 """
220+
203221 def __init__ (self , token , expiry ):
204222 # type: (str, int) -> None
205223 """
@@ -225,6 +243,7 @@ class EventhubAzureSasTokenCredential(object):
225243 :param azure_sas_credential: The credential to be used for authentication.
226244 :type azure_sas_credential: ~azure.core.credentials.AzureSasCredential
227245 """
246+
228247 def __init__ (self , azure_sas_credential ):
229248 # type: (AzureSasCredential) -> None
230249 """The shared access token credential used for authentication
@@ -257,9 +276,9 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
257276 if isinstance (credential , AzureSasCredential ):
258277 self ._credential = EventhubAzureSasTokenCredential (credential )
259278 elif isinstance (credential , AzureNamedKeyCredential ):
260- self ._credential = EventhubAzureNamedKeyTokenCredential (credential ) # type: ignore
279+ self ._credential = EventhubAzureNamedKeyTokenCredential (credential ) # type: ignore
261280 else :
262- self ._credential = credential # type: ignore
281+ self ._credential = credential # type: ignore
263282 self ._keep_alive = kwargs .get ("keep_alive" , 30 )
264283 self ._auto_reconnect = kwargs .get ("auto_reconnect" , True )
265284 self ._mgmt_target = "amqps://{}/{}" .format (
@@ -274,7 +293,9 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
274293 @staticmethod
275294 def _from_connection_string (conn_str , ** kwargs ):
276295 # type: (str, Any) -> Dict[str, Any]
277- host , policy , key , entity , token , token_expiry = _parse_conn_str (conn_str , ** kwargs )
296+ host , policy , key , entity , token , token_expiry = _parse_conn_str (
297+ conn_str , ** kwargs
298+ )
278299 kwargs ["fully_qualified_namespace" ] = host
279300 kwargs ["eventhub_name" ] = entity
280301 if token and token_expiry :
@@ -291,7 +312,7 @@ def _create_auth(self):
291312 """
292313 try :
293314 # ignore mypy's warning because token_type is Optional
294- token_type = self ._credential .token_type # type: ignore
315+ token_type = self ._credential .token_type # type: ignore
295316 except AttributeError :
296317 token_type = b"jwt"
297318 if token_type == b"servicebus.windows.net:sastoken" :
@@ -305,7 +326,7 @@ def _create_auth(self):
305326 transport_type = self ._config .transport_type ,
306327 custom_endpoint_hostname = self ._config .custom_endpoint_hostname ,
307328 port = self ._config .connection_port ,
308- verify = self ._config .connection_verify
329+ verify = self ._config .connection_verify ,
309330 )
310331 auth .update_token ()
311332 return auth
@@ -319,7 +340,7 @@ def _create_auth(self):
319340 transport_type = self ._config .transport_type ,
320341 custom_endpoint_hostname = self ._config .custom_endpoint_hostname ,
321342 port = self ._config .connection_port ,
322- verify = self ._config .connection_verify
343+ verify = self ._config .connection_verify ,
323344 )
324345
325346 def _close_connection (self ):
@@ -331,7 +352,12 @@ def _backoff(
331352 ):
332353 # type: (int, Exception, Optional[int], Optional[str]) -> None
333354 entity_name = entity_name or self ._container_id
334- backoff = self ._config .backoff_factor * 2 ** retried_times
355+ backoff = _get_backoff_time (
356+ self ._config .retry_mode ,
357+ self ._config .backoff_factor ,
358+ self ._config .backoff_max ,
359+ retried_times ,
360+ )
335361 if backoff <= self ._config .backoff_max and (
336362 timeout_time is None or time .time () + backoff <= timeout_time
337363 ): # pylint:disable=no-else-return
@@ -360,7 +386,7 @@ def _management_request(self, mgmt_msg, op_type):
360386 self ._mgmt_target , auth = mgmt_auth , debug = self ._config .network_tracing
361387 )
362388 try :
363- conn = self ._conn_manager .get_connection (
389+ conn = self ._conn_manager .get_connection ( # pylint:disable=assignment-from-none
364390 self ._address .hostname , mgmt_auth
365391 )
366392 mgmt_client .open (connection = conn )
@@ -373,29 +399,28 @@ def _management_request(self, mgmt_msg, op_type):
373399 description_fields = MGMT_STATUS_DESC ,
374400 )
375401 status_code = int (response .application_properties [MGMT_STATUS_CODE ])
376- description = response .application_properties .get (MGMT_STATUS_DESC ) # type: Optional[Union[str, bytes]]
402+ description = response .application_properties .get (
403+ MGMT_STATUS_DESC
404+ ) # type: Optional[Union[str, bytes]]
377405 if description and isinstance (description , six .binary_type ):
378- description = description .decode (' utf-8' )
406+ description = description .decode (" utf-8" )
379407 if status_code < 400 :
380408 return response
381409 if status_code in [401 ]:
382410 raise errors .AuthenticationException (
383411 "Management authentication failed. Status code: {}, Description: {!r}" .format (
384- status_code ,
385- description
412+ status_code , description
386413 )
387414 )
388415 if status_code in [404 ]:
389416 raise ConnectError (
390417 "Management connection failed. Status code: {}, Description: {!r}" .format (
391- status_code ,
392- description
418+ status_code , description
393419 )
394420 )
395421 raise errors .AMQPConnectionError (
396422 "Management request error. Status code: {}, Description: {!r}" .format (
397- status_code ,
398- description
423+ status_code , description
399424 )
400425 )
401426 except Exception as exception : # pylint: disable=broad-except
@@ -491,9 +516,7 @@ def _check_closed(self):
491516 )
492517
493518 def _open (self ):
494- """Open the EventHubConsumer/EventHubProducer using the supplied connection.
495-
496- """
519+ """Open the EventHubConsumer/EventHubProducer using the supplied connection."""
497520 # pylint: disable=protected-access
498521 if not self .running :
499522 if self ._handler :
0 commit comments