|
20 | 20 |
|
21 | 21 | from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils |
22 | 22 | import six |
| 23 | +from azure.core.utils import parse_connection_string as core_parse_connection_string |
23 | 24 | from azure.core.credentials import AccessToken, AzureSasCredential |
24 | 25 |
|
25 | 26 | from .exceptions import _handle_exception, ClientClosedError, ConnectError |
|
43 | 44 | _AccessToken = collections.namedtuple("AccessToken", "token expires_on") |
44 | 45 |
|
45 | 46 |
|
46 | | -def _parse_conn_str(conn_str, kwargs): |
47 | | - # type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] |
| 47 | +def _parse_conn_str(conn_str, **kwargs): |
| 48 | + # type: (str, Any) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] |
48 | 49 | endpoint = None |
49 | 50 | shared_access_key_name = None |
50 | 51 | shared_access_key = None |
51 | 52 | entity_path = None # type: Optional[str] |
52 | 53 | shared_access_signature = None # type: Optional[str] |
53 | | - shared_access_signature_expiry = None # type: Optional[int] |
54 | | - eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str] |
55 | | - for element in conn_str.split(";"): |
56 | | - key, _, value = element.partition("=") |
57 | | - if key.lower() == "endpoint": |
58 | | - endpoint = value.rstrip("/") |
59 | | - elif key.lower() == "hostname": |
60 | | - endpoint = value.rstrip("/") |
61 | | - elif key.lower() == "sharedaccesskeyname": |
62 | | - shared_access_key_name = value |
63 | | - elif key.lower() == "sharedaccesskey": |
64 | | - shared_access_key = value |
65 | | - elif key.lower() == "entitypath": |
66 | | - entity_path = value |
67 | | - elif key.lower() == "sharedaccesssignature": |
68 | | - shared_access_signature = value |
69 | | - try: |
70 | | - # Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs) |
71 | | - # type: ignore |
72 | | - shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0]) |
73 | | - except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional. |
74 | | - # An arbitrary, absurdly large number, since you can't renew. |
75 | | - shared_access_signature_expiry = int(time.time() * 2) |
76 | | - if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))): |
| 54 | + shared_access_signature_expiry = None |
| 55 | + eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str] |
| 56 | + check_case = kwargs.pop("check_case", False) # type: bool |
| 57 | + conn_settings = core_parse_connection_string(conn_str, case_sensitive_keys=check_case) |
| 58 | + if check_case: |
| 59 | + shared_access_key = conn_settings.get("SharedAccessKey") |
| 60 | + shared_access_key_name = conn_settings.get("SharedAccessKeyName") |
| 61 | + endpoint = conn_settings.get("Endpoint") |
| 62 | + entity_path = conn_settings.get("EntityPath") |
| 63 | + # non case sensitive check when parsing connection string for internal use |
| 64 | + for key, value in conn_settings.items(): |
| 65 | + # only sas check is non case sensitive for both conn str properties and internal use |
| 66 | + if key.lower() == "sharedaccesssignature": |
| 67 | + shared_access_signature = value |
| 68 | + |
| 69 | + if not check_case: |
| 70 | + endpoint = conn_settings.get("endpoint") or conn_settings.get("hostname") |
| 71 | + if endpoint: |
| 72 | + endpoint = endpoint.rstrip("/") |
| 73 | + shared_access_key_name = conn_settings.get("sharedaccesskeyname") |
| 74 | + shared_access_key = conn_settings.get("sharedaccesskey") |
| 75 | + entity_path = conn_settings.get("entitypath") |
| 76 | + shared_access_signature = conn_settings.get("sharedaccesssignature") |
| 77 | + |
| 78 | + if shared_access_signature: |
| 79 | + try: |
| 80 | + # Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs) |
| 81 | + shared_access_signature_expiry = int( |
| 82 | + shared_access_signature.split("se=")[1].split("&")[0] # type: ignore |
| 83 | + ) |
| 84 | + except ( |
| 85 | + IndexError, |
| 86 | + TypeError, |
| 87 | + ValueError, |
| 88 | + ): # Fallback since technically expiry is optional. |
| 89 | + # An arbitrary, absurdly large number, since you can't renew. |
| 90 | + shared_access_signature_expiry = int(time.time() * 2) |
| 91 | + |
| 92 | + entity = cast(str, eventhub_name or entity_path) |
| 93 | + |
| 94 | + # check that endpoint is valid |
| 95 | + if not endpoint: |
| 96 | + raise ValueError("Connection string is either blank or malformed.") |
| 97 | + parsed = urlparse(endpoint) |
| 98 | + if not parsed.netloc: |
| 99 | + raise ValueError("Invalid Endpoint on the Connection String.") |
| 100 | + host = cast(str, parsed.netloc.strip()) |
| 101 | + |
| 102 | + if any([shared_access_key, shared_access_key_name]) and not all( |
| 103 | + [shared_access_key, shared_access_key_name] |
| 104 | + ): |
77 | 105 | raise ValueError( |
78 | 106 | "Invalid connection string. Should be in the format: " |
79 | 107 | "Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>" |
80 | 108 | ) |
81 | | - entity = cast(str, eventhub_name or entity_path) |
82 | | - left_slash_pos = cast(str, endpoint).find("//") |
83 | | - if left_slash_pos != -1: |
84 | | - host = cast(str, endpoint)[left_slash_pos + 2 :] |
85 | | - else: |
86 | | - host = str(endpoint) |
| 109 | + # Only connection string parser should check that only one of sas and shared access |
| 110 | + # key exists. For backwards compatibility, client construction should not have this check. |
| 111 | + if check_case and shared_access_signature and shared_access_key: |
| 112 | + raise ValueError( |
| 113 | + "Only one of the SharedAccessKey or SharedAccessSignature must be present." |
| 114 | + ) |
| 115 | + if not shared_access_signature and not shared_access_key: |
| 116 | + raise ValueError( |
| 117 | + "At least one of the SharedAccessKey or SharedAccessSignature must be present." |
| 118 | + ) |
| 119 | + |
87 | 120 | return (host, |
88 | 121 | str(shared_access_key_name) if shared_access_key_name else None, |
89 | 122 | str(shared_access_key) if shared_access_key else None, |
@@ -218,7 +251,7 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg |
218 | 251 | @staticmethod |
219 | 252 | def _from_connection_string(conn_str, **kwargs): |
220 | 253 | # type: (str, Any) -> Dict[str, Any] |
221 | | - host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs) |
| 254 | + host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs) |
222 | 255 | kwargs["fully_qualified_namespace"] = host |
223 | 256 | kwargs["eventhub_name"] = entity |
224 | 257 | if token and token_expiry: |
|
0 commit comments