Skip to content

Commit 8ac54ee

Browse files
authored
[Event Hubs] combine conn str parsing logic (Azure#18059)
* user core parser + remove redundancy * move sas expiry logic + types * fix error message * mypy error * error message for cs parser only
1 parent 73d0b36 commit 8ac54ee

File tree

4 files changed

+123
-73
lines changed

4 files changed

+123
-73
lines changed

sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils
2222
import six
23+
from azure.core.utils import parse_connection_string as core_parse_connection_string
2324
from azure.core.credentials import AccessToken, AzureSasCredential
2425

2526
from .exceptions import _handle_exception, ClientClosedError, ConnectError
@@ -43,47 +44,79 @@
4344
_AccessToken = collections.namedtuple("AccessToken", "token expires_on")
4445

4546

46-
def _parse_conn_str(conn_str, kwargs):
47-
# type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
47+
def _parse_conn_str(conn_str, **kwargs):
48+
# type: (str, Any) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
4849
endpoint = None
4950
shared_access_key_name = None
5051
shared_access_key = None
5152
entity_path = None # type: Optional[str]
5253
shared_access_signature = None # type: Optional[str]
53-
shared_access_signature_expiry = None # type: Optional[int]
54-
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
55-
for element in conn_str.split(";"):
56-
key, _, value = element.partition("=")
57-
if key.lower() == "endpoint":
58-
endpoint = value.rstrip("/")
59-
elif key.lower() == "hostname":
60-
endpoint = value.rstrip("/")
61-
elif key.lower() == "sharedaccesskeyname":
62-
shared_access_key_name = value
63-
elif key.lower() == "sharedaccesskey":
64-
shared_access_key = value
65-
elif key.lower() == "entitypath":
66-
entity_path = value
67-
elif key.lower() == "sharedaccesssignature":
68-
shared_access_signature = value
69-
try:
70-
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
71-
# type: ignore
72-
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
73-
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
74-
# An arbitrary, absurdly large number, since you can't renew.
75-
shared_access_signature_expiry = int(time.time() * 2)
76-
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
54+
shared_access_signature_expiry = None
55+
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
56+
check_case = kwargs.pop("check_case", False) # type: bool
57+
conn_settings = core_parse_connection_string(conn_str, case_sensitive_keys=check_case)
58+
if check_case:
59+
shared_access_key = conn_settings.get("SharedAccessKey")
60+
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
61+
endpoint = conn_settings.get("Endpoint")
62+
entity_path = conn_settings.get("EntityPath")
63+
# non case sensitive check when parsing connection string for internal use
64+
for key, value in conn_settings.items():
65+
# only sas check is non case sensitive for both conn str properties and internal use
66+
if key.lower() == "sharedaccesssignature":
67+
shared_access_signature = value
68+
69+
if not check_case:
70+
endpoint = conn_settings.get("endpoint") or conn_settings.get("hostname")
71+
if endpoint:
72+
endpoint = endpoint.rstrip("/")
73+
shared_access_key_name = conn_settings.get("sharedaccesskeyname")
74+
shared_access_key = conn_settings.get("sharedaccesskey")
75+
entity_path = conn_settings.get("entitypath")
76+
shared_access_signature = conn_settings.get("sharedaccesssignature")
77+
78+
if shared_access_signature:
79+
try:
80+
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
81+
shared_access_signature_expiry = int(
82+
shared_access_signature.split("se=")[1].split("&")[0] # type: ignore
83+
)
84+
except (
85+
IndexError,
86+
TypeError,
87+
ValueError,
88+
): # Fallback since technically expiry is optional.
89+
# An arbitrary, absurdly large number, since you can't renew.
90+
shared_access_signature_expiry = int(time.time() * 2)
91+
92+
entity = cast(str, eventhub_name or entity_path)
93+
94+
# check that endpoint is valid
95+
if not endpoint:
96+
raise ValueError("Connection string is either blank or malformed.")
97+
parsed = urlparse(endpoint)
98+
if not parsed.netloc:
99+
raise ValueError("Invalid Endpoint on the Connection String.")
100+
host = cast(str, parsed.netloc.strip())
101+
102+
if any([shared_access_key, shared_access_key_name]) and not all(
103+
[shared_access_key, shared_access_key_name]
104+
):
77105
raise ValueError(
78106
"Invalid connection string. Should be in the format: "
79107
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
80108
)
81-
entity = cast(str, eventhub_name or entity_path)
82-
left_slash_pos = cast(str, endpoint).find("//")
83-
if left_slash_pos != -1:
84-
host = cast(str, endpoint)[left_slash_pos + 2 :]
85-
else:
86-
host = str(endpoint)
109+
# Only connection string parser should check that only one of sas and shared access
110+
# key exists. For backwards compatibility, client construction should not have this check.
111+
if check_case and shared_access_signature and shared_access_key:
112+
raise ValueError(
113+
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
114+
)
115+
if not shared_access_signature and not shared_access_key:
116+
raise ValueError(
117+
"At least one of the SharedAccessKey or SharedAccessSignature must be present."
118+
)
119+
87120
return (host,
88121
str(shared_access_key_name) if shared_access_key_name else None,
89122
str(shared_access_key) if shared_access_key else None,
@@ -218,7 +251,7 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
218251
@staticmethod
219252
def _from_connection_string(conn_str, **kwargs):
220253
# type: (str, Any) -> Dict[str, Any]
221-
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
254+
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
222255
kwargs["fully_qualified_namespace"] = host
223256
kwargs["eventhub_name"] = entity
224257
if token and token_expiry:

sdk/eventhub/azure-eventhub/azure/eventhub/_connection_string_parser.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
5-
try:
6-
from urllib.parse import urlparse
7-
except ImportError:
8-
from urlparse import urlparse # type: ignore
9-
105
from ._common import DictMixin
6+
from ._client_base import _parse_conn_str
117

128

139
class EventHubConnectionStringProperties(DictMixin):
@@ -70,39 +66,14 @@ def parse_connection_string(conn_str):
7066
:type conn_str: str
7167
:rtype: ~azure.eventhub.EventHubConnectionStringProperties
7268
"""
73-
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
74-
if any(len(tup) != 2 for tup in conn_settings):
75-
raise ValueError("Connection string is either blank or malformed.")
76-
conn_settings = dict(conn_settings)
77-
shared_access_signature = None
78-
for key, value in conn_settings.items():
79-
if key.lower() == "sharedaccesssignature":
80-
shared_access_signature = value
81-
shared_access_key = conn_settings.get("SharedAccessKey")
82-
shared_access_key_name = conn_settings.get("SharedAccessKeyName")
83-
if any([shared_access_key, shared_access_key_name]) and not all(
84-
[shared_access_key, shared_access_key_name]
85-
):
86-
raise ValueError(
87-
"Connection string must have both SharedAccessKeyName and SharedAccessKey."
88-
)
89-
if shared_access_signature is not None and shared_access_key is not None:
90-
raise ValueError(
91-
"Only one of the SharedAccessKey or SharedAccessSignature must be present."
92-
)
93-
endpoint = conn_settings.get("Endpoint")
94-
if not endpoint:
95-
raise ValueError("Connection string is either blank or malformed.")
96-
parsed = urlparse(endpoint.rstrip("/"))
97-
if not parsed.netloc:
98-
raise ValueError("Invalid Endpoint on the Connection String.")
99-
namespace = parsed.netloc.strip()
69+
fully_qualified_namespace, policy, key, entity, signature = _parse_conn_str(conn_str, check_case=True)[:-1]
70+
endpoint = "sb://" + fully_qualified_namespace + "/"
10071
props = {
101-
"fully_qualified_namespace": namespace,
72+
"fully_qualified_namespace": fully_qualified_namespace,
10273
"endpoint": endpoint,
103-
"eventhub_name": conn_settings.get("EntityPath"),
104-
"shared_access_signature": shared_access_signature,
105-
"shared_access_key_name": shared_access_key_name,
106-
"shared_access_key": shared_access_key,
74+
"eventhub_name": entity,
75+
"shared_access_signature": signature,
76+
"shared_access_key_name": policy,
77+
"shared_access_key": key,
10778
}
10879
return EventHubConnectionStringProperties(**props)

sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def __enter__(self):
130130

131131
@staticmethod
132132
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
133-
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
133+
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, **kwargs)
134134
kwargs["fully_qualified_namespace"] = host
135135
kwargs["eventhub_name"] = entity
136136
if token and token_expiry:

sdk/eventhub/azure-eventhub/tests/unittest/test_connection_string_parser.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def test_eh_parse_malformed_conn_str_no_endpoint(self, **kwargs):
4343
with pytest.raises(ValueError) as e:
4444
parse_result = parse_connection_string(conn_str)
4545
assert str(e.value) == 'Connection string is either blank or malformed.'
46+
47+
def test_eh_parse_malformed_conn_str_no_endpoint_value(self, **kwargs):
48+
conn_str = 'Endpoint=;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
49+
with pytest.raises(ValueError) as e:
50+
parse_result = parse_connection_string(conn_str)
51+
assert str(e.value) == 'Connection string is either blank or malformed.'
4652

4753
def test_eh_parse_malformed_conn_str_no_netloc(self, **kwargs):
4854
conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
@@ -57,15 +63,55 @@ def test_eh_parse_conn_str_sas(self, **kwargs):
5763
assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net'
5864
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
5965
assert parse_result.shared_access_key_name == None
66+
67+
def test_eh_parse_conn_str_whitespace_trailing_semicolon(self, **kwargs):
68+
conn_str = ' Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=; '
69+
parse_result = parse_connection_string(conn_str)
70+
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
71+
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
72+
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
73+
assert parse_result.shared_access_key_name == None
74+
75+
def test_eh_parse_conn_str_sas_trailing_semicolon(self, **kwargs):
76+
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;'
77+
parse_result = parse_connection_string(conn_str)
78+
assert parse_result.endpoint == 'sb://resourcename.servicebus.windows.net/'
79+
assert parse_result.fully_qualified_namespace == 'resourcename.servicebus.windows.net'
80+
assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
81+
assert parse_result.shared_access_key_name == None
6082

6183
def test_eh_parse_conn_str_no_keyname(self, **kwargs):
6284
conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
6385
with pytest.raises(ValueError) as e:
6486
parse_result = parse_connection_string(conn_str)
65-
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'
87+
assert "Invalid connection string" in str(e.value)
6688

6789
def test_eh_parse_conn_str_no_key(self, **kwargs):
6890
conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy'
6991
with pytest.raises(ValueError) as e:
7092
parse_result = parse_connection_string(conn_str)
71-
assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'
93+
assert "Invalid connection string" in str(e.value)
94+
95+
def test_eh_parse_conn_str_no_key_or_sas(self, **kwargs):
96+
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/'
97+
with pytest.raises(ValueError) as e:
98+
parse_result = parse_connection_string(conn_str)
99+
assert str(e.value) == 'At least one of the SharedAccessKey or SharedAccessSignature must be present.'
100+
101+
def test_eh_parse_malformed_conn_str_lowercase_endpoint(self, **kwargs):
102+
conn_str = 'endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
103+
with pytest.raises(ValueError) as e:
104+
parse_result = parse_connection_string(conn_str)
105+
assert str(e.value) == 'Connection string is either blank or malformed.'
106+
107+
def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
108+
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;sharedaccesskeyname=test;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
109+
with pytest.raises(ValueError) as e:
110+
parse_result = parse_connection_string(conn_str)
111+
assert "Invalid connection string" in str(e.value)
112+
113+
def test_eh_parse_malformed_conn_str_lowercase_sa_key_name(self, **kwargs):
114+
conn_str = 'Endpoint=sb://resourcename.servicebus.windows.net/;SharedAccessKeyName=test;sharedaccesskey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX='
115+
with pytest.raises(ValueError) as e:
116+
parse_result = parse_connection_string(conn_str)
117+
assert "Invalid connection string" in str(e.value)

0 commit comments

Comments
 (0)