Skip to content

Commit 12725f5

Browse files
[ServiceBus] Adjust AutoLockRenewer to only allow registration of intended types (ReceivedMessage and ServiceBusSession) (Azure#14600)
* Adjust AutoLockRenewer to only allow registration of intended types (ReceivedMessage and ServiceBusSession) with the intent that if it was desired to allow an interfaced based approach it'd be easier to open that up later, and provide guardrails for now. * via_partition_key removal assumed another branch had been merged prior, reverted until that goes in. Co-authored-by: Adam Ling (MSFT) <adam_ling@outlook.com>
1 parent 8b12ebe commit 12725f5

File tree

6 files changed

+63
-27
lines changed

6 files changed

+63
-27
lines changed

sdk/servicebus/azure-servicebus/azure/servicebus/_common/auto_lock_renewer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from .._servicebus_receiver import ServiceBusReceiver
1515
from .._servicebus_session import ServiceBusSession
16+
from .message import ServiceBusReceivedMessage
1617
from ..exceptions import AutoLockRenewFailed, AutoLockRenewTimeout, ServiceBusError
1718
from .utils import renewable_start_time, utc_now
1819

1920
if TYPE_CHECKING:
20-
from typing import Callable, Union, Optional
21-
from .message import ServiceBusReceivedMessage
21+
from typing import Callable, Union, Optional, Awaitable
2222
LockRenewFailureCallback = Callable[[Union[ServiceBusSession, ServiceBusReceivedMessage],
2323
Optional[Exception]], None]
2424
Renewable = Union[ServiceBusSession, ServiceBusReceivedMessage]
@@ -144,6 +144,10 @@ def register(self, receiver, renewable, timeout=300, on_lock_renew_failure=None)
144144
145145
:rtype: None
146146
"""
147+
if not isinstance(renewable, (ServiceBusReceivedMessage, ServiceBusSession)):
148+
raise TypeError("AutoLockRenewer only supports registration of types "
149+
"azure.servicebus.ServiceBusReceivedMessage (via a receiver's receive methods) and "
150+
"azure.servicebus.ServiceBusSession (via a session receiver's property receiver.session).")
147151
if self._shutdown.is_set():
148152
raise ServiceBusError("The AutoLockRenewer has already been shutdown. Please create a new instance for"
149153
" auto lock renewing.")

sdk/servicebus/azure-servicebus/azure/servicebus/aio/_async_auto_lock_renewer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ def register(
144144
Default value is None (no callback).
145145
:rtype: None
146146
"""
147+
if not isinstance(renewable, (ServiceBusReceivedMessage, ServiceBusSession)):
148+
raise TypeError("AutoLockRenewer only supports registration of types "
149+
"azure.servicebus.ServiceBusReceivedMessage (via a receiver's receive methods) and "
150+
"azure.servicebus.aio.ServiceBusSession "
151+
"(via a session receiver's property receiver.session).")
147152
if self._shutdown.is_set():
148153
raise ServiceBusError("The AutoLockRenewer has already been shutdown. Please create a new instance for"
149154
" auto lock renewing.")

sdk/servicebus/azure-servicebus/tests/async_tests/mocks_async.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import timedelta
22

33
from azure.servicebus._common.utils import utc_now
4+
from azure.servicebus import ServiceBusReceivedMessage
45

56
class MockReceiver:
67
def __init__(self):
@@ -13,7 +14,7 @@ async def renew_message_lock(self, message):
1314
message.locked_until_utc = message.locked_until_utc + timedelta(seconds=message._lock_duration)
1415

1516

16-
class MockReceivedMessage:
17+
class MockReceivedMessage(ServiceBusReceivedMessage):
1718
def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
1819
self._lock_duration = 2
1920

@@ -29,4 +30,12 @@ def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
2930
def _lock_expired(self):
3031
if self.locked_until_utc and self.locked_until_utc <= utc_now():
3132
return True
32-
return False
33+
return False
34+
35+
@property
36+
def locked_until_utc(self):
37+
return self._locked_until_utc
38+
39+
@locked_until_utc.setter
40+
def locked_until_utc(self, value):
41+
self._locked_until_utc = value

sdk/servicebus/azure-servicebus/tests/async_tests/test_queues_async.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
MessageContentTooLarge,
4343
OperationTimeoutError
4444
)
45-
from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer
45+
from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer, AzureTestCase
4646
from servicebus_preparer import CachedServiceBusNamespacePreparer, CachedServiceBusQueuePreparer, ServiceBusQueuePreparer
4747
from utilities import get_logger, print_message, sleep_until_expired
48-
from mocks_async import MockReceivedMessage
48+
from mocks_async import MockReceivedMessage, MockReceiver
4949

5050
_logger = get_logger(logging.DEBUG)
5151

@@ -1145,7 +1145,7 @@ async def test_queue_message_settle_through_mgmt_link_due_to_broken_receiver_lin
11451145
assert len(messages) == 1
11461146
await receiver.complete_message(messages[0])
11471147

1148-
@pytest.mark.asyncio
1148+
@AzureTestCase.await_prepared_test
11491149
async def test_async_queue_mock_auto_lock_renew_callback(self):
11501150
results = []
11511151
errors = []
@@ -1154,11 +1154,16 @@ async def callback_mock(renewable, error):
11541154
if error:
11551155
errors.append(error)
11561156

1157+
receiver = MockReceiver()
11571158
auto_lock_renew = AutoLockRenewer()
1158-
auto_lock_renew._renew_period = 1 # So we can run the test fast.
1159-
async with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)
1159+
with pytest.raises(TypeError):
1160+
auto_lock_renew.register(receiver, renewable=Exception()) # an arbitrary invalid type.
1161+
1162+
auto_lock_renew = AutoLockRenewer()
1163+
auto_lock_renew._renew_period = 1 # So we can run the test fast.
1164+
async with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)
11601165
message = MockReceivedMessage(prevent_renew_lock=True)
1161-
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
1166+
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
11621167
await asyncio.sleep(3)
11631168
assert len(results) == 1 and results[-1]._lock_expired == True
11641169
assert not errors
@@ -1167,8 +1172,8 @@ async def callback_mock(renewable, error):
11671172
del errors[:]
11681173
auto_lock_renew = AutoLockRenewer()
11691174
auto_lock_renew._renew_period = 1
1170-
async with auto_lock_renew: # Check that in normal operation it does not get called
1171-
auto_lock_renew.register(renewable=MockReceivedMessage(), on_lock_renew_failure=callback_mock)
1175+
async with auto_lock_renew: # Check that in normal operation it does not get called
1176+
auto_lock_renew.register(receiver, renewable=MockReceivedMessage(), on_lock_renew_failure=callback_mock)
11721177
await asyncio.sleep(3)
11731178
assert not results
11741179
assert not errors
@@ -1177,9 +1182,9 @@ async def callback_mock(renewable, error):
11771182
del errors[:]
11781183
auto_lock_renew = AutoLockRenewer()
11791184
auto_lock_renew._renew_period = 1
1180-
async with auto_lock_renew: # Check that when a message is settled, it will not get called even after expiry
1185+
async with auto_lock_renew: # Check that when a message is settled, it will not get called even after expiry
11811186
message = MockReceivedMessage(prevent_renew_lock=True)
1182-
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
1187+
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
11831188
message._settled = True
11841189
await asyncio.sleep(3)
11851190
assert not results
@@ -1191,7 +1196,7 @@ async def callback_mock(renewable, error):
11911196
auto_lock_renew._renew_period = 1
11921197
async with auto_lock_renew: # Check that it is called when there is an overt renew failure
11931198
message = MockReceivedMessage(exception_on_renew_lock=True)
1194-
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
1199+
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
11951200
await asyncio.sleep(3)
11961201
assert len(results) == 1 and results[-1]._lock_expired == True
11971202
assert errors[-1]
@@ -1200,9 +1205,9 @@ async def callback_mock(renewable, error):
12001205
del errors[:]
12011206
auto_lock_renew = AutoLockRenewer()
12021207
auto_lock_renew._renew_period = 1
1203-
async with auto_lock_renew: # Check that it is not called when the renewer is shutdown
1208+
async with auto_lock_renew: # Check that it is not called when the renewer is shutdown
12041209
message = MockReceivedMessage(prevent_renew_lock=True)
1205-
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
1210+
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
12061211
await auto_lock_renew.close()
12071212
await asyncio.sleep(3)
12081213
assert not results
@@ -1212,35 +1217,35 @@ async def callback_mock(renewable, error):
12121217
del errors[:]
12131218
auto_lock_renew = AutoLockRenewer()
12141219
auto_lock_renew._renew_period = 1
1215-
async with auto_lock_renew: # Check that it is not called when the receiver is shutdown
1220+
async with auto_lock_renew: # Check that it is not called when the receiver is shutdown
12161221
message = MockReceivedMessage(prevent_renew_lock=True)
1217-
auto_lock_renew.register(renewable=message, on_lock_renew_failure=callback_mock)
1222+
auto_lock_renew.register(receiver, renewable=message, on_lock_renew_failure=callback_mock)
12181223
message._receiver._running = False
12191224
await asyncio.sleep(3)
12201225
assert not results
12211226
assert not errors
12221227

1223-
1224-
@pytest.mark.asyncio
1228+
@AzureTestCase.await_prepared_test
12251229
async def test_async_queue_mock_no_reusing_auto_lock_renew(self):
12261230
auto_lock_renew = AutoLockRenewer()
12271231
auto_lock_renew._renew_period = 1
12281232

1233+
receiver = MockReceiver()
12291234
async with auto_lock_renew:
1230-
auto_lock_renew.register(renewable=MockReceivedMessage())
1235+
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
12311236
await asyncio.sleep(3)
12321237

12331238
with pytest.raises(ServiceBusError):
12341239
async with auto_lock_renew:
12351240
pass
12361241

12371242
with pytest.raises(ServiceBusError):
1238-
auto_lock_renew.register(renewable=MockReceivedMessage())
1243+
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
12391244

12401245
auto_lock_renew = AutoLockRenewer()
12411246
auto_lock_renew._renew_period = 1
12421247

1243-
auto_lock_renew.register(renewable=MockReceivedMessage())
1248+
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
12441249
time.sleep(3)
12451250

12461251
await auto_lock_renew.close()
@@ -1250,7 +1255,7 @@ async def test_async_queue_mock_no_reusing_auto_lock_renew(self):
12501255
pass
12511256

12521257
with pytest.raises(ServiceBusError):
1253-
auto_lock_renew.register(renewable=MockReceivedMessage())
1258+
auto_lock_renew.register(receiver, renewable=MockReceivedMessage())
12541259

12551260
@pytest.mark.liveTest
12561261
@pytest.mark.live_test_only

sdk/servicebus/azure-servicebus/tests/mocks.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import timedelta
22

33
from azure.servicebus._common.utils import utc_now
4+
from azure.servicebus import ServiceBusReceivedMessage
45

56

67
class MockReceiver:
@@ -14,7 +15,7 @@ def renew_message_lock(self, message):
1415
message.locked_until_utc = message.locked_until_utc + timedelta(seconds=message._lock_duration)
1516

1617

17-
class MockReceivedMessage:
18+
class MockReceivedMessage(ServiceBusReceivedMessage):
1819
def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
1920
self._lock_duration = 2
2021

@@ -31,4 +32,12 @@ def __init__(self, prevent_renew_lock=False, exception_on_renew_lock=False):
3132
def _lock_expired(self):
3233
if self.locked_until_utc and self.locked_until_utc <= utc_now():
3334
return True
34-
return False
35+
return False
36+
37+
@property
38+
def locked_until_utc(self):
39+
return self._locked_until_utc
40+
41+
@locked_until_utc.setter
42+
def locked_until_utc(self, value):
43+
self._locked_until_utc = value

sdk/servicebus/azure-servicebus/tests/test_queues.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,6 +1409,10 @@ def callback_mock(renewable, error):
14091409
errors.append(error)
14101410

14111411
receiver = MockReceiver()
1412+
auto_lock_renew = AutoLockRenewer()
1413+
with pytest.raises(TypeError):
1414+
auto_lock_renew.register(Exception()) # an arbitrary invalid type.
1415+
14121416
auto_lock_renew = AutoLockRenewer()
14131417
auto_lock_renew._renew_period = 1 # So we can run the test fast.
14141418
with auto_lock_renew: # Check that it is called when the object expires for any reason (silent renew failure)

0 commit comments

Comments
 (0)