Skip to content

Commit 811cf03

Browse files
[EventHubs] Fix bug in reusing EventHubProducerClient (Azure#21927)
* fix bug * run black * fix mypy * fix pylint * update changelog to be more clear * Update sdk/eventhub/azure-eventhub/CHANGELOG.md Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com> Co-authored-by: swathipil <76007337+swathipil@users.noreply.github.com>
1 parent 08819ce commit 811cf03

File tree

5 files changed

+94
-26
lines changed

5 files changed

+94
-26
lines changed

sdk/eventhub/azure-eventhub/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fixed a bug that `EventHubProducerClient` could be reopened for sending events instead of encountering with `KeyError` when the client is previously closed (issue #21849).
12+
1113
### Other Changes
1214

1315
## 5.6.1 (2021-10-06)

sdk/eventhub/azure-eventhub/azure/eventhub/_producer_client.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from ._common import EventDataBatch, EventData
1818

1919
if TYPE_CHECKING:
20-
from azure.core.credentials import TokenCredential, AzureSasCredential, AzureNamedKeyCredential
20+
from azure.core.credentials import (
21+
TokenCredential,
22+
AzureSasCredential,
23+
AzureNamedKeyCredential,
24+
)
2125

2226
SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]
2327

@@ -143,7 +147,10 @@ def _start_producer(self, partition_id, send_timeout):
143147
or cast(EventHubProducer, self._producers[partition_id]).closed
144148
):
145149
self._producers[partition_id] = self._create_producer(
146-
partition_id=partition_id, send_timeout=send_timeout
150+
partition_id=(
151+
None if partition_id == ALL_PARTITIONS else partition_id
152+
),
153+
send_timeout=send_timeout,
147154
)
148155

149156
def _create_producer(self, partition_id=None, send_timeout=None):
@@ -261,14 +268,21 @@ def send_batch(self, event_data_batch, **kwargs):
261268

262269
if isinstance(event_data_batch, EventDataBatch):
263270
if partition_id or partition_key:
264-
raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch "
265-
"because type EventDataBatch itself may have partition_id or partition_key")
271+
raise TypeError(
272+
"partition_id and partition_key should be None when sending an EventDataBatch "
273+
"because type EventDataBatch itself may have partition_id or partition_key"
274+
)
266275
to_send_batch = event_data_batch
267276
else:
268-
to_send_batch = self.create_batch(partition_id=partition_id, partition_key=partition_key)
269-
to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access
277+
to_send_batch = self.create_batch(
278+
partition_id=partition_id, partition_key=partition_key
279+
)
280+
to_send_batch._load_events( # pylint:disable=protected-access
281+
event_data_batch
282+
)
270283
partition_id = (
271-
to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access
284+
to_send_batch._partition_id # pylint:disable=protected-access
285+
or ALL_PARTITIONS
272286
)
273287

274288
if len(to_send_batch) == 0:
@@ -400,8 +414,8 @@ def close(self):
400414
401415
"""
402416
with self._lock:
403-
for producer in self._producers.values():
404-
if producer:
405-
producer.close()
406-
self._producers = {}
417+
for pid in self._producers:
418+
if self._producers[pid]:
419+
self._producers[pid].close() # type: ignore
420+
self._producers[pid] = None
407421
super(EventHubProducerClient, self)._close()

sdk/eventhub/azure-eventhub/azure/eventhub/aio/_producer_client_async.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
if TYPE_CHECKING:
2121
from azure.core.credentials_async import AsyncTokenCredential
22-
from uamqp.constants import TransportType # pylint: disable=ungrouped-imports
22+
from uamqp.constants import TransportType # pylint: disable=ungrouped-imports
2323

2424
SendEventTypes = List[Union[EventData, AmqpAnnotatedMessage]]
2525

@@ -80,7 +80,9 @@ def __init__(
8080
self,
8181
fully_qualified_namespace: str,
8282
eventhub_name: str,
83-
credential: Union["AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential],
83+
credential: Union[
84+
"AsyncTokenCredential", AzureSasCredential, AzureNamedKeyCredential
85+
],
8486
**kwargs
8587
) -> None:
8688
super(EventHubProducerClient, self).__init__(
@@ -145,7 +147,10 @@ async def _start_producer(
145147
or cast(EventHubProducer, self._producers[partition_id]).closed
146148
):
147149
self._producers[partition_id] = self._create_producer(
148-
partition_id=partition_id, send_timeout=send_timeout
150+
partition_id=(
151+
None if partition_id == ALL_PARTITIONS else partition_id
152+
),
153+
send_timeout=send_timeout,
149154
)
150155

151156
def _create_producer(
@@ -294,18 +299,25 @@ async def send_batch(
294299

295300
if isinstance(event_data_batch, EventDataBatch):
296301
if partition_id or partition_key:
297-
raise TypeError("partition_id and partition_key should be None when sending an EventDataBatch "
298-
"because type EventDataBatch itself may have partition_id or partition_key")
302+
raise TypeError(
303+
"partition_id and partition_key should be None when sending an EventDataBatch "
304+
"because type EventDataBatch itself may have partition_id or partition_key"
305+
)
299306
to_send_batch = event_data_batch
300307
else:
301-
to_send_batch = await self.create_batch(partition_id=partition_id, partition_key=partition_key)
302-
to_send_batch._load_events(event_data_batch) # pylint:disable=protected-access
308+
to_send_batch = await self.create_batch(
309+
partition_id=partition_id, partition_key=partition_key
310+
)
311+
to_send_batch._load_events( # pylint:disable=protected-access
312+
event_data_batch
313+
)
303314

304315
if len(to_send_batch) == 0:
305316
return
306317

307318
partition_id = (
308-
to_send_batch._partition_id or ALL_PARTITIONS # pylint:disable=protected-access
319+
to_send_batch._partition_id # pylint:disable=protected-access
320+
or ALL_PARTITIONS
309321
)
310322
try:
311323
await cast(EventHubProducer, self._producers[partition_id]).send(
@@ -431,7 +443,9 @@ async def close(self) -> None:
431443
432444
"""
433445
async with self._lock:
434-
for producer in self._producers.values():
435-
if producer:
436-
await producer.close()
446+
for pid in self._producers:
447+
if self._producers[pid] is not None:
448+
await self._producers[pid].close() # type: ignore
449+
self._producers[pid] = None
450+
437451
await super(EventHubProducerClient, self)._close_async()

sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_send_async.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,34 @@ async def test_send_and_receive_small_body_async(connstr_receivers, payload):
189189
async def test_send_partition_async(connstr_receivers):
190190
connection_str, receivers = connstr_receivers
191191
client = EventHubProducerClient.from_connection_string(connection_str)
192+
193+
async with client:
194+
batch = await client.create_batch()
195+
batch.add(EventData(b"Data"))
196+
await client.send_batch(batch)
197+
192198
async with client:
193199
batch = await client.create_batch(partition_id="1")
194200
batch.add(EventData(b"Data"))
195201
await client.send_batch(batch)
196202

197203
partition_0 = receivers[0].receive_message_batch(timeout=5000)
198-
assert len(partition_0) == 0
199204
partition_1 = receivers[1].receive_message_batch(timeout=5000)
200-
assert len(partition_1) == 1
205+
assert len(partition_0) + len(partition_1) == 2
206+
207+
async with client:
208+
batch = await client.create_batch()
209+
batch.add(EventData(b"Data"))
210+
await client.send_batch(batch)
211+
212+
async with client:
213+
batch = await client.create_batch(partition_id="1")
214+
batch.add(EventData(b"Data"))
215+
await client.send_batch(batch)
216+
217+
partition_0 = receivers[0].receive_message_batch(timeout=5000)
218+
partition_1 = receivers[1].receive_message_batch(timeout=5000)
219+
assert len(partition_0) + len(partition_1) == 2
201220

202221

203222
@pytest.mark.liveTest

sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_send.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,34 @@ def test_send_and_receive_small_body(connstr_receivers, payload):
206206
def test_send_partition(connstr_receivers):
207207
connection_str, receivers = connstr_receivers
208208
client = EventHubProducerClient.from_connection_string(connection_str)
209+
210+
with client:
211+
batch = client.create_batch()
212+
batch.add(EventData(b"Data"))
213+
client.send_batch(batch)
214+
209215
with client:
210216
batch = client.create_batch(partition_id="1")
211217
batch.add(EventData(b"Data"))
212218
client.send_batch(batch)
213219

214220
partition_0 = receivers[0].receive_message_batch(timeout=5000)
215-
assert len(partition_0) == 0
216221
partition_1 = receivers[1].receive_message_batch(timeout=5000)
217-
assert len(partition_1) == 1
222+
assert len(partition_0) + len(partition_1) == 2
223+
224+
with client:
225+
batch = client.create_batch()
226+
batch.add(EventData(b"Data"))
227+
client.send_batch(batch)
228+
229+
with client:
230+
batch = client.create_batch(partition_id="1")
231+
batch.add(EventData(b"Data"))
232+
client.send_batch(batch)
233+
234+
partition_0 = receivers[0].receive_message_batch(timeout=5000)
235+
partition_1 = receivers[1].receive_message_batch(timeout=5000)
236+
assert len(partition_0) + len(partition_1) == 2
218237

219238

220239
@pytest.mark.liveTest

0 commit comments

Comments
 (0)