|
1 | | -from collections import UserDict |
| 1 | +import asyncio |
| 2 | +from collections import defaultdict, UserDict |
2 | 3 | from collections.abc import Iterable, Mapping, MutableMapping |
3 | 4 | from copy import copy |
4 | 5 | import json |
5 | 6 | import os |
6 | 7 | import tarfile |
7 | | -from tempfile import TemporaryDirectory |
8 | | -from typing import Any, Callable, cast, Generic, Iterator |
9 | | -from urllib.parse import ParseResult, urlparse |
| 8 | +from typing import Any, Callable, cast, Generic, Iterator, TYPE_CHECKING |
10 | 9 |
|
11 | | -# from orjson import orjson |
12 | 10 | from pydantic import Field, PrivateAttr, root_validator, ValidationError |
13 | 11 | from pydantic.fields import ModelField, Undefined, UndefinedType |
14 | 12 | from pydantic.generics import GenericModel |
|
30 | 28 | prepare_selected_items_with_mapping_data, |
31 | 29 | select_keys) |
32 | 30 | from omnipy.util.decorators import call_super_if_available |
33 | | -from omnipy.util.helpers import get_default_if_typevar, is_iterable, remove_forward_ref_notation |
| 31 | +from omnipy.util.helpers import (get_default_if_typevar, |
| 32 | + get_event_loop_and_check_if_loop_is_running, |
| 33 | + is_iterable, |
| 34 | + remove_forward_ref_notation) |
34 | 35 | from omnipy.util.web import download_file_to_memory |
35 | 36 |
|
| 37 | +if TYPE_CHECKING: |
| 38 | + from omnipy.modules.remote.datasets import HttpUrlDataset |
| 39 | + from omnipy.modules.remote.models import HttpUrlModel |
| 40 | + |
36 | 41 | ModelT = TypeVar('ModelT', bound=Model) |
37 | 42 | GeneralModelT = TypeVar('GeneralModelT', bound=Model) |
38 | 43 | _DatasetT = TypeVar('_DatasetT') |
@@ -544,45 +549,87 @@ def save(self, path: str): |
544 | 549 | tar.extractall(path=directory) |
545 | 550 | tar.close() |
546 | 551 |
|
547 | | - def load(self, *path_or_urls: str, by_file_suffix=False): |
| 552 | + def load(self, |
| 553 | + paths_or_urls: 'str | Iterable[str] | HttpUrlModel | HttpUrlDataset', |
| 554 | + by_file_suffix: bool = False) -> list[asyncio.Task] | None: |
| 555 | + from omnipy import HttpUrlDataset, HttpUrlModel |
| 556 | + |
| 557 | + match paths_or_urls: |
| 558 | + case HttpUrlDataset(): |
| 559 | + return self._load_http_urls(paths_or_urls) |
| 560 | + |
| 561 | + case HttpUrlModel(): |
| 562 | + return self._load_http_urls(HttpUrlDataset({str(paths_or_urls): paths_or_urls})) |
| 563 | + |
| 564 | + case str(): |
| 565 | + try: |
| 566 | + http_url_dataset = HttpUrlDataset({paths_or_urls: paths_or_urls}) |
| 567 | + except ValidationError: |
| 568 | + return self._load_paths([paths_or_urls], by_file_suffix) |
| 569 | + return self._load_http_urls(http_url_dataset) |
| 570 | + case Iterable(): |
| 571 | + try: |
| 572 | + path_or_url_iterable = cast(Iterable[str], paths_or_urls) |
| 573 | + http_url_dataset = HttpUrlDataset( |
| 574 | + zip(path_or_url_iterable, path_or_url_iterable)) |
| 575 | + except ValidationError: |
| 576 | + return self._load_paths(path_or_url_iterable, by_file_suffix) |
| 577 | + return self._load_http_urls(http_url_dataset) |
| 578 | + case _: |
| 579 | + raise TypeError(f'"paths_or_urls" argument is of incorrect type. Type ' |
| 580 | + f'{type(paths_or_urls)} is not supported.') |
| 581 | + |
| 582 | + def _load_http_urls(self, http_url_dataset: 'HttpUrlDataset') -> list[asyncio.Task]: |
| 583 | + from omnipy.modules.remote.helpers import RateLimitingClientSession |
| 584 | + from omnipy.modules.remote.tasks import get_json_from_api_endpoint |
| 585 | + hosts: defaultdict[str, list[int]] = defaultdict(list) |
| 586 | + for i, url in enumerate(http_url_dataset.values()): |
| 587 | + hosts[url.host].append(i) |
| 588 | + |
| 589 | + async def load_all(): |
| 590 | + tasks = [] |
| 591 | + client_sessions = {} |
| 592 | + for host in hosts: |
| 593 | + client_sessions[host] = RateLimitingClientSession( |
| 594 | + self.config.http_config_for_host[host].requests_per_time_period, |
| 595 | + self.config.http_config_for_host[host].time_period_in_secs) |
| 596 | + |
| 597 | + for host, indices in hosts.items(): |
| 598 | + task = ( |
| 599 | + get_json_from_api_endpoint.refine(output_dataset_param='output_dataset').run( |
| 600 | + http_url_dataset[indices], |
| 601 | + client_session=client_sessions[host], |
| 602 | + output_dataset=self)) |
| 603 | + tasks.append(task) |
| 604 | + |
| 605 | + await asyncio.gather(*tasks) |
| 606 | + return self |
| 607 | + |
| 608 | + loop, loop_is_running = get_event_loop_and_check_if_loop_is_running() |
| 609 | + |
| 610 | + if loop and loop_is_running: |
| 611 | + return loop.create_task(load_all()) |
| 612 | + else: |
| 613 | + return asyncio.run(load_all()) |
| 614 | + |
| 615 | + def _load_paths(self, path_or_urls: Iterable[str], by_file_suffix: bool) -> None: |
548 | 616 | for path_or_url in path_or_urls: |
549 | | - if is_model_instance(path_or_url): |
550 | | - path_or_url = path_or_url.contents |
551 | | - |
552 | | - with TemporaryDirectory() as tmp_dir_path: |
553 | | - serializer_registry = self._get_serializer_registry() |
554 | | - |
555 | | - parsed_url = urlparse(path_or_url) |
556 | | - |
557 | | - if parsed_url.scheme in ['http', 'https']: |
558 | | - download_path = self._download_file(path_or_url, parsed_url.path, tmp_dir_path) |
559 | | - if download_path is None: |
560 | | - continue |
561 | | - tar_gz_file_path = self._ensure_tar_gz_file(download_path) |
562 | | - elif parsed_url.scheme in ['file', '']: |
563 | | - tar_gz_file_path = self._ensure_tar_gz_file(parsed_url.path) |
564 | | - elif self._is_windows_path(parsed_url): |
565 | | - tar_gz_file_path = self._ensure_tar_gz_file(path_or_url) |
566 | | - else: |
567 | | - raise ValueError(f'Unsupported scheme "{parsed_url.scheme}"') |
568 | | - |
569 | | - if by_file_suffix: |
570 | | - loaded_dataset = \ |
571 | | - serializer_registry.load_from_tar_file_path_based_on_file_suffix( |
572 | | - self, tar_gz_file_path, self) |
573 | | - else: |
574 | | - loaded_dataset = \ |
575 | | - serializer_registry.load_from_tar_file_path_based_on_dataset_cls( |
576 | | - self, tar_gz_file_path, self) |
577 | | - if loaded_dataset is not None: |
578 | | - self.absorb(loaded_dataset) |
579 | | - continue |
580 | | - else: |
581 | | - raise RuntimeError('Unable to load serializer') |
| 617 | + serializer_registry = self._get_serializer_registry() |
| 618 | + tar_gz_file_path = self._ensure_tar_gz_file(path_or_url) |
582 | 619 |
|
583 | | - @staticmethod |
584 | | - def _is_windows_path(parsed_url: ParseResult) -> bool: |
585 | | - return len(parsed_url.scheme) == 1 and parsed_url.scheme.isalpha() |
| 620 | + if by_file_suffix: |
| 621 | + loaded_dataset = \ |
| 622 | + serializer_registry.load_from_tar_file_path_based_on_file_suffix( |
| 623 | + self, tar_gz_file_path, self) |
| 624 | + else: |
| 625 | + loaded_dataset = \ |
| 626 | + serializer_registry.load_from_tar_file_path_based_on_dataset_cls( |
| 627 | + self, tar_gz_file_path, self) |
| 628 | + if loaded_dataset is not None: |
| 629 | + self.absorb(loaded_dataset) |
| 630 | + continue |
| 631 | + else: |
| 632 | + raise RuntimeError('Unable to load from serializer') |
586 | 633 |
|
587 | 634 | @staticmethod |
588 | 635 | def _download_file(url: str, path: str, tmp_dir_path: str) -> str | None: |
@@ -638,7 +685,7 @@ def __eq__(self, other: object) -> bool: |
638 | 685 | and self.to_data() == other.to_data() # last is probably unnecessary, but just in case |
639 | 686 |
|
640 | 687 | def __repr_args__(self): |
641 | | - return [(k, v.contents) for k, v in self.data.items()] |
| 688 | + return [(k, v.contents) if is_model_instance(v) else (k, v) for k, v in self.data.items()] |
642 | 689 |
|
643 | 690 |
|
644 | 691 | class MultiModelDataset(Dataset[GeneralModelT], Generic[GeneralModelT]): |
|
0 commit comments