Skip to content

Commit c81aebe

Browse files
committed
Fixed closing of client sessions
1 parent ae3ed8d commit c81aebe

File tree

2 files changed

+22
-15
lines changed

2 files changed

+22
-15
lines changed

src/omnipy/data/dataset.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_DatasetT = TypeVar('_DatasetT')
4444

4545
DATA_KEY = 'data'
46+
ASYNC_LOAD_SLEEP_TIME = 0.05
4647

4748
# def orjson_dumps(v, *, default):
4849
# # orjson.dumps returns bytes, to match standard json.dumps we need to decode
@@ -588,19 +589,22 @@ def _load_http_urls(self, http_url_dataset: 'HttpUrlDataset') -> list[asyncio.Ta
588589

589590
async def load_all():
590591
tasks = []
591-
client_sessions = {}
592+
592593
for host in hosts:
593-
client_sessions[host] = RateLimitingClientSession(
594-
self.config.http_config_for_host[host].requests_per_time_period,
595-
self.config.http_config_for_host[host].time_period_in_secs)
596-
597-
for host, indices in hosts.items():
598-
task = (
599-
get_json_from_api_endpoint.refine(output_dataset_param='output_dataset').run(
600-
http_url_dataset[indices],
601-
client_session=client_sessions[host],
602-
output_dataset=self))
603-
tasks.append(task)
594+
async with RateLimitingClientSession(
595+
self.config.http_config_for_host[host].requests_per_time_period,
596+
self.config.http_config_for_host[host].time_period_in_secs
597+
) as client_session:
598+
indices = hosts[host]
599+
task = (
600+
get_json_from_api_endpoint.refine(
601+
output_dataset_param='output_dataset').run(
602+
http_url_dataset[indices],
603+
client_session=client_session,
604+
output_dataset=self))
605+
tasks.append(task)
606+
while not task.done():
607+
await asyncio.sleep(ASYNC_LOAD_SLEEP_TIME)
604608

605609
await asyncio.gather(*tasks)
606610
return self

src/omnipy/modules/remote/tasks.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ async def get_json_from_api_endpoint(
9292
retry_backoff_strategy,
9393
):
9494
async for response in _call_get(url, cast(ClientSession, retry_session)):
95-
return JsonModel(await response.json(content_type=None))
95+
output_data = JsonModel(await response.json(content_type=None))
96+
return output_data
9697

9798

9899
@TaskTemplate(iterate_over_data_files=True, output_dataset_cls=StrDataset)
@@ -110,7 +111,8 @@ async def get_str_from_api_endpoint(
110111
retry_backoff_strategy,
111112
):
112113
async for response in _call_get(url, cast(ClientSession, retry_session)):
113-
return StrModel(await response.text())
114+
output_data = StrModel(await response.text())
115+
return output_data
114116

115117

116118
@TaskTemplate(iterate_over_data_files=True, output_dataset_cls=BytesDataset)
@@ -128,7 +130,8 @@ async def get_bytes_from_api_endpoint(
128130
retry_backoff_strategy,
129131
):
130132
async for response in _call_get(url, cast(ClientSession, retry_session)):
131-
return BytesModel(await response.read())
133+
output_data = BytesModel(await response.read())
134+
return output_data
132135

133136

134137
JsonDatasetT = TypeVar('JsonDatasetT', bound=Dataset)

0 commit comments

Comments
 (0)