|
11 | 11 | ) |
12 | 12 | from kili.adapters.kili_api_gateway.asset.mappers import asset_where_mapper |
13 | 13 | from kili.adapters.kili_api_gateway.asset.operations import ( |
| 14 | + GQL_COUNT_ASSET_ANNOTATIONS, |
14 | 15 | GQL_COUNT_ASSETS, |
15 | 16 | GQL_CREATE_UPLOAD_BUCKET_SIGNED_URLS, |
16 | 17 | GQL_FILTER_EXISTING_ASSETS, |
|
27 | 28 | from kili.domain.asset import AssetFilters |
28 | 29 | from kili.domain.types import ListOrTuple |
29 | 30 |
|
| 31 | +# Threshold for batching based on number of annotations |
| 32 | +# This is used to determine whether to use a single batch or multiple batches |
| 33 | +# when fetching assets. If the number of annotations counted exceeds this threshold, |
| 34 | +# the asset fetch will be done in multiple smaller batches to avoid performance issues. |
| 35 | +THRESHOLD_FOR_BATCHING = 200 |
| 36 | + |
30 | 37 |
|
31 | 38 | class AssetOperationMixin(BaseOperationMixin): |
32 | 39 | """Mixin extending Kili API Gateway class with Assets related operations.""" |
@@ -74,13 +81,15 @@ def list_assets_split( |
74 | 81 | self, filters: AssetFilters, fields: ListOrTuple[str], options: QueryOptions, project_info |
75 | 82 | ) -> Generator[Dict, None, None]: |
76 | 83 | """List assets with given options.""" |
77 | | - options = QueryOptions( |
78 | | - options.disable_tqdm, |
79 | | - options.first, |
80 | | - options.skip, |
81 | | - min(options.batch_size, 10 if project_info["inputType"] == "VIDEO" else 50), |
| 84 | + nb_annotations = self.count_assets_annotations(filters) |
| 85 | + assets_batch_max_amount = 10 if project_info["inputType"] == "VIDEO" else 50 |
| 86 | + batch_size_to_use = min(options.batch_size, assets_batch_max_amount) |
| 87 | + batch_size = ( |
| 88 | + 1 if nb_annotations / batch_size_to_use > THRESHOLD_FOR_BATCHING else batch_size_to_use |
82 | 89 | ) |
83 | 90 |
|
| 91 | + options = QueryOptions(options.disable_tqdm, options.first, options.skip, batch_size) |
| 92 | + |
84 | 93 | inner_annotation_fragment = get_annotation_fragment() |
85 | 94 | annotation_fragment = f""" |
86 | 95 | annotations {{ |
@@ -149,3 +158,11 @@ def filter_existing_assets(self, project_id: str, assets_external_ids: ListOrTup |
149 | 158 | } |
150 | 159 | external_id_response = self.graphql_client.execute(GQL_FILTER_EXISTING_ASSETS, payload) |
151 | 160 | return external_id_response["external_ids"] |
| 161 | + |
| 162 | + def count_assets_annotations(self, filters: AssetFilters) -> int: |
| 163 | + """Count the number of annotations for assets matching the filters.""" |
| 164 | + where = asset_where_mapper(filters) |
| 165 | + payload = {"where": where} |
| 166 | + count_result = self.graphql_client.execute(GQL_COUNT_ASSET_ANNOTATIONS, payload) |
| 167 | + count: int = count_result["data"] |
| 168 | + return count |
0 commit comments