Skip to content

Commit 87e4b8d

Browse files
committed
Implement async download of data from urls in Dataset.load(). Refactor file-based loading
1 parent dc310ba commit 87e4b8d

File tree

1 file changed

+91
-44
lines changed

1 file changed

+91
-44
lines changed

src/omnipy/data/dataset.py

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
from collections import UserDict
1+
import asyncio
2+
from collections import defaultdict, UserDict
23
from collections.abc import Iterable, Mapping, MutableMapping
34
from copy import copy
45
import json
56
import os
67
import tarfile
7-
from tempfile import TemporaryDirectory
8-
from typing import Any, Callable, cast, Generic, Iterator
9-
from urllib.parse import ParseResult, urlparse
8+
from typing import Any, Callable, cast, Generic, Iterator, TYPE_CHECKING
109

11-
# from orjson import orjson
1210
from pydantic import Field, PrivateAttr, root_validator, ValidationError
1311
from pydantic.fields import ModelField, Undefined, UndefinedType
1412
from pydantic.generics import GenericModel
@@ -30,9 +28,16 @@
3028
prepare_selected_items_with_mapping_data,
3129
select_keys)
3230
from omnipy.util.decorators import call_super_if_available
33-
from omnipy.util.helpers import get_default_if_typevar, is_iterable, remove_forward_ref_notation
31+
from omnipy.util.helpers import (get_default_if_typevar,
32+
get_event_loop_and_check_if_loop_is_running,
33+
is_iterable,
34+
remove_forward_ref_notation)
3435
from omnipy.util.web import download_file_to_memory
3536

37+
if TYPE_CHECKING:
38+
from omnipy.modules.remote.datasets import HttpUrlDataset
39+
from omnipy.modules.remote.models import HttpUrlModel
40+
3641
ModelT = TypeVar('ModelT', bound=Model)
3742
GeneralModelT = TypeVar('GeneralModelT', bound=Model)
3843
_DatasetT = TypeVar('_DatasetT')
@@ -544,45 +549,87 @@ def save(self, path: str):
544549
tar.extractall(path=directory)
545550
tar.close()
546551

547-
def load(self, *path_or_urls: str, by_file_suffix=False):
552+
def load(self,
553+
paths_or_urls: 'str | Iterable[str] | HttpUrlModel | HttpUrlDataset',
554+
by_file_suffix: bool = False) -> list[asyncio.Task] | None:
555+
from omnipy import HttpUrlDataset, HttpUrlModel
556+
557+
match paths_or_urls:
558+
case HttpUrlDataset():
559+
return self._load_http_urls(paths_or_urls)
560+
561+
case HttpUrlModel():
562+
return self._load_http_urls(HttpUrlDataset({str(paths_or_urls): paths_or_urls}))
563+
564+
case str():
565+
try:
566+
http_url_dataset = HttpUrlDataset({paths_or_urls: paths_or_urls})
567+
except ValidationError:
568+
return self._load_paths([paths_or_urls], by_file_suffix)
569+
return self._load_http_urls(http_url_dataset)
570+
case Iterable():
571+
try:
572+
path_or_url_iterable = cast(Iterable[str], paths_or_urls)
573+
http_url_dataset = HttpUrlDataset(
574+
zip(path_or_url_iterable, path_or_url_iterable))
575+
except ValidationError:
576+
return self._load_paths(path_or_url_iterable, by_file_suffix)
577+
return self._load_http_urls(http_url_dataset)
578+
case _:
579+
raise TypeError(f'"paths_or_urls" argument is of incorrect type. Type '
580+
f'{type(paths_or_urls)} is not supported.')
581+
582+
def _load_http_urls(self, http_url_dataset: 'HttpUrlDataset') -> list[asyncio.Task]:
583+
from omnipy.modules.remote.helpers import RateLimitingClientSession
584+
from omnipy.modules.remote.tasks import get_json_from_api_endpoint
585+
hosts: defaultdict[str, list[int]] = defaultdict(list)
586+
for i, url in enumerate(http_url_dataset.values()):
587+
hosts[url.host].append(i)
588+
589+
async def load_all():
590+
tasks = []
591+
client_sessions = {}
592+
for host in hosts:
593+
client_sessions[host] = RateLimitingClientSession(
594+
self.config.http_config_for_host[host].requests_per_time_period,
595+
self.config.http_config_for_host[host].time_period_in_secs)
596+
597+
for host, indices in hosts.items():
598+
task = (
599+
get_json_from_api_endpoint.refine(output_dataset_param='output_dataset').run(
600+
http_url_dataset[indices],
601+
client_session=client_sessions[host],
602+
output_dataset=self))
603+
tasks.append(task)
604+
605+
await asyncio.gather(*tasks)
606+
return self
607+
608+
loop, loop_is_running = get_event_loop_and_check_if_loop_is_running()
609+
610+
if loop and loop_is_running:
611+
return loop.create_task(load_all())
612+
else:
613+
return asyncio.run(load_all())
614+
615+
def _load_paths(self, path_or_urls: Iterable[str], by_file_suffix: bool) -> None:
548616
for path_or_url in path_or_urls:
549-
if is_model_instance(path_or_url):
550-
path_or_url = path_or_url.contents
551-
552-
with TemporaryDirectory() as tmp_dir_path:
553-
serializer_registry = self._get_serializer_registry()
554-
555-
parsed_url = urlparse(path_or_url)
556-
557-
if parsed_url.scheme in ['http', 'https']:
558-
download_path = self._download_file(path_or_url, parsed_url.path, tmp_dir_path)
559-
if download_path is None:
560-
continue
561-
tar_gz_file_path = self._ensure_tar_gz_file(download_path)
562-
elif parsed_url.scheme in ['file', '']:
563-
tar_gz_file_path = self._ensure_tar_gz_file(parsed_url.path)
564-
elif self._is_windows_path(parsed_url):
565-
tar_gz_file_path = self._ensure_tar_gz_file(path_or_url)
566-
else:
567-
raise ValueError(f'Unsupported scheme "{parsed_url.scheme}"')
568-
569-
if by_file_suffix:
570-
loaded_dataset = \
571-
serializer_registry.load_from_tar_file_path_based_on_file_suffix(
572-
self, tar_gz_file_path, self)
573-
else:
574-
loaded_dataset = \
575-
serializer_registry.load_from_tar_file_path_based_on_dataset_cls(
576-
self, tar_gz_file_path, self)
577-
if loaded_dataset is not None:
578-
self.absorb(loaded_dataset)
579-
continue
580-
else:
581-
raise RuntimeError('Unable to load serializer')
617+
serializer_registry = self._get_serializer_registry()
618+
tar_gz_file_path = self._ensure_tar_gz_file(path_or_url)
582619

583-
@staticmethod
584-
def _is_windows_path(parsed_url: ParseResult) -> bool:
585-
return len(parsed_url.scheme) == 1 and parsed_url.scheme.isalpha()
620+
if by_file_suffix:
621+
loaded_dataset = \
622+
serializer_registry.load_from_tar_file_path_based_on_file_suffix(
623+
self, tar_gz_file_path, self)
624+
else:
625+
loaded_dataset = \
626+
serializer_registry.load_from_tar_file_path_based_on_dataset_cls(
627+
self, tar_gz_file_path, self)
628+
if loaded_dataset is not None:
629+
self.absorb(loaded_dataset)
630+
continue
631+
else:
632+
raise RuntimeError('Unable to load from serializer')
586633

587634
@staticmethod
588635
def _download_file(url: str, path: str, tmp_dir_path: str) -> str | None:
@@ -638,7 +685,7 @@ def __eq__(self, other: object) -> bool:
638685
and self.to_data() == other.to_data() # last is probably unnecessary, but just in case
639686

640687
def __repr_args__(self):
641-
return [(k, v.contents) for k, v in self.data.items()]
688+
return [(k, v.contents) if is_model_instance(v) else (k, v) for k, v in self.data.items()]
642689

643690

644691
class MultiModelDataset(Dataset[GeneralModelT], Generic[GeneralModelT]):

0 commit comments

Comments
 (0)