From d9121f8232a54149e631e0ea67b365bfcaf0a19a Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 09:30:31 +0200 Subject: [PATCH 01/83] feat: add sharding support for GeoZarr conversion and CLI --- src/eopf_geozarr/cli.py | 6 ++ src/eopf_geozarr/conversion/geozarr.py | 97 ++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 5ff9b81b..77e24e22 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -175,6 +175,7 @@ def convert_command(args: argparse.Namespace) -> None: max_retries=args.max_retries, crs_groups=args.crs_groups, gcp_group=args.gcp_group, + enable_sharding=args.enable_sharding, ) print("✅ Successfully converted EOPF dataset to GeoZarr format") @@ -1109,6 +1110,11 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Start a local dask cluster for parallel processing of chunks", ) + convert_parser.add_argument( + "--enable-sharding", + action="store_true", + help="Enable zarr sharding for spatial dimensions of each variable", + ) convert_parser.set_defaults(func=convert_command) # Info command diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index ccfe4e45..896f0004 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -26,7 +26,7 @@ import zarr from pyproj import CRS from rasterio.warp import calculate_default_transform -from zarr.codecs import BloscCodec +from zarr.codecs import BloscCodec, ShardingCodec from zarr.core.sync import sync from zarr.storage import StoreLike from zarr.storage._common import make_store_path @@ -57,6 +57,7 @@ def create_geozarr_dataset( max_retries: int = 3, crs_groups: Iterable[str] | None = None, gcp_group: str | None = None, + enable_sharding: bool = False, ) -> xr.DataTree: """ Create a GeoZarr-spec 0.4 compliant dataset from EOPF data. @@ -81,6 +82,8 @@ def create_geozarr_dataset( Iterable of group names that need CRS information added on best-effort basis gcp_group : str, optional Group name where GCPs (Ground Control Points) are located. + enable_sharding : bool, default False + Enable zarr sharding for spatial dimensions of each variable Returns ------- @@ -89,6 +92,9 @@ def create_geozarr_dataset( """ dt = dt_input.copy() compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) + + if enable_sharding: + print("🔧 Zarr sharding enabled for spatial dimensions") if _is_sentinel1(dt_input): if gcp_group is None: @@ -132,6 +138,7 @@ def create_geozarr_dataset( max_retries, crs_groups, gcp_group, + enable_sharding, ) # Consolidate metadata at the root level AFTER all groups are written @@ -230,6 +237,7 @@ def iterative_copy( max_retries: int = 3, crs_groups: Iterable[str] | None = None, gcp_group: str | None = None, + enable_sharding: bool = False, ) -> xr.DataTree: """ Iteratively copy groups from original DataTree to GeoZarr DataTree. @@ -301,6 +309,7 @@ def iterative_copy( min_dimension=min_dimension, tile_width=tile_width, gcp_group=gcp_group, + enable_sharding=enable_sharding, ) written_groups.add(current_group_path) continue @@ -407,6 +416,7 @@ def write_geozarr_group( min_dimension: int = 256, tile_width: int = 256, gcp_group: str | None = None, + enable_sharding: bool = False, ) -> xr.DataTree: """ Write a group to a GeoZarr dataset with multiscales support. @@ -451,7 +461,7 @@ def write_geozarr_group( dt.attrs = ds.attrs.copy() # Create encoding for all variables - encoding = _create_geozarr_encoding(ds, compressor, spatial_chunk) + encoding = _create_geozarr_encoding(ds, compressor, spatial_chunk, enable_sharding) # Write native data in the group 0 (overview level 0) native_dataset_group_name = f"{group_name}/0" @@ -1442,7 +1452,7 @@ def _create_encoding( def _create_geozarr_encoding( - ds: xr.Dataset, compressor: Any, spatial_chunk: int + ds: xr.Dataset, compressor: Any, spatial_chunk: int, enable_sharding: bool = False ) -> dict[Hashable, XarrayEncodingJSON]: """Create encoding for GeoZarr dataset variables.""" encoding: dict[Hashable, XarrayEncodingJSON] = {} @@ -1461,10 +1471,16 @@ def _create_geozarr_encoding( else: spatial_chunk_aligned = spatial_chunk - encoding[var] = { - "chunks": (spatial_chunk_aligned, spatial_chunk_aligned), - "compressors": compressor, - } + if enable_sharding and len(data_shape) >= 2: + # Create sharding configuration for spatial dimensions + encoding[var] = _create_sharded_encoding( + data_shape, spatial_chunk_aligned, compressor + ) + else: + encoding[var] = { + "chunks": (spatial_chunk_aligned, spatial_chunk_aligned), + "compressors": compressor, + } # Add coordinate encoding for coord in ds.coords: @@ -1473,6 +1489,73 @@ def _create_geozarr_encoding( return encoding +def _create_sharded_encoding( + data_shape: tuple[int, ...], spatial_chunk: int, compressor: Any +) -> XarrayEncodingJSON: + """ + Create sharded encoding configuration for spatial dimensions. + + Parameters + ---------- + data_shape : tuple[int, ...] + Shape of the data array + spatial_chunk : int + Spatial chunk size + compressor : Any + Compressor to use + + Returns + ------- + dict + Encoding configuration with sharding + """ + # Calculate shard configuration based on spatial dimensions + if len(data_shape) == 3: + # 3D array (time, y, x) or (band, y, x) + height, width = data_shape[-2:] + + # Use full spatial dimensions for shards + shard_height = height + shard_width = width + + # Chunk dimensions within shards + chunk_height = min(spatial_chunk, height) + chunk_width = min(spatial_chunk, width) + + # Create sharding codec + sharding_codec = ShardingCodec( + chunk_shape=(1, chunk_height, chunk_width), + codecs=[compressor] + ) + + return { + "chunks": (1, shard_height, shard_width), # Full spatial dimensions per shard + "compressors": [sharding_codec], + } + else: + # 2D array (y, x) + height, width = data_shape[-2:] + + # Use full spatial dimensions for shards + shard_height = height + shard_width = width + + # Chunk dimensions within shards + chunk_height = min(spatial_chunk, height) + chunk_width = min(spatial_chunk, width) + + # Create sharding codec + sharding_codec = ShardingCodec( + chunk_shape=(chunk_height, chunk_width), + codecs=[compressor] + ) + + return { + "chunks": (shard_height, shard_width), # Full spatial dimensions per shard + "compressors": [sharding_codec], + } + + def _load_existing_dataset(path: str) -> xr.Dataset | None: """Load existing dataset if it exists.""" try: From 8b83e771bfecc744c216a5955655add628c946a5 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 07:33:24 +0000 Subject: [PATCH 02/83] update launch configurations for GeoZarr conversion with new data sources and adjusted parameters --- .vscode/launch.json | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index acfba750..731a5e77 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -13,15 +13,16 @@ "module": "eopf_geozarr", "args": [ "convert", - "https://objectstore.eodc.eu:2222/e05ab01a9d56408d82ac32d69a5aae2a:sample-data/tutorial_data/cpm_v253/S2B_MSIL1C_20250113T103309_N0511_R108_T32TLQ_20250113T122458.zarr", - "./tests-output/eopf_geozarr/s2b_test.zarr", - "--groups", "/measurements/reflectance/r10m", "/measurements/reflectance/r20m", "/measurements/reflectance/r60m", "/quality/l1c_quicklook/r10m", + "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202508-s02msil2a/11/products/cpm_v256/S2C_MSIL2A_20250811T112131_N0511_R037_T29TPF_20250811T152216.zarr", + "./tests-output/eopf_geozarr/s2l2_test.zarr", + "--groups", "/measurements/reflectance/r10m", "/measurements/reflectance/r20m", "/measurements/reflectance/r60m", "/quality/l2a_quicklook/r10m", "--crs-groups", "/conditions/geometry", - "--spatial-chunk", "4096", + "--spatial-chunk", "512", "--min-dimension", "256", "--tile-width", "256", "--max-retries", "2", - "--verbose" + "--verbose", + "--enable" ], "cwd": "${workspaceFolder}", "justMyCode": false, @@ -99,14 +100,22 @@ // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202507-s02msil2a/04/products/cpm_v256/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202508-s02msil2a/04/products/cpm_v256/S2B_MSIL2A_20250804T103629_N0511_R008_T31TDH_20250804T130722.zarr", // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202508-s02msil2a/07/products/cpm_v256/S2B_MSIL2A_20250807T104619_N0511_R051_T31TDH_20250807T131144.zarr", - "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202508-s02msil2a/11/products/cpm_v256/S2C_MSIL2A_20250811T112131_N0511_R037_T29TPF_20250811T152216.zarr", - // "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", - // "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S2B_MSIL2A_20250804T103629_N0511_R008_T31TDH_20250804T130722.zarr", - // "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S2B_MSIL2A_20250807T104619_N0511_R051_T31TDH_20250807T131144.zarr", - "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S2C_MSIL2A_20250811T112131_N0511_R037_T29TPF_20250811T152216.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202508-s02msil2a/11/products/cpm_v256/S2C_MSIL2A_20250811T112131_N0511_R037_T29TPF_20250811T152216.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/13/products/cpm_v256/S2C_MSIL2A_20250913T095041_N0511_R079_T33TVF_20250913T151113.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/21/products/cpm_v256/S2B_MSIL2A_20250921T100029_N0511_R122_T32TQM_20250921T135752.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/21/products/cpm_v256/S2B_MSIL2A_20250921T100029_N0511_R122_T33TTG_20250921T135752.zarr", + "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/08/products/cpm_v256/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2B_MSIL2A_20250804T103629_N0511_R008_T31TDH_20250804T130722.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2B_MSIL2A_20250807T104619_N0511_R051_T31TDH_20250807T131144.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2C_MSIL2A_20250811T112131_N0511_R037_T29TPF_20250811T152216.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2C_MSIL2A_20250913T095041_N0511_R079_T33TVF_20250913T151113.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2B_MSIL2A_20250921T100029_N0511_R122_T32TQM_20250921T135752.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2B_MSIL2A_20250921T100029_N0511_R122_T33TTG_20250921T135752.zarr", + "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", "--groups", "/measurements/reflectance/r10m", "/measurements/reflectance/r20m", "/measurements/reflectance/r60m", "/quality/l2a_quicklook/r10m", "--crs-groups", "/conditions/geometry", - "--spatial-chunk", "512", + "--spatial-chunk", "256", "--min-dimension", "256", "--tile-width", "256", "--max-retries", "2", @@ -156,8 +165,12 @@ "module": "eopf_geozarr", "args": [ "convert", - "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s01siwgrh/12/products/cpm_v256/S1C_IW_GRDH_1SDV_20250912T053648_20250912T053713_004087_0081FD_5AA4.zarr", - "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S1C_IW_GRDH_1SDV_20250912T053648_20250912T053713_004087_0081FD_5AA4.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:notebook-data/tutorial_data/cpm_v260/S1A_IW_GRDH_1SDV_20241124T180254_20241124T180319_056700_06F516_BA27.zarr", + "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:notebook-data/tutorial_data/cpm_v260/S1A_IW_GRDH_1SDV_20241218T180252_20241218T180317_057050_0702F2_0BC2.zarr", + // "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s01siwgrh/12/products/cpm_v256/S1A_IW_GRDH_1SDV_20241230T180251_20241230T180316_057225_0709DD_15AC.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S1A_IW_GRDH_1SDV_20241124T180254_20241124T180319_056700_06F516_BA27_2.zarr", + "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel1-l1-grd/S1A_IW_GRDH_1SDV_20241218T180252_20241218T180317_057050_0702F2_0BC2.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/eopf_geozarr/S1A_IW_GRDH_1SDV_20241230T180251_20241230T180316_057225_0709DD_15AC.zarr", "--groups", "/measurements", "--gcp-group", "/conditions/gcp", // "--crs-groups", "/conditions/geometry", @@ -205,7 +218,7 @@ "module": "eopf_geozarr", "args": [ "info", - "./tests-output/eopf_geozarr/s2b_test.zarr", + "./tests-output/eopf_geozarr/s2l2_test.zarr", "--verbose", "--html-output", "dataset_info.html" ], From 367b14615bee81b4ae4c21f52e422488f4dbaa80 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 07:36:26 +0000 Subject: [PATCH 03/83] feat: enable sharding in GeoZarr conversion launch configuration --- .vscode/launch.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 744fcc87..e7bd8e81 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -22,7 +22,7 @@ "--tile-width", "256", "--max-retries", "2", "--verbose", - "--enable" + "--enable-sharding" ], "cwd": "${workspaceFolder}", "justMyCode": false, From 30d6fb046a362cb99f675830ae464994ef02c25a Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 09:36:56 +0200 Subject: [PATCH 04/83] fix: update sharding codec handling in _create_sharded_encoding function --- src/eopf_geozarr/conversion/geozarr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 896f0004..dff3f9d1 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1522,7 +1522,7 @@ def _create_sharded_encoding( chunk_height = min(spatial_chunk, height) chunk_width = min(spatial_chunk, width) - # Create sharding codec + # Create sharding codec with proper codec chain sharding_codec = ShardingCodec( chunk_shape=(1, chunk_height, chunk_width), codecs=[compressor] @@ -1530,7 +1530,7 @@ def _create_sharded_encoding( return { "chunks": (1, shard_height, shard_width), # Full spatial dimensions per shard - "compressors": [sharding_codec], + "compressors": sharding_codec, # Pass sharding codec directly, not in list } else: # 2D array (y, x) @@ -1544,7 +1544,7 @@ def _create_sharded_encoding( chunk_height = min(spatial_chunk, height) chunk_width = min(spatial_chunk, width) - # Create sharding codec + # Create sharding codec with proper codec chain sharding_codec = ShardingCodec( chunk_shape=(chunk_height, chunk_width), codecs=[compressor] @@ -1552,7 +1552,7 @@ def _create_sharded_encoding( return { "chunks": (shard_height, shard_width), # Full spatial dimensions per shard - "compressors": [sharding_codec], + "compressors": sharding_codec, # Pass sharding codec directly, not in list } From 4f38d681e3eac34d5a2d51314109462353d5d05b Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 09:48:02 +0200 Subject: [PATCH 05/83] refactor: streamline sharding configuration in _create_geozarr_encoding function --- src/eopf_geozarr/conversion/geozarr.py | 92 ++++---------------------- 1 file changed, 14 insertions(+), 78 deletions(-) diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index dff3f9d1..6097a685 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1470,17 +1470,20 @@ def _create_geozarr_encoding( ) else: spatial_chunk_aligned = spatial_chunk - - if enable_sharding and len(data_shape) >= 2: - # Create sharding configuration for spatial dimensions - encoding[var] = _create_sharded_encoding( - data_shape, spatial_chunk_aligned, compressor - ) - else: - encoding[var] = { - "chunks": (spatial_chunk_aligned, spatial_chunk_aligned), - "compressors": compressor, - } + + shards = None + + if enable_sharding: + if len(data_shape) == 3: + shards = (1, data_shape[1], data_shape[2]) + else: + shards = (data_shape[0], data_shape[1]) + + encoding[var] = { + "chunks": (spatial_chunk_aligned, spatial_chunk_aligned), + "compressors": compressor, + "shards": shards, + } # Add coordinate encoding for coord in ds.coords: @@ -1489,73 +1492,6 @@ def _create_geozarr_encoding( return encoding -def _create_sharded_encoding( - data_shape: tuple[int, ...], spatial_chunk: int, compressor: Any -) -> XarrayEncodingJSON: - """ - Create sharded encoding configuration for spatial dimensions. - - Parameters - ---------- - data_shape : tuple[int, ...] - Shape of the data array - spatial_chunk : int - Spatial chunk size - compressor : Any - Compressor to use - - Returns - ------- - dict - Encoding configuration with sharding - """ - # Calculate shard configuration based on spatial dimensions - if len(data_shape) == 3: - # 3D array (time, y, x) or (band, y, x) - height, width = data_shape[-2:] - - # Use full spatial dimensions for shards - shard_height = height - shard_width = width - - # Chunk dimensions within shards - chunk_height = min(spatial_chunk, height) - chunk_width = min(spatial_chunk, width) - - # Create sharding codec with proper codec chain - sharding_codec = ShardingCodec( - chunk_shape=(1, chunk_height, chunk_width), - codecs=[compressor] - ) - - return { - "chunks": (1, shard_height, shard_width), # Full spatial dimensions per shard - "compressors": sharding_codec, # Pass sharding codec directly, not in list - } - else: - # 2D array (y, x) - height, width = data_shape[-2:] - - # Use full spatial dimensions for shards - shard_height = height - shard_width = width - - # Chunk dimensions within shards - chunk_height = min(spatial_chunk, height) - chunk_width = min(spatial_chunk, width) - - # Create sharding codec with proper codec chain - sharding_codec = ShardingCodec( - chunk_shape=(chunk_height, chunk_width), - codecs=[compressor] - ) - - return { - "chunks": (shard_height, shard_width), # Full spatial dimensions per shard - "compressors": sharding_codec, # Pass sharding codec directly, not in list - } - - def _load_existing_dataset(path: str) -> xr.Dataset | None: """Load existing dataset if it exists.""" try: From 37ad2c035b8b34636d50fd4ec4769bb8e5ff9c3d Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 09:51:51 +0200 Subject: [PATCH 06/83] feat: enhance sharding logic in _create_geozarr_encoding and add _calculate_shard_dimension utility --- src/eopf_geozarr/conversion/geozarr.py | 48 ++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 6097a685..fa2e86e9 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1468,19 +1468,36 @@ def _create_geozarr_encoding( utils.calculate_aligned_chunk_size(width, spatial_chunk), utils.calculate_aligned_chunk_size(height, spatial_chunk), ) + + if len(data_shape) == 3: + chunks = (1, spatial_chunk_aligned, spatial_chunk_aligned) + else: + chunks = (spatial_chunk_aligned, spatial_chunk_aligned) else: spatial_chunk_aligned = spatial_chunk + chunks = (spatial_chunk_aligned,) shards = None if enable_sharding: + # Calculate shard dimensions that are divisible by chunk dimensions if len(data_shape) == 3: - shards = (1, data_shape[1], data_shape[2]) + # For 3D data (time, y, x), ensure shard dimensions are divisible by chunks + shard_time = data_shape[0] # Keep full time dimension + shard_y = _calculate_shard_dimension(data_shape[1], chunks[1]) + shard_x = _calculate_shard_dimension(data_shape[2], chunks[2]) + shards = (shard_time, shard_y, shard_x) + elif len(data_shape) == 2: + # For 2D data (y, x), ensure shard dimensions are divisible by chunks + shard_y = _calculate_shard_dimension(data_shape[0], chunks[0]) + shard_x = _calculate_shard_dimension(data_shape[1], chunks[1]) + shards = (shard_y, shard_x) else: - shards = (data_shape[0], data_shape[1]) + # For 1D data, use the full dimension + shards = (data_shape[0],) encoding[var] = { - "chunks": (spatial_chunk_aligned, spatial_chunk_aligned), + "chunks": chunks, "compressors": compressor, "shards": shards, } @@ -1637,6 +1654,31 @@ def _add_grid_mapping_variable( print(f" Added grid_mapping attribute to {var_name}") +def _calculate_shard_dimension(data_dim: int, chunk_dim: int) -> int: + """ + Calculate shard dimension that is divisible by chunk dimension. + + Parameters + ---------- + data_dim : int + Size of the data dimension + chunk_dim : int + Size of the chunk dimension + + Returns + ------- + int + Shard dimension that is divisible by chunk_dim + """ + # If chunk is larger than or equal to data dimension, use full dimension + if chunk_dim >= data_dim: + return data_dim + + # Find the largest multiple of chunk_dim that doesn't exceed data_dim + # This ensures the shard dimension is divisible by chunk dimension + return (data_dim // chunk_dim) * chunk_dim + + def _is_sentinel1(dt: xr.DataTree) -> bool: """Return True if the input DataTree represents a Sentinel-1 product.""" stac_props = dt.attrs.get("stac_discovery", {}).get("properties", {}) From 67faa58c45da8f536a7456a76866c9df4b4177da Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 10:05:59 +0200 Subject: [PATCH 07/83] feat: improve sharding configuration and validation in _create_geozarr_encoding --- src/eopf_geozarr/conversion/geozarr.py | 32 ++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index fa2e86e9..19d2a7a7 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1487,14 +1487,22 @@ def _create_geozarr_encoding( shard_y = _calculate_shard_dimension(data_shape[1], chunks[1]) shard_x = _calculate_shard_dimension(data_shape[2], chunks[2]) shards = (shard_time, shard_y, shard_x) + print(f" 🔧 Sharding config for {var}: data_shape={data_shape}, chunks={chunks}, shards={shards}") elif len(data_shape) == 2: # For 2D data (y, x), ensure shard dimensions are divisible by chunks shard_y = _calculate_shard_dimension(data_shape[0], chunks[0]) shard_x = _calculate_shard_dimension(data_shape[1], chunks[1]) shards = (shard_y, shard_x) + print(f" 🔧 Sharding config for {var}: data_shape={data_shape}, chunks={chunks}, shards={shards}") else: # For 1D data, use the full dimension shards = (data_shape[0],) + print(f" 🔧 Sharding config for {var}: data_shape={data_shape}, chunks={chunks}, shards={shards}") + + # Validate that shards are evenly divisible by chunks + for i, (shard_dim, chunk_dim) in enumerate(zip(shards, chunks)): + if shard_dim % chunk_dim != 0: + print(f" ⚠️ Warning: Shard dimension {shard_dim} not evenly divisible by chunk dimension {chunk_dim} at axis {i}") encoding[var] = { "chunks": chunks, @@ -1656,7 +1664,10 @@ def _add_grid_mapping_variable( def _calculate_shard_dimension(data_dim: int, chunk_dim: int) -> int: """ - Calculate shard dimension that is divisible by chunk dimension. + Calculate shard dimension that is evenly divisible by chunk dimension. + + For Zarr v3 sharding with Dask, the shard dimension must be evenly + divisible by the chunk dimension to avoid checksum mismatches. Parameters ---------- @@ -1668,15 +1679,26 @@ def _calculate_shard_dimension(data_dim: int, chunk_dim: int) -> int: Returns ------- int - Shard dimension that is divisible by chunk_dim + Shard dimension that is evenly divisible by chunk_dim """ # If chunk is larger than or equal to data dimension, use full dimension if chunk_dim >= data_dim: return data_dim - # Find the largest multiple of chunk_dim that doesn't exceed data_dim - # This ensures the shard dimension is divisible by chunk dimension - return (data_dim // chunk_dim) * chunk_dim + # Calculate how many complete chunks fit in the data dimension + num_complete_chunks = data_dim // chunk_dim + + # If we have at least 2 complete chunks, use a multiple of chunk_dim + if num_complete_chunks >= 2: + # Use a shard size that's a multiple of chunk_dim + # Prefer 2x, 3x, 4x, 5x, 6x chunk size, but don't exceed data dimension + for multiplier in [6, 5, 4, 3, 2]: + shard_size = multiplier * chunk_dim + if shard_size <= data_dim: + return shard_size + + # Fallback: use the largest multiple of chunk_dim that fits + return num_complete_chunks * chunk_dim if num_complete_chunks > 0 else data_dim def _is_sentinel1(dt: xr.DataTree) -> bool: From cccd8fd446b070a41022c09ab97bde20847e6aeb Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 10:13:59 +0200 Subject: [PATCH 08/83] fix: refine shard dimension calculation and improve divisor check in utility functions --- src/eopf_geozarr/conversion/geozarr.py | 7 ++++--- src/eopf_geozarr/conversion/utils.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 19d2a7a7..19eae088 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1681,7 +1681,8 @@ def _calculate_shard_dimension(data_dim: int, chunk_dim: int) -> int: int Shard dimension that is evenly divisible by chunk_dim """ - # If chunk is larger than or equal to data dimension, use full dimension + # If chunk is larger than data dimension, the effective chunk will be data_dim + # In this case, shard should also be data_dim to maintain divisibility if chunk_dim >= data_dim: return data_dim @@ -1691,13 +1692,13 @@ def _calculate_shard_dimension(data_dim: int, chunk_dim: int) -> int: # If we have at least 2 complete chunks, use a multiple of chunk_dim if num_complete_chunks >= 2: # Use a shard size that's a multiple of chunk_dim - # Prefer 2x, 3x, 4x, 5x, 6x chunk size, but don't exceed data dimension - for multiplier in [6, 5, 4, 3, 2]: + for multiplier in range(num_complete_chunks + 1, 2, -1): shard_size = multiplier * chunk_dim if shard_size <= data_dim: return shard_size # Fallback: use the largest multiple of chunk_dim that fits + # If no complete chunks fit, use data_dim (this handles edge cases) return num_complete_chunks * chunk_dim if num_complete_chunks > 0 else data_dim diff --git a/src/eopf_geozarr/conversion/utils.py b/src/eopf_geozarr/conversion/utils.py index 17edc246..d6ccd033 100644 --- a/src/eopf_geozarr/conversion/utils.py +++ b/src/eopf_geozarr/conversion/utils.py @@ -125,7 +125,7 @@ def calculate_aligned_chunk_size(dimension_size: int, target_chunk_size: int) -> # Find the largest divisor of dimension_size that is <= target_chunk_size for chunk_size in range(target_chunk_size, 0, -1): - if dimension_size % chunk_size < 0.1 * chunk_size: + if dimension_size % chunk_size == 0: return chunk_size # If no divisor is found, return the closest value to target_chunk_size From 064917a61a32938f77b16eff6d531ba9b3dd6c8e Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 11:41:14 +0200 Subject: [PATCH 09/83] Add dataset tree structure and test script for sharding fix - Introduced a new dataset tree structure for Sentinel-2 data, detailing conditions, quality, and measurements. - Added a comprehensive test script to verify the sharding fix for GeoZarr conversion. - Implemented tests for shard dimension calculations and encoding creation with sharding enabled/disabled. - Enhanced output for better debugging and validation of shard dimensions against chunk dimensions. --- .vscode/launch.json | 2 +- dataset_info.html | 29443 +++++++++++++++++++++++ dataset_tree_simplified.txt | 112 + src/eopf_geozarr/conversion/geozarr.py | 6 +- test_sharding_fix.py | 145 + 5 files changed, 29706 insertions(+), 2 deletions(-) create mode 100644 dataset_info.html create mode 100644 dataset_tree_simplified.txt create mode 100644 test_sharding_fix.py diff --git a/.vscode/launch.json b/.vscode/launch.json index e7bd8e81..a76bd927 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -237,7 +237,7 @@ "module": "eopf_geozarr", "args": [ "info", - "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2B_MSIL2A_20250921T100029_N0511_R122_T33TTG_20250921T135752.zarr", + "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", "--verbose", "--html-output", "dataset_info.html" ], diff --git a/dataset_info.html b/dataset_info.html new file mode 100644 index 00000000..6a36a8cb --- /dev/null +++ b/dataset_info.html @@ -0,0 +1,29443 @@ + + + + + + + DataTree Visualization - S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr + + + +
+
+

S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr

+
+
+
Dataset Path
+
s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr
+
+
+
Total Groups
+
3
+
+
+
Generated
+
2025-09-26 11:24:44
+
+
+
+ +
+
+ +
+ + +
+
+ + 📁 + root + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + conditions + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + geometry + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 670kB
+Dimensions:                        (angle: 2, band: 13, x: 23, y: 23,
+                                    detector: 6)
+Coordinates:
+  * angle                          (angle) <U7 56B 'zenith' 'azimuth'
+  * band                           (band) <U3 156B 'b01' 'b02' ... 'b11' 'b12'
+  * x                              (x) int64 184B 499980 504980 ... 609980
+  * y                              (y) int64 184B 4200000 4195000 ... 4090000
+  * detector                       (detector) int64 48B 1 2 3 4 5 6
+Data variables:
+    mean_sun_angles                (angle) float64 16B dask.array<chunksize=(2,), meta=np.ndarray>
+    mean_viewing_incidence_angles  (band, angle) float64 208B dask.array<chunksize=(13, 2), meta=np.ndarray>
+    spatial_ref                    int64 8B ...
+    sun_angles                     (angle, y, x) float64 8kB dask.array<chunksize=(2, 23, 23), meta=np.ndarray>
+    viewing_incidence_angles       (band, detector, angle, y, x) float64 660kB dask.array<chunksize=(13, 6, 2, 23, 23), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📁 + mask + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + l1c_classification + (1 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r60m + (1 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 3MB
+Dimensions:  (x: 1830, y: 1830)
+Coordinates:
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+Data variables:
+    b00      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + l2a_classification + (2 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r20m + (1 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 30MB
+Dimensions:  (x: 5490, y: 5490)
+Coordinates:
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+Data variables:
+    scl      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r60m + (1 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 3MB
+Dimensions:  (y: 1830, x: 1830)
+Coordinates:
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+Data variables:
+    scl      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + detector_footprint + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r10m + (4 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 482MB
+Dimensions:  (x: 10980, y: 10980)
+Coordinates:
+  * x        (x) int64 88kB 499985 499995 500005 500015 ... 609755 609765 609775
+  * y        (y) int64 88kB 4199995 4199985 4199975 ... 4090225 4090215 4090205
+Data variables:
+    b03      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b04      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b02      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b08      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r20m + (6 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 181MB
+Dimensions:  (y: 5490, x: 5490)
+Coordinates:
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+Data variables:
+    b05      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b06      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b07      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b11      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b12      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b8a      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r60m + (3 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 10MB
+Dimensions:  (x: 1830, y: 1830)
+Coordinates:
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+Data variables:
+    b09      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b01      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b10      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + meteorology + (2 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + cams + (11 variables • 7 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 4kB
+Dimensions:        (longitude: 9, latitude: 9)
+Coordinates:
+  * longitude      (longitude) float64 72B 15.0 15.15 15.31 ... 16.08 16.23
+    number         int64 8B 0
+  * latitude       (latitude) float64 72B 37.95 37.82 37.7 ... 37.2 37.07 36.95
+    isobaricInhPa  float64 8B 0.0
+    surface        float64 8B 0.0
+    step           timedelta64[ns] 8B 09:46:59.999999998
+    valid_time     datetime64[ns] 8B 2025-07-04T09:47:00
+    time           datetime64[ns] 8B 2025-07-04
+Data variables:
+    aod1240        (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    aod469         (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    aod550         (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    bcaod550       (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    duaod550       (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    aod670         (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    aod865         (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    ssaod550       (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    omaod550       (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    suaod550       (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    z              (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ Conventions: + CF-1.7 +
+ +
+ GRIB_centre: + ecmf +
+ +
+ GRIB_centreDescription: + European Centre for Medium-Range Weather Forecasts +
+ +
+ GRIB_edition: + 1 +
+ +
+ GRIB_subCentre: + 0 +
+
... and 2 more
+
+
+ +
+
+
+ +
+
+ + 📄 + ecmwf + (6 variables • 7 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 2kB
+Dimensions:        (latitude: 9, longitude: 9)
+Coordinates:
+    isobaricInhPa  float64 8B 0.0
+  * latitude       (latitude) float64 72B 37.95 37.82 37.7 ... 37.2 37.07 36.95
+    number         int64 8B 0
+  * longitude      (longitude) float64 72B 15.0 15.15 15.31 ... 16.08 16.23
+    surface        float64 8B 0.0
+    step           timedelta64[ns] 8B 09:46:59.999999998
+    time           datetime64[ns] 8B 2025-07-04
+    valid_time     datetime64[ns] 8B 2025-07-04T09:47:00
+Data variables:
+    r              (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    msl            (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    tco3           (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    tcwv           (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    u10            (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+    v10            (latitude, longitude) float32 324B dask.array<chunksize=(9, 9), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ Conventions: + CF-1.7 +
+ +
+ GRIB_centre: + ecmf +
+ +
+ GRIB_centreDescription: + European Centre for Medium-Range Weather Forecasts +
+ +
+ GRIB_edition: + 1 +
+ +
+ GRIB_subCentre: + 0 +
+
... and 2 more
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + quality + (4 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + atmosphere + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r20m + (2 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 121MB
+Dimensions:  (x: 5490, y: 5490)
+Coordinates:
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+Data variables:
+    aot      (y, x) uint16 60MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    wvp      (y, x) uint16 60MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r60m + (2 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 13MB
+Dimensions:  (x: 1830, y: 1830)
+Coordinates:
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+Data variables:
+    aot      (y, x) uint16 7MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    wvp      (y, x) uint16 7MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r10m + (2 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 482MB
+Dimensions:  (x: 10980, y: 10980)
+Coordinates:
+  * x        (x) int64 88kB 499985 499995 500005 500015 ... 609755 609765 609775
+  * y        (y) int64 88kB 4199995 4199985 4199975 ... 4090225 4090215 4090205
+Data variables:
+    aot      (y, x) uint16 241MB dask.array<chunksize=(8192, 8192), meta=np.ndarray>
+    wvp      (y, x) uint16 241MB dask.array<chunksize=(8192, 8192), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + mask + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r20m + (6 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 181MB
+Dimensions:  (x: 5490, y: 5490)
+Coordinates:
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+Data variables:
+    b05      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b8a      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b07      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b06      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b11      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    b12      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r60m + (3 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 10MB
+Dimensions:  (x: 1830, y: 1830)
+Coordinates:
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+Data variables:
+    b01      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b09      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b10      (y, x) uint8 3MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r10m + (4 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 482MB
+Dimensions:  (y: 10980, x: 10980)
+Coordinates:
+  * y        (y) int64 88kB 4199995 4199985 4199975 ... 4090225 4090215 4090205
+  * x        (x) int64 88kB 499985 499995 500005 500015 ... 609755 609765 609775
+Data variables:
+    b02      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b04      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b03      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+    b08      (y, x) uint8 121MB dask.array<chunksize=(10980, 10980), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + probability + (1 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📄 + r20m + (2 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 60MB
+Dimensions:  (y: 5490, x: 5490)
+Coordinates:
+    band     int64 8B 1
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+Data variables:
+    cld      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+    snw      (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + l2a_quicklook + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + r10m + (1 attributes • 6 subgroups) + +
+ +
+

Attributes

+
+ +
+ multiscales: + {'tile_matrix_set': {'id': 'Native_CRS_32633', ... +
+ +
+
+ +
+

Subgroups

+
+ +
+
+ + 📄 + 0 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 362MB
+Dimensions:      (band: 3, x: 10980, y: 10980)
+Coordinates:
+  * band         (band) int64 24B 1 2 3
+  * x            (x) int64 88kB 499985 499995 500005 ... 609755 609765 609775
+  * y            (y) int64 88kB 4199995 4199985 4199975 ... 4090215 4090205
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 362MB dask.array<chunksize=(3, 3992, 10980), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 1 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 91MB
+Dimensions:      (x: 5490, y: 5490, band: 3)
+Coordinates:
+  * x            (x) float64 44kB 5e+05 5e+05 5e+05 ... 6.097e+05 6.098e+05
+  * y            (y) float64 44kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Dimensions without coordinates: band
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 90MB dask.array<chunksize=(3, 5490, 5490), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 2 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 23MB
+Dimensions:      (x: 2745, y: 2745, band: 3)
+Coordinates:
+  * x            (x) float64 22kB 5e+05 5e+05 5.001e+05 ... 6.097e+05 6.097e+05
+  * y            (y) float64 22kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Dimensions without coordinates: band
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 23MB dask.array<chunksize=(3, 2745, 2745), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 3 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 6MB
+Dimensions:      (x: 1372, y: 1372, band: 3)
+Coordinates:
+  * x            (x) float64 11kB 5e+05 5.001e+05 ... 6.096e+05 6.097e+05
+  * y            (y) float64 11kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Dimensions without coordinates: band
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 6MB dask.array<chunksize=(3, 1372, 1372), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 4 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 1MB
+Dimensions:      (x: 686, y: 686, band: 3)
+Coordinates:
+  * x            (x) float64 5kB 5e+05 5.001e+05 ... 6.095e+05 6.096e+05
+  * y            (y) float64 5kB 4.2e+06 4.2e+06 4.2e+06 ... 4.091e+06 4.09e+06
+Dimensions without coordinates: band
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 1MB dask.array<chunksize=(3, 686, 686), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 5 + (2 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 358kB
+Dimensions:      (x: 343, y: 343, band: 3)
+Coordinates:
+  * x            (x) float64 3kB 5e+05 5.003e+05 ... 6.091e+05 6.095e+05
+  * y            (y) float64 3kB 4.2e+06 4.2e+06 ... 4.091e+06 4.091e+06
+Dimensions without coordinates: band
+Data variables:
+    spatial_ref  int64 8B ...
+    tci          (band, y, x) uint8 353kB dask.array<chunksize=(3, 343, 343), meta=np.ndarray>
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📄 + r20m + (1 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 91MB
+Dimensions:  (band: 3, x: 5490, y: 5490)
+Coordinates:
+  * band     (band) int64 24B 1 2 3
+  * x        (x) int64 44kB 499990 500010 500030 500050 ... 609730 609750 609770
+  * y        (y) int64 44kB 4199990 4199970 4199950 ... 4090250 4090230 4090210
+Data variables:
+    tci      (band, y, x) uint8 90MB dask.array<chunksize=(3, 5490, 5490), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ + 📄 + r60m + (1 variables) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 10MB
+Dimensions:  (band: 3, x: 1830, y: 1830)
+Coordinates:
+  * band     (band) int64 24B 1 2 3
+  * x        (x) int64 15kB 500010 500070 500130 500190 ... 609630 609690 609750
+  * y        (y) int64 15kB 4199970 4199910 4199850 ... 4090350 4090290 4090230
+Data variables:
+    tci      (band, y, x) uint8 10MB dask.array<chunksize=(3, 1830, 1830), meta=np.ndarray>
+
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + measurements + (1 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + reflectance + (3 subgroups) + +
+ +
+

Subgroups

+
+ +
+
+ + 📁 + r20m + (1 attributes • 5 subgroups) + +
+ +
+

Attributes

+
+ +
+ multiscales: + {'tile_matrix_set': {'id': 'Native_CRS_32633', ... +
+ +
+
+ +
+

Subgroups

+
+ +
+
+ + 📄 + 0 + (11 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 2GB
+Dimensions:      (x: 5490, y: 5490)
+Coordinates:
+  * x            (x) int64 44kB 499990 500010 500030 ... 609730 609750 609770
+  * y            (y) int64 44kB 4199990 4199970 4199950 ... 4090230 4090210
+Data variables:
+    b01          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b02          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b03          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b04          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b05          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b06          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b07          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b11          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b12          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b8a          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 1 + (11 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 603MB
+Dimensions:      (x: 2745, y: 2745)
+Coordinates:
+  * x            (x) float64 22kB 5e+05 5e+05 5.001e+05 ... 6.097e+05 6.097e+05
+  * y            (y) float64 22kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b01          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b02          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b03          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b04          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b05          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b06          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b07          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b11          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b12          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b8a          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 2 + (11 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 151MB
+Dimensions:      (x: 1372, y: 1372)
+Coordinates:
+  * x            (x) float64 11kB 5e+05 5.001e+05 ... 6.096e+05 6.097e+05
+  * y            (y) float64 11kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b01          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b02          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b03          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b04          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b05          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b06          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b07          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b11          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b12          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b8a          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 3 + (11 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 38MB
+Dimensions:      (x: 686, y: 686)
+Coordinates:
+  * x            (x) float64 5kB 5e+05 5.001e+05 ... 6.095e+05 6.096e+05
+  * y            (y) float64 5kB 4.2e+06 4.2e+06 4.2e+06 ... 4.091e+06 4.09e+06
+Data variables:
+    b01          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b02          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b03          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b04          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b05          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b06          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b07          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b11          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b12          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b8a          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 4 + (11 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 9MB
+Dimensions:      (x: 343, y: 343)
+Coordinates:
+  * x            (x) float64 3kB 5e+05 5.003e+05 ... 6.091e+05 6.095e+05
+  * y            (y) float64 3kB 4.2e+06 4.2e+06 ... 4.091e+06 4.091e+06
+Data variables:
+    b01          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b02          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b03          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b04          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b05          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b06          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b07          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b11          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b12          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b8a          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + r60m + (1 attributes • 3 subgroups) + +
+ +
+

Attributes

+
+ +
+ multiscales: + {'tile_matrix_set': {'id': 'Native_CRS_32633', ... +
+ +
+
+ +
+

Subgroups

+
+ +
+
+ + 📄 + 0 + (12 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 295MB
+Dimensions:      (x: 1830, y: 1830)
+Coordinates:
+  * x            (x) int64 15kB 500010 500070 500130 ... 609630 609690 609750
+  * y            (y) int64 15kB 4199970 4199910 4199850 ... 4090290 4090230
+Data variables:
+    b01          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b02          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b03          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b04          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b05          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b06          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b07          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b09          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b11          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b12          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    b8a          (y, x) float64 27MB dask.array<chunksize=(1830, 1830), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 1 + (12 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 74MB
+Dimensions:      (x: 915, y: 915)
+Coordinates:
+  * x            (x) float64 7kB 5e+05 5.001e+05 ... 6.095e+05 6.097e+05
+  * y            (y) float64 7kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b01          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b02          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b03          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b04          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b05          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b06          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b07          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b09          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b11          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b12          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    b8a          (y, x) float64 7MB dask.array<chunksize=(915, 915), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 2 + (12 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 18MB
+Dimensions:      (x: 457, y: 457)
+Coordinates:
+  * x            (x) float64 4kB 5e+05 5.002e+05 ... 6.093e+05 6.095e+05
+  * y            (y) float64 4kB 4.2e+06 4.2e+06 4.2e+06 ... 4.091e+06 4.09e+06
+Data variables:
+    b01          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b02          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b03          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b04          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b05          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b06          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b07          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b09          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b11          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b12          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    b8a          (y, x) float64 2MB dask.array<chunksize=(457, 457), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ + 📁 + r10m + (1 attributes • 6 subgroups) + +
+ +
+

Attributes

+
+ +
+ multiscales: + {'tile_matrix_set': {'id': 'Native_CRS_32633', ... +
+ +
+
+ +
+

Subgroups

+
+ +
+
+ + 📄 + 0 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 4GB
+Dimensions:      (x: 10980, y: 10980)
+Coordinates:
+  * x            (x) int64 88kB 499985 499995 500005 ... 609755 609765 609775
+  * y            (y) int64 88kB 4199995 4199985 4199975 ... 4090215 4090205
+Data variables:
+    b02          (y, x) float64 964MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b03          (y, x) float64 964MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b04          (y, x) float64 964MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b08          (y, x) float64 964MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 1 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 965MB
+Dimensions:      (x: 5490, y: 5490)
+Coordinates:
+  * x            (x) float64 44kB 5e+05 5e+05 5e+05 ... 6.097e+05 6.098e+05
+  * y            (y) float64 44kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b02          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b03          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b04          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    b08          (y, x) float64 241MB dask.array<chunksize=(3992, 3992), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 2 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 241MB
+Dimensions:      (x: 2745, y: 2745)
+Coordinates:
+  * x            (x) float64 22kB 5e+05 5e+05 5.001e+05 ... 6.097e+05 6.097e+05
+  * y            (y) float64 22kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b02          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b03          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b04          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    b08          (y, x) float64 60MB dask.array<chunksize=(2745, 2745), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 3 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 60MB
+Dimensions:      (x: 1372, y: 1372)
+Coordinates:
+  * x            (x) float64 11kB 5e+05 5.001e+05 ... 6.096e+05 6.097e+05
+  * y            (y) float64 11kB 4.2e+06 4.2e+06 4.2e+06 ... 4.09e+06 4.09e+06
+Data variables:
+    b02          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b03          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b04          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    b08          (y, x) float64 15MB dask.array<chunksize=(1372, 1372), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 4 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 15MB
+Dimensions:      (x: 686, y: 686)
+Coordinates:
+  * x            (x) float64 5kB 5e+05 5.001e+05 ... 6.095e+05 6.096e+05
+  * y            (y) float64 5kB 4.2e+06 4.2e+06 4.2e+06 ... 4.091e+06 4.09e+06
+Data variables:
+    b02          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b03          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b04          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    b08          (y, x) float64 4MB dask.array<chunksize=(686, 686), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ + 📄 + 5 + (5 variables • 1 attributes) + +
+ +
+

Variables

+
+
+ + + + + + + + + + + + + + +
<xarray.Dataset> Size: 4MB
+Dimensions:      (x: 343, y: 343)
+Coordinates:
+  * x            (x) float64 3kB 5e+05 5.003e+05 ... 6.091e+05 6.095e+05
+  * y            (y) float64 3kB 4.2e+06 4.2e+06 ... 4.091e+06 4.091e+06
+Data variables:
+    b02          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b03          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b04          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    b08          (y, x) float64 941kB dask.array<chunksize=(343, 343), meta=np.ndarray>
+    spatial_ref  int64 8B ...
+
+
+ +
+

Attributes

+
+ +
+ grid_mapping: + spatial_ref +
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+
+ +
+
+
+ +
+ +
+
+ + +
+ + diff --git a/dataset_tree_simplified.txt b/dataset_tree_simplified.txt new file mode 100644 index 00000000..fbacd381 --- /dev/null +++ b/dataset_tree_simplified.txt @@ -0,0 +1,112 @@ +# Simplified Data Tree Structure +# Extracted from: S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr +# Generated: 2025-09-26 11:25:09 + +Dataset: S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr +Path: s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr +Total Groups: 3 + +root/ +├── conditions/ +│ ├── geometry/ +│ │ └── Variables: mean_sun_angles, mean_viewing_incidence_angles, spatial_ref, sun_angles, viewing_incidence_angles +│ ├── mask/ +│ │ ├── l1c_classification/ +│ │ │ └── r60m/ +│ │ │ └── Variables: b00 +│ │ ├── l2a_classification/ +│ │ │ ├── r20m/ +│ │ │ │ └── Variables: scl +│ │ │ └── r60m/ +│ │ │ └── Variables: scl +│ │ └── detector_footprint/ +│ │ ├── r10m/ +│ │ │ └── Variables: b03, b04, b02, b08 +│ │ ├── r20m/ +│ │ │ └── Variables: b05, b06, b07, b11, b12, b8a +│ │ └── r60m/ +│ │ └── Variables: b09, b01, b10 +│ └── meteorology/ +│ ├── cams/ +│ │ └── Variables: aod1240, aod469, aod550, bcaod550, duaod550, aod670, aod865, ssaod550, omaod550, suaod550, z +│ └── ecmwf/ +│ └── Variables: r, msl, tco3, tcwv, u10, v10 +├── quality/ +│ ├── atmosphere/ +│ │ ├── r20m/ +│ │ │ └── Variables: aot, wvp +│ │ ├── r60m/ +│ │ │ └── Variables: aot, wvp +│ │ └── r10m/ +│ │ └── Variables: aot, wvp +│ ├── mask/ +│ │ ├── r20m/ +│ │ │ └── Variables: b05, b8a, b07, b06, b11, b12 +│ │ ├── r60m/ +│ │ │ └── Variables: b01, b09, b10 +│ │ └── r10m/ +│ │ └── Variables: b02, b04, b03, b08 +│ ├── probability/ +│ │ └── r20m/ +│ │ └── Variables: cld, snw +│ └── l2a_quicklook/ +│ ├── r10m/ +│ │ ├── 0/ +│ │ │ └── Variables: spatial_ref, tci +│ │ ├── 1/ +│ │ │ └── Variables: spatial_ref, tci +│ │ ├── 2/ +│ │ │ └── Variables: spatial_ref, tci +│ │ ├── 3/ +│ │ │ └── Variables: spatial_ref, tci +│ │ ├── 4/ +│ │ │ └── Variables: spatial_ref, tci +│ │ └── 5/ +│ │ └── Variables: spatial_ref, tci +│ ├── r20m/ +│ │ └── Variables: tci +│ └── r60m/ +│ └── Variables: tci +└── measurements/ + └── reflectance/ + ├── r20m/ + │ ├── 0/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b11, b12, b8a, spatial_ref + │ ├── 1/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b11, b12, b8a, spatial_ref + │ ├── 2/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b11, b12, b8a, spatial_ref + │ ├── 3/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b11, b12, b8a, spatial_ref + │ └── 4/ + │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b11, b12, b8a, spatial_ref + ├── r60m/ + │ ├── 0/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b09, b11, b12, b8a, spatial_ref + │ ├── 1/ + │ │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b09, b11, b12, b8a, spatial_ref + │ └── 2/ + │ └── Variables: b01, b02, b03, b04, b05, b06, b07, b09, b11, b12, b8a, spatial_ref + └── r10m/ + ├── 0/ + │ └── Variables: b02, b03, b04, b08, spatial_ref + ├── 1/ + │ └── Variables: b02, b03, b04, b08, spatial_ref + ├── 2/ + │ └── Variables: b02, b03, b04, b08, spatial_ref + ├── 3/ + │ └── Variables: b02, b03, b04, b08, spatial_ref + ├── 4/ + │ └── Variables: b02, b03, b04, b08, spatial_ref + └── 5/ + └── Variables: b02, b03, b04, b08, spatial_ref + +## Summary: +- Main categories: conditions, quality, measurements +- Resolution levels: r10m (10m), r20m (20m), r60m (60m) +- Spectral bands: b01-b12, b8a +- Mask types: l1c_classification, l2a_classification, detector_footprint +- Quality metrics: atmosphere (aot, wvp), probability (cld, snw) +- Meteorological data: CAMS aerosol data, ECMWF weather data +- Geometric data: sun angles, viewing angles, spatial reference +- Multiscale data: Multiple zoom levels (0-5) for visualization diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 19eae088..05908eb8 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -502,6 +502,7 @@ def write_geozarr_group( tile_width=tile_width, spatial_chunk=spatial_chunk, ds_gcp=ds_gcp, + enable_sharding=enable_sharding, ) except Exception as e: print( @@ -527,6 +528,7 @@ def create_geozarr_compliant_multiscales( tile_width: int = 256, spatial_chunk: int = 4096, ds_gcp: xr.Dataset | None = None, + enable_sharding: bool = False ) -> Dict[str, Any]: """ Create GeoZarr-spec compliant multiscales following the specification exactly. @@ -684,10 +686,11 @@ def create_geozarr_compliant_multiscales( native_bounds, data_vars, ds_gcp_overview, + enable_sharding, ) # Create encoding for this overview level - encoding = _create_geozarr_encoding(overview_ds, compressor, spatial_chunk) + encoding = _create_geozarr_encoding(overview_ds, compressor, spatial_chunk, enable_sharding) # Write overview level overview_path = fs_utils.normalize_path(f"{output_path}/{group_name}/{level}") @@ -895,6 +898,7 @@ def create_overview_dataset_all_vars( native_bounds: Tuple[float, float, float, float], data_vars: Sequence[Hashable], ds_gcp: xr.Dataset | None = None, + enable_sharding: bool = False, ) -> xr.Dataset: """ Create an overview dataset containing all variables for a specific level. diff --git a/test_sharding_fix.py b/test_sharding_fix.py new file mode 100644 index 00000000..af8ad669 --- /dev/null +++ b/test_sharding_fix.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +Test script to verify the sharding fix for GeoZarr conversion. +This script tests the _calculate_shard_dimension function and validates +that shard dimensions are properly divisible by chunk dimensions. +""" + +import sys +import os +sys.path.insert(0, 'src') + +def test_calculate_shard_dimension(): + """Test the _calculate_shard_dimension function.""" + from eopf_geozarr.conversion.geozarr import _calculate_shard_dimension + + print("🧪 Testing _calculate_shard_dimension function...") + + # Test cases: (data_dim, chunk_dim, description) + test_cases = [ + (10980, 4096, "Sentinel-2 10m resolution typical case"), + (5490, 2048, "Half resolution case"), + (8192, 4096, "Perfect 2x multiple"), + (12288, 4096, "Perfect 3x multiple"), + (16384, 4096, "Perfect 4x multiple"), + (20480, 4096, "Perfect 5x multiple"), + (24576, 4096, "Perfect 6x multiple"), + (1000, 512, "Small dimension case"), + (256, 512, "Chunk larger than data"), + (1024, 256, "4x multiple case"), + ] + + print("\nTest Results:") + print("=" * 80) + print(f"{'Data Dim':<10} {'Chunk Dim':<10} {'Shard Dim':<10} {'Divisible?':<12} {'Description'}") + print("-" * 80) + + all_passed = True + for data_dim, chunk_dim, description in test_cases: + shard_dim = _calculate_shard_dimension(data_dim, chunk_dim) + + # When chunk_dim >= data_dim, the effective chunk size is data_dim + effective_chunk_dim = min(chunk_dim, data_dim) + is_divisible = shard_dim % effective_chunk_dim == 0 + status = "✅ YES" if is_divisible else "❌ NO" + + print(f"{data_dim:<10} {chunk_dim:<10} {shard_dim:<10} {status:<12} {description}") + + if not is_divisible: + all_passed = False + print(f" ⚠️ ERROR: {shard_dim} % {effective_chunk_dim} = {shard_dim % effective_chunk_dim}") + + print("-" * 80) + if all_passed: + print("✅ All tests passed! Shard dimensions are properly divisible by chunk dimensions.") + else: + print("❌ Some tests failed! Check the implementation.") + + return all_passed + + +def test_encoding_creation(): + """Test the encoding creation with sharding enabled.""" + import numpy as np + import xarray as xr + from zarr.codecs import BloscCodec + from eopf_geozarr.conversion.geozarr import _create_geozarr_encoding + + print("\n🧪 Testing encoding creation with sharding...") + + # Create a test dataset + data = np.random.rand(1, 10980, 10980).astype(np.float32) + ds = xr.Dataset({ + 'b02': (['time', 'y', 'x'], data), + }, coords={ + 'time': [np.datetime64('2023-01-01')], + 'y': np.arange(10980), + 'x': np.arange(10980), + }) + + compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) + spatial_chunk = 4096 + + # Test with sharding enabled + print("\nTesting with sharding enabled:") + encoding = _create_geozarr_encoding(ds, compressor, spatial_chunk, enable_sharding=True) + + for var, enc in encoding.items(): + if 'shards' in enc and enc['shards'] is not None: + chunks = enc['chunks'] + shards = enc['shards'] + print(f"Variable: {var}") + print(f" Data shape: {ds[var].shape}") + print(f" Chunks: {chunks}") + print(f" Shards: {shards}") + + # Validate divisibility + valid = True + for i, (shard_dim, chunk_dim) in enumerate(zip(shards, chunks)): + if shard_dim % chunk_dim != 0: + print(f" ❌ Axis {i}: {shard_dim} % {chunk_dim} = {shard_dim % chunk_dim}") + valid = False + else: + print(f" ✅ Axis {i}: {shard_dim} % {chunk_dim} = 0") + + if valid: + print(" ✅ All shard dimensions are divisible by chunk dimensions") + else: + print(" ❌ Some shard dimensions are not divisible by chunk dimensions") + + print("\nTesting with sharding disabled:") + encoding_no_shard = _create_geozarr_encoding(ds, compressor, spatial_chunk, enable_sharding=False) + + for var, enc in encoding_no_shard.items(): + if 'shards' in enc: + print(f"Variable: {var}, Shards: {enc['shards']}") + + +def main(): + """Run all tests.""" + print("🔧 Testing Zarr v3 Sharding Fix for GeoZarr") + print("=" * 50) + + # Test the shard dimension calculation + test1_passed = test_calculate_shard_dimension() + + # Test the encoding creation + test_encoding_creation() + + print("\n" + "=" * 50) + if test1_passed: + print("✅ All critical tests passed!") + print("🎉 The sharding fix should resolve the checksum mismatch issues.") + print("\nKey improvements:") + print("- Shard dimensions are now evenly divisible by chunk dimensions") + print("- Added debugging output to show sharding configuration") + print("- Enhanced shard calculation with preference for larger multipliers") + else: + print("❌ Some tests failed. Please review the implementation.") + + return test1_passed + + +if __name__ == "__main__": + success = main() + sys.exit(0 if success else 1) From 82305a6353ca060b1c092f2f86d6c1eae8ba06f3 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 10:10:32 +0000 Subject: [PATCH 10/83] feat: enable sharding in Dask cluster setup and enhance chunking logic for sharded variables --- .vscode/launch.json | 1 + src/eopf_geozarr/cli.py | 5 +++-- src/eopf_geozarr/conversion/geozarr.py | 12 +++++++++++- src/eopf_geozarr/conversion/utils.py | 2 +- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e7bd8e81..6d595477 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -120,6 +120,7 @@ "--tile-width", "256", "--max-retries", "2", "--dask-cluster", + "--enable-sharding", "--verbose" ], "cwd": "${workspaceFolder}", diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 77e24e22..3bd9df75 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -51,8 +51,9 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any try: from dask.distributed import Client - # Set up local cluster - client = Client() # set up local cluster + # Set up local cluster with high memory limits + client = Client(memory_limit="8GB") # set up local cluster + # client = Client() # set up local cluster if verbose: print(f"🚀 Dask cluster started: {client}") diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 19eae088..f6ba611c 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -1100,7 +1100,17 @@ def write_dataset_band_by_band_with_validation( for attempt in range(max_retries): try: # Ensure the dataset is properly chunked to align with encoding - if var in var_encoding and "chunks" in var_encoding[var]: + if var in var_encoding and "shards" in var_encoding[var] and var_encoding[var]["shards"] is not None: + # For sharded variables, use the shards dimensions + shard_dims = var_encoding[var].get("shards", None) + if shard_dims is not None: + var_dims = single_var_ds[var].dims + chunk_dict = {} + for i, dim in enumerate(var_dims): + if i < len(shard_dims): + chunk_dict[dim] = shard_dims[i] + single_var_ds[var] = single_var_ds[var].chunk(chunk_dict) + elif var in var_encoding and "chunks" in var_encoding[var]: target_chunks = var_encoding[var]["chunks"] # Create chunk dict using the actual dimensions of the variable var_dims = single_var_ds[var].dims diff --git a/src/eopf_geozarr/conversion/utils.py b/src/eopf_geozarr/conversion/utils.py index d6ccd033..7b9087aa 100644 --- a/src/eopf_geozarr/conversion/utils.py +++ b/src/eopf_geozarr/conversion/utils.py @@ -124,7 +124,7 @@ def calculate_aligned_chunk_size(dimension_size: int, target_chunk_size: int) -> return dimension_size # Find the largest divisor of dimension_size that is <= target_chunk_size - for chunk_size in range(target_chunk_size, 0, -1): + for chunk_size in range(target_chunk_size, int(target_chunk_size * 0.51), -1): if dimension_size % chunk_size == 0: return chunk_size From 3232a8e56458fd41aa45a11c219cb2f1b1c81577 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 14:18:50 +0200 Subject: [PATCH 11/83] Add Sentinel-2 Optimization Module with CLI Integration and Data Processing - Created the `s2_optimization` module for optimizing Sentinel-2 Zarr datasets. - Implemented CLI commands for converting Sentinel-2 datasets to optimized structures. - Developed band mapping and resolution definitions for Sentinel-2 optimization. - Added the `S2OptimizedConverter` class for handling the conversion process. - Implemented data consolidation logic to reorganize Sentinel-2 structure. - Created multiscale pyramid generation for optimized data. - Added downsampling operations for various data types (reflectance, classification, quality masks). - Implemented validation logic for optimized Sentinel-2 datasets. - Developed unit tests for band mapping, converter functionality, and resampling operations. --- s2_implementation_guidelines.md | 1519 +++++++++++++++++ s2_optimization_plan.md | 363 ++++ src/eopf_geozarr/s2_optimization/__init__.py | 2 + .../s2_optimization/cli_integration.py | 45 + .../s2_optimization/s2_band_mapping.py | 80 + .../s2_optimization/s2_converter.py | 36 + .../s2_optimization/s2_data_consolidator.py | 43 + .../s2_optimization/s2_multiscale.py | 27 + .../s2_optimization/s2_resampling.py | 236 +++ .../s2_optimization/s2_validation.py | 29 + .../tests/test_s2_band_mapping.py | 48 + src/eopf_geozarr/tests/test_s2_converter.py | 33 + src/eopf_geozarr/tests/test_s2_resampling.py | 539 ++++++ 13 files changed, 3000 insertions(+) create mode 100644 s2_implementation_guidelines.md create mode 100644 s2_optimization_plan.md create mode 100644 src/eopf_geozarr/s2_optimization/__init__.py create mode 100644 src/eopf_geozarr/s2_optimization/cli_integration.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_band_mapping.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_converter.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_data_consolidator.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_multiscale.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_resampling.py create mode 100644 src/eopf_geozarr/s2_optimization/s2_validation.py create mode 100644 src/eopf_geozarr/tests/test_s2_band_mapping.py create mode 100644 src/eopf_geozarr/tests/test_s2_converter.py create mode 100644 src/eopf_geozarr/tests/test_s2_resampling.py diff --git a/s2_implementation_guidelines.md b/s2_implementation_guidelines.md new file mode 100644 index 00000000..db6f072e --- /dev/null +++ b/s2_implementation_guidelines.md @@ -0,0 +1,1519 @@ +# Sentinel-2 Optimization - Detailed Implementation Guidelines + +## Module Architecture + +### Core Files to Create + +``` +eopf_geozarr/s2_optimization/ +├── __init__.py +├── s2_converter.py # Main conversion orchestration +├── s2_band_mapping.py # Band resolution and availability mapping +├── s2_resampling.py # Downsampling operations only +├── s2_data_consolidator.py # Data reorganization logic +├── s2_multiscale.py # Multiscale pyramid creation +├── s2_validation.py # Validation and integrity checks +└── cli_integration.py # CLI command integration +``` + +## Implementation Specifications + +### 1. s2_band_mapping.py + +```python +""" +Band mapping and resolution definitions for Sentinel-2 optimization. +""" + +from typing import Dict, List, Set +from dataclasses import dataclass + +@dataclass +class BandInfo: + """Information about a spectral band.""" + name: str + native_resolution: int # meters + data_type: str + wavelength_center: float # nanometers + wavelength_width: float # nanometers + +# Native resolution definitions - CRITICAL: Only these bands exist at these resolutions +NATIVE_BANDS: Dict[int, List[str]] = { + 10: ['b02', 'b03', 'b04', 'b08'], # Blue, Green, Red, NIR + 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], # Red Edge, SWIR + 60: ['b01', 'b09'] # Coastal, Water Vapor +} + +# Complete band information +BAND_INFO: Dict[str, BandInfo] = { + 'b01': BandInfo('b01', 60, 'uint16', 443, 21), # Coastal aerosol + 'b02': BandInfo('b02', 10, 'uint16', 490, 66), # Blue + 'b03': BandInfo('b03', 10, 'uint16', 560, 36), # Green + 'b04': BandInfo('b04', 10, 'uint16', 665, 31), # Red + 'b05': BandInfo('b05', 20, 'uint16', 705, 15), # Red Edge 1 + 'b06': BandInfo('b06', 20, 'uint16', 740, 15), # Red Edge 2 + 'b07': BandInfo('b07', 20, 'uint16', 783, 20), # Red Edge 3 + 'b08': BandInfo('b08', 10, 'uint16', 842, 106), # NIR + 'b8a': BandInfo('b8a', 20, 'uint16', 865, 21), # NIR Narrow + 'b09': BandInfo('b09', 60, 'uint16', 945, 20), # Water Vapor + 'b11': BandInfo('b11', 20, 'uint16', 1614, 91), # SWIR 1 + 'b12': BandInfo('b12', 20, 'uint16', 2202, 175), # SWIR 2 +} + +# Quality data mapping - defines which auxiliary data exists at which resolutions +QUALITY_DATA_NATIVE: Dict[str, int] = { + 'scl': 20, # Scene Classification Layer - native 20m + 'aot': 20, # Aerosol Optical Thickness - native 20m + 'wvp': 20, # Water Vapor - native 20m + 'cld': 20, # Cloud probability - native 20m + 'snw': 20, # Snow probability - native 20m +} + +# Detector footprint availability - matches spectral bands +DETECTOR_FOOTPRINT_NATIVE: Dict[int, List[str]] = { + 10: ['b02', 'b03', 'b04', 'b08'], + 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], + 60: ['b01', 'b09'] +} + +def get_bands_for_level(level: int) -> Set[str]: + """ + Get all bands available at a given pyramid level. + + Args: + level: Pyramid level (0=10m, 1=20m, 2=60m, 3+=downsampled) + + Returns: + Set of band names available at this level + """ + if level == 0: # 10m - only native 10m bands + return set(NATIVE_BANDS[10]) + elif level == 1: # 20m - all bands (native + downsampled from 10m) + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + elif level == 2: # 60m - all bands downsampled + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + else: # Further downsampling - all bands + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + +def get_quality_data_for_level(level: int) -> Set[str]: + """Get quality data available at a given level (no upsampling).""" + if level == 0: # 10m - no quality data (would require upsampling) + return set() + elif level >= 1: # 20m and below - all quality data available + return set(QUALITY_DATA_NATIVE.keys()) +``` + +### 2. s2_resampling.py + +```python +""" +Downsampling operations for Sentinel-2 data (no upsampling). +""" + +import numpy as np +import xarray as xr +from scipy.ndimage import zoom +from sklearn.preprocessing import mode +import warnings + +class S2ResamplingEngine: + """Handles downsampling operations for S2 multiscale creation.""" + + def __init__(self): + self.resampling_methods = { + 'reflectance': self._downsample_reflectance, + 'classification': self._downsample_classification, + 'quality_mask': self._downsample_quality_mask, + 'probability': self._downsample_probability, + 'detector_footprint': self._downsample_quality_mask, # Same as quality mask + } + + def downsample_variable(self, data: xr.DataArray, target_height: int, + target_width: int, var_type: str) -> xr.DataArray: + """ + Downsample a variable to target dimensions. + + Args: + data: Input data array + target_height: Target height in pixels + target_width: Target width in pixels + var_type: Type of variable ('reflectance', 'classification', etc.) + + Returns: + Downsampled data array + """ + if var_type not in self.resampling_methods: + raise ValueError(f"Unknown variable type: {var_type}") + + method = self.resampling_methods[var_type] + return method(data, target_height, target_width) + + def _downsample_reflectance(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Block averaging for reflectance bands.""" + # Calculate block sizes + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Ensure exact divisibility + if current_height % target_height != 0 or current_width % target_width != 0: + # Crop to make it divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + # Perform block averaging + if data.ndim == 3: # (time, y, x) or similar + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + downsampled = reshaped.mean(axis=(2, 4)) + else: # (y, x) + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = reshaped.mean(axis=(1, 3)) + + # Create new coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + # Create new DataArray + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_classification(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Mode-based downsampling for classification data.""" + from scipy import stats + + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Crop to make divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + # Reshape for block processing + if data.ndim == 3: + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + # Compute mode for each block + downsampled = np.zeros((data.shape[0], target_height, target_width), dtype=data.dtype) + for t in range(data.shape[0]): + for i in range(target_height): + for j in range(target_width): + block = reshaped[t, i, :, j, :].flatten() + mode_val = stats.mode(block, keepdims=False)[0] + downsampled[t, i, j] = mode_val + else: + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = np.zeros((target_height, target_width), dtype=data.dtype) + for i in range(target_height): + for j in range(target_width): + block = reshaped[i, :, j, :].flatten() + mode_val = stats.mode(block, keepdims=False)[0] + downsampled[i, j] = mode_val + + # Create coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_quality_mask(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Logical OR downsampling for quality masks (any bad pixel = bad block).""" + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Crop to make divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + if data.ndim == 3: + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + # Any non-zero value in block makes the downsampled pixel non-zero + downsampled = (reshaped.sum(axis=(2, 4)) > 0).astype(data.dtype) + else: + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = (reshaped.sum(axis=(1, 3)) > 0).astype(data.dtype) + + # Create coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_probability(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Average downsampling for probability data.""" + # Use same method as reflectance but ensure values stay in [0,1] or [0,100] range + result = self._downsample_reflectance(data, target_height, target_width) + + # Clamp values to valid probability range + if result.max() <= 1.0: # [0,1] probabilities + result.values = np.clip(result.values, 0, 1) + else: # [0,100] percentages + result.values = np.clip(result.values, 0, 100) + + return result + +def determine_variable_type(var_name: str, var_data: xr.DataArray) -> str: + """ + Determine the type of a variable for appropriate resampling. + + Args: + var_name: Name of the variable + var_data: The data array + + Returns: + Variable type string + """ + # Spectral bands + if var_name.startswith('b') and (var_name[1:].isdigit() or var_name == 'b8a'): + return 'reflectance' + + # Quality data + if var_name in ['scl']: # Scene Classification Layer + return 'classification' + + if var_name in ['cld', 'snw']: # Probability data + return 'probability' + + if var_name in ['aot', 'wvp']: # Atmosphere quality - treat as reflectance + return 'reflectance' + + if var_name.startswith('detector_footprint_') or var_name.startswith('quality_'): + return 'quality_mask' + + # Default to reflectance for unknown variables + return 'reflectance' +``` + +### 3. s2_data_consolidator.py + +```python +""" +Data consolidation logic for reorganizing S2 structure. +""" + +import xarray as xr +from typing import Dict, List, Tuple, Optional +from .s2_band_mapping import ( + NATIVE_BANDS, QUALITY_DATA_NATIVE, DETECTOR_FOOTPRINT_NATIVE, + get_bands_for_level, get_quality_data_for_level +) + +class S2DataConsolidator: + """Consolidates S2 data from scattered structure into organized groups.""" + + def __init__(self, dt_input: xr.DataTree): + self.dt_input = dt_input + self.measurements_data = {} + self.geometry_data = {} + self.meteorology_data = {} + + def consolidate_all_data(self) -> Tuple[Dict, Dict, Dict]: + """ + Consolidate all data into three main categories. + + Returns: + Tuple of (measurements, geometry, meteorology) data dictionaries + """ + self._extract_measurements_data() + self._extract_geometry_data() + self._extract_meteorology_data() + + return self.measurements_data, self.geometry_data, self.meteorology_data + + def _extract_measurements_data(self) -> None: + """Extract and organize all measurement-related data by native resolution.""" + + # Initialize resolution groups + for resolution in [10, 20, 60]: + self.measurements_data[resolution] = { + 'bands': {}, + 'quality': {}, + 'detector_footprints': {}, + 'classification': {}, + 'atmosphere': {}, + 'probability': {} + } + + # Extract reflectance bands + if '/measurements/reflectance' in self.dt_input.groups: + self._extract_reflectance_bands() + + # Extract quality data + self._extract_quality_data() + + # Extract detector footprints + self._extract_detector_footprints() + + # Extract atmosphere quality + self._extract_atmosphere_data() + + # Extract classification data + self._extract_classification_data() + + # Extract probability data + self._extract_probability_data() + + def _extract_reflectance_bands(self) -> None: + """Extract reflectance bands from measurements/reflectance groups.""" + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) # Extract number from 'r10m' + group_path = f'/measurements/reflectance/{resolution}' + + if group_path in self.dt_input.groups: + # Check if this is a multiscale group (has numeric subgroups) + group_node = self.dt_input[group_path] + if hasattr(group_node, 'children') and group_node.children: + # Take level 0 (native resolution) + native_path = f'{group_path}/0' + if native_path in self.dt_input.groups: + ds = self.dt_input[native_path].to_dataset() + else: + ds = group_node.to_dataset() + else: + ds = group_node.to_dataset() + + # Extract only native bands for this resolution + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + self.measurements_data[res_num]['bands'][band] = ds[band] + + def _extract_quality_data(self) -> None: + """Extract quality mask data.""" + quality_base = '/quality/mask' + + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) + group_path = f'{quality_base}/{resolution}' + + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + # Only extract quality for native bands at this resolution + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + self.measurements_data[res_num]['quality'][f'quality_{band}'] = ds[band] + + def _extract_detector_footprints(self) -> None: + """Extract detector footprint data.""" + footprint_base = '/conditions/mask/detector_footprint' + + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) + group_path = f'{footprint_base}/{resolution}' + + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + # Only extract footprints for native bands + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + var_name = f'detector_footprint_{band}' + self.measurements_data[res_num]['detector_footprints'][var_name] = ds[band] + + def _extract_atmosphere_data(self) -> None: + """Extract atmosphere quality data (aot, wvp) - native at 20m.""" + atm_base = '/quality/atmosphere' + + # Atmosphere data is native at 20m resolution + group_path = f'{atm_base}/r20m' + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + for var in ['aot', 'wvp']: + if var in ds.data_vars: + self.measurements_data[20]['atmosphere'][var] = ds[var] + + def _extract_classification_data(self) -> None: + """Extract scene classification data - native at 20m.""" + class_base = '/conditions/mask/l2a_classification' + + # Classification is native at 20m + group_path = f'{class_base}/r20m' + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + if 'scl' in ds.data_vars: + self.measurements_data[20]['classification']['scl'] = ds['scl'] + + def _extract_probability_data(self) -> None: + """Extract cloud and snow probability data - native at 20m.""" + prob_base = '/quality/probability/r20m' + + if prob_base in self.dt_input.groups: + ds = self.dt_input[prob_base].to_dataset() + + for var in ['cld', 'snw']: + if var in ds.data_vars: + self.measurements_data[20]['probability'][var] = ds[var] + + def _extract_geometry_data(self) -> None: + """Extract all geometry-related data into single group.""" + geom_base = '/conditions/geometry' + + if geom_base in self.dt_input.groups: + ds = self.dt_input[geom_base].to_dataset() + + # Consolidate all geometry variables + for var_name in ds.data_vars: + self.geometry_data[var_name] = ds[var_name] + + def _extract_meteorology_data(self) -> None: + """Extract meteorological data (CAMS and ECMWF).""" + # CAMS data + cams_path = '/conditions/meteorology/cams' + if cams_path in self.dt_input.groups: + ds = self.dt_input[cams_path].to_dataset() + for var_name in ds.data_vars: + self.meteorology_data[f'cams_{var_name}'] = ds[var_name] + + # ECMWF data + ecmwf_path = '/conditions/meteorology/ecmwf' + if ecmwf_path in self.dt_input.groups: + ds = self.dt_input[ecmwf_path].to_dataset() + for var_name in ds.data_vars: + self.meteorology_data[f'ecmwf_{var_name}'] = ds[var_name] + +def create_consolidated_dataset(data_dict: Dict, resolution: int) -> xr.Dataset: + """ + Create a consolidated dataset from categorized data. + + Args: + data_dict: Dictionary with categorized data + resolution: Target resolution in meters + + Returns: + Consolidated xarray Dataset + """ + all_vars = {} + + # Combine all data variables + for category, vars_dict in data_dict.items(): + all_vars.update(vars_dict) + + if not all_vars: + return xr.Dataset() + + # Create dataset + ds = xr.Dataset(all_vars) + + # Set up coordinate system and metadata + if 'x' in ds.coords and 'y' in ds.coords: + # Ensure CRS information is present + if ds.rio.crs is None: + # Try to infer CRS from one of the variables + for var_name, var_data in all_vars.items(): + if hasattr(var_data, 'rio') and var_data.rio.crs: + ds.rio.write_crs(var_data.rio.crs, inplace=True) + break + + # Add resolution metadata + ds.attrs['native_resolution_meters'] = resolution + ds.attrs['processing_level'] = 'L2A' + ds.attrs['product_type'] = 'S2MSI2A' + + return ds +``` + +### 4. s2_multiscale.py + +```python +""" +Multiscale pyramid creation for optimized S2 structure. +""" + +import numpy as np +import xarray as xr +from typing import Dict, List, Tuple +from .s2_resampling import S2ResamplingEngine, determine_variable_type +from .s2_band_mapping import get_bands_for_level, get_quality_data_for_level + +class S2MultiscalePyramid: + """Creates multiscale pyramids for consolidated S2 data.""" + + def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + self.resampler = S2ResamplingEngine() + + # Define pyramid levels: resolution in meters + self.pyramid_levels = { + 0: 10, # Level 0: 10m (native for b02,b03,b04,b08) + 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) + 2: 40, # Level 2: 40m (2x downsampling from 20m) + 3: 80, # Level 3: 80m + 4: 160, # Level 4: 160m + 5: 320, # Level 5: 320m + 6: 640, # Level 6: 640m + } + + def create_multiscale_measurements( + self, + measurements_by_resolution: Dict[int, Dict], + output_path: str + ) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid from consolidated measurements. + + Args: + measurements_by_resolution: Data organized by native resolution + output_path: Base output path + + Returns: + Dictionary of datasets by pyramid level + """ + pyramid_datasets = {} + + # Create each pyramid level + for level, target_resolution in self.pyramid_levels.items(): + print(f"Creating pyramid level {level} ({target_resolution}m)...") + + dataset = self._create_level_dataset( + level, target_resolution, measurements_by_resolution + ) + + if dataset and len(dataset.data_vars) > 0: + pyramid_datasets[level] = dataset + + # Write this level + level_path = f"{output_path}/measurements/{level}" + self._write_level_dataset(dataset, level_path, level) + + return pyramid_datasets + + def _create_level_dataset( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict[int, Dict] + ) -> xr.Dataset: + """Create dataset for a specific pyramid level.""" + + if level == 0: + # Level 0: Only native 10m data + return self._create_level_0_dataset(measurements_by_resolution) + elif level == 1: + # Level 1: All data at 20m (native + downsampled from 10m) + return self._create_level_1_dataset(measurements_by_resolution) + else: + # Levels 2+: Downsample from level 1 + return self._create_downsampled_dataset( + level, target_resolution, measurements_by_resolution + ) + + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 0 dataset with only native 10m data.""" + if 10 not in measurements_by_resolution: + return xr.Dataset() + + data_10m = measurements_by_resolution[10] + all_vars = {} + + # Add only native 10m bands and their associated data + for category, vars_dict in data_10m.items(): + all_vars.update(vars_dict) + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 0 + dataset.attrs['resolution_meters'] = 10 + + return dataset + + def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 1 dataset with all data at 20m resolution.""" + all_vars = {} + reference_coords = None + + # Start with native 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 20m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + 'x': first_var.coords['x'], + 'y': first_var.coords['y'] + } + + # Add downsampled 10m data + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + if reference_coords: + # Downsample to match 20m grid + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 1 + dataset.attrs['resolution_meters'] = 20 + + return dataset + + def _create_downsampled_dataset( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict + ) -> xr.Dataset: + """Create downsampled dataset for levels 2+.""" + # Start from level 1 data (20m) and downsample + level_1_dataset = self._create_level_1_dataset(measurements_by_resolution) + + if len(level_1_dataset.data_vars) == 0: + return xr.Dataset() + + # Calculate target dimensions (downsample by factor of 2^(level-1)) + downsample_factor = 2 ** (level - 1) + + # Get reference dimensions from level 1 + ref_var = next(iter(level_1_dataset.data_vars.values())) + current_height, current_width = ref_var.shape[-2:] + target_height = current_height // downsample_factor + target_width = current_width // downsample_factor + + downsampled_vars = {} + + for var_name, var_data in level_1_dataset.data_vars.items(): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + downsampled_vars[var_name] = downsampled + + # Create dataset + dataset = xr.Dataset(downsampled_vars) + dataset.attrs['pyramid_level'] = level + dataset.attrs['resolution_meters'] = target_resolution + + return dataset + + def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> None: + """Write a pyramid level dataset to storage.""" + # Create encoding + encoding = self._create_level_encoding(dataset, level) + + # Write dataset + print(f" Writing level {level} to {level_path}") + dataset.to_zarr( + level_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=encoding + ) + + def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: + """Create optimized encoding for a pyramid level.""" + encoding = {} + + # Calculate level-appropriate chunk sizes + chunk_size = max(256, self.spatial_chunk // (2 ** level)) + + for var_name, var_data in dataset.data_vars.items(): + if var_data.ndim >= 2: + height, width = var_data.shape[-2:] + + # Adjust chunk size to data dimensions + chunk_y = min(chunk_size, height) + chunk_x = min(chunk_size, width) + + if var_data.ndim == 3: + chunks = (1, chunk_y, chunk_x) + else: + chunks = (chunk_y, chunk_x) + else: + chunks = (min(chunk_size, var_data.shape[0]),) + + # Configure encoding + var_encoding = { + 'chunks': chunks, + 'compressor': 'default' + } + + # Add sharding if enabled + if self.enable_sharding and var_data.ndim >= 2: + shard_dims = self._calculate_shard_dimensions(var_data.shape, chunks) + var_encoding['shards'] = shard_dims + + encoding[var_name] = var_encoding + + # Add coordinate encoding + for coord_name in dataset.coords: + encoding[coord_name] = {'compressor': None} + + return encoding + + def _calculate_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) -> Tuple: + """Calculate shard dimensions for Zarr v3 sharding.""" + shard_dims = [] + + for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): + # Ensure shard dimension is evenly divisible by chunk dimension + if chunk_size >= dim_size: + shard_dim = dim_size + else: + # Calculate largest multiple of chunk_size that fits + num_chunks = dim_size // chunk_size + if num_chunks >= 4: # Use 4 chunks per shard if possible + shard_dim = min(4 * chunk_size, dim_size) + else: + shard_dim = num_chunks * chunk_size + + shard_dims.append(shard_dim) + + return tuple(shard_dims) +``` + +### 5. s2_converter.py (Main orchestration) + +```python +""" +Main S2 optimization converter. +""" + +import os +import time +from pathlib import Path +from typing import Dict, Optional, List +import xarray as xr + +from .s2_data_consolidator import S2DataConsolidator, create_consolidated_dataset +from .s2_multiscale import S2MultiscalePyramid +from .s2_validation import S2OptimizationValidator +from ..fs_utils import get_storage_options, normalize_path + +class S2OptimizedConverter: + """Optimized Sentinel-2 to GeoZarr converter.""" + + def __init__( + self, + enable_sharding: bool = True, + spatial_chunk: int = 1024, + compression_level: int = 3, + max_retries: int = 3 + ): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + self.compression_level = compression_level + self.max_retries = max_retries + + # Initialize components + self.pyramid_creator = S2MultiscalePyramid(enable_sharding, spatial_chunk) + self.validator = S2OptimizationValidator() + + def convert_s2_optimized( + self, + dt_input: xr.DataTree, + output_path: str, + create_geometry_group: bool = True, + create_meteorology_group: bool = True, + validate_output: bool = True, + verbose: bool = False + ) -> xr.DataTree: + """ + Convert S2 dataset to optimized structure. + + Args: + dt_input: Input Sentinel-2 DataTree + output_path: Output path for optimized dataset + create_geometry_group: Whether to create geometry group + create_meteorology_group: Whether to create meteorology group + validate_output: Whether to validate the output + verbose: Enable verbose logging + + Returns: + Optimized DataTree + """ + start_time = time.time() + + if verbose: + print(f"Starting S2 optimization conversion...") + print(f"Input: {len(dt_input.groups)} groups") + print(f"Output: {output_path}") + + # Validate input is S2 + if not self._is_sentinel2_dataset(dt_input): + raise ValueError("Input dataset is not a Sentinel-2 product") + + # Step 1: Consolidate data from scattered structure + print("Step 1: Consolidating scattered data structure...") + consolidator = S2DataConsolidator(dt_input) + measurements_data, geometry_data, meteorology_data = consolidator.consolidate_all_data() + + if verbose: + print(f" Measurements data extracted: {sum(len(d['bands']) for d in measurements_data.values())} bands") + print(f" Geometry variables: {len(geometry_data)}") + print(f" Meteorology variables: {len(meteorology_data)}") + + # Step 2: Create multiscale measurements + print("Step 2: Creating multiscale measurements pyramid...") + pyramid_datasets = self.pyramid_creator.create_multiscale_measurements( + measurements_data, output_path + ) + + print(f" Created {len(pyramid_datasets)} pyramid levels") + + # Step 3: Create geometry group + if create_geometry_group and geometry_data: + print("Step 3: Creating consolidated geometry group...") + geometry_ds = xr.Dataset(geometry_data) + geometry_path = f"{output_path}/geometry" + self._write_auxiliary_group(geometry_ds, geometry_path, "geometry", verbose) + + # Step 4: Create meteorology group + if create_meteorology_group and meteorology_data: + print("Step 4: Creating consolidated meteorology group...") + meteorology_ds = xr.Dataset(meteorology_data) + meteorology_path = f"{output_path}/meteorology" + self._write_auxiliary_group(meteorology_ds, meteorology_path, "meteorology", verbose) + + # Step 5: Create root-level multiscales metadata + print("Step 5: Adding multiscales metadata...") + self._add_root_multiscales_metadata(output_path, pyramid_datasets) + + # Step 6: Consolidate metadata + print("Step 6: Consolidating metadata...") + self._consolidate_root_metadata(output_path) + + # Step 7: Validation + if validate_output: + print("Step 7: Validating optimized dataset...") + validation_results = self.validator.validate_optimized_dataset(output_path) + if not validation_results['is_valid']: + print(" Warning: Validation issues found:") + for issue in validation_results['issues']: + print(f" - {issue}") + + # Create result DataTree + result_dt = self._create_result_datatree(output_path) + + total_time = time.time() - start_time + print(f"Optimization complete in {total_time:.2f}s") + + if verbose: + self._print_optimization_summary(dt_input, result_dt, output_path) + + return result_dt + + def _is_sentinel2_dataset(self, dt: xr.DataTree) -> bool: + """Check if dataset is Sentinel-2.""" + # Check STAC properties + stac_props = dt.attrs.get('stac_discovery', {}).get('properties', {}) + mission = stac_props.get('mission', '') + + if mission.lower().startswith('sentinel-2'): + return True + + # Check for characteristic S2 groups + s2_indicators = [ + '/measurements/reflectance', + '/conditions/geometry', + '/quality/atmosphere' + ] + + found_indicators = sum(1 for indicator in s2_indicators if indicator in dt.groups) + return found_indicators >= 2 + + def _write_auxiliary_group( + self, + dataset: xr.Dataset, + group_path: str, + group_type: str, + verbose: bool + ) -> None: + """Write auxiliary group (geometry or meteorology).""" + # Create simple encoding + encoding = {} + for var_name in dataset.data_vars: + encoding[var_name] = {'compressor': 'default'} + for coord_name in dataset.coords: + encoding[coord_name] = {'compressor': None} + + # Write dataset + storage_options = get_storage_options(group_path) + dataset.to_zarr( + group_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=encoding, + storage_options=storage_options + ) + + if verbose: + print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") + + def _add_root_multiscales_metadata( + self, + output_path: str, + pyramid_datasets: Dict[int, xr.Dataset] + ) -> None: + """Add multiscales metadata at root level.""" + from ..geozarr import create_native_crs_tile_matrix_set, calculate_overview_levels + + # Get information from level 0 dataset + if 0 not in pyramid_datasets: + return + + level_0_ds = pyramid_datasets[0] + if not level_0_ds.data_vars: + return + + # Get spatial info from first variable + first_var = next(iter(level_0_ds.data_vars.values())) + native_height, native_width = first_var.shape[-2:] + native_crs = level_0_ds.rio.crs + native_bounds = level_0_ds.rio.bounds() + + # Calculate overview levels + overview_levels = [] + for level, resolution in self.pyramid_creator.pyramid_levels.items(): + if level in pyramid_datasets: + level_ds = pyramid_datasets[level] + level_var = next(iter(level_ds.data_vars.values())) + level_height, level_width = level_var.shape[-2:] + + overview_levels.append({ + 'level': level, + 'resolution': resolution, + 'width': level_width, + 'height': level_height, + 'scale_factor': 2 ** level if level > 0 else 1 + }) + + # Create tile matrix set + tile_matrix_set = create_native_crs_tile_matrix_set( + native_crs, native_bounds, overview_levels, "measurements" + ) + + # Add metadata to measurements group + measurements_zarr_path = normalize_path(f"{output_path}/measurements/zarr.json") + if os.path.exists(measurements_zarr_path): + import json + with open(measurements_zarr_path, 'r') as f: + zarr_json = json.load(f) + + zarr_json.setdefault('attributes', {}) + zarr_json['attributes']['multiscales'] = { + 'tile_matrix_set': tile_matrix_set, + 'resampling_method': 'average', + 'datasets': [{'path': str(level)} for level in sorted(pyramid_datasets.keys())] + } + + with open(measurements_zarr_path, 'w') as f: + json.dump(zarr_json, f, indent=2) + + def _consolidate_root_metadata(self, output_path: str) -> None: + """Consolidate metadata at root level.""" + try: + from ..geozarr import consolidate_metadata + from ..fs_utils import open_zarr_group + + zarr_group = open_zarr_group(output_path, mode="r+") + consolidate_metadata(zarr_group.store) + except Exception as e: + print(f" Warning: Root metadata consolidation failed: {e}") + + def _create_result_datatree(self, output_path: str) -> xr.DataTree: + """Create result DataTree from written output.""" + try: + storage_options = get_storage_options(output_path) + return xr.open_datatree( + output_path, + engine='zarr', + chunks='auto', + storage_options=storage_options + ) + except Exception as e: + print(f"Warning: Could not open result DataTree: {e}") + return xr.DataTree() + + def _print_optimization_summary( + self, + dt_input: xr.DataTree, + dt_output: xr.DataTree, + output_path: str + ) -> None: + """Print optimization summary statistics.""" + print("\n" + "="*50) + print("OPTIMIZATION SUMMARY") + print("="*50) + + # Count groups + input_groups = len(dt_input.groups) if hasattr(dt_input, 'groups') else 0 + output_groups = len(dt_output.groups) if hasattr(dt_output, 'groups') else 0 + + print(f"Groups: {input_groups} → {output_groups} ({((output_groups-input_groups)/input_groups*100):+.1f}%)") + + # Estimate file count reduction + estimated_input_files = input_groups * 10 # Rough estimate + estimated_output_files = output_groups * 5 # Fewer files per group + print(f"Estimated files: {estimated_input_files} → {estimated_output_files} ({((estimated_output_files-estimated_input_files)/estimated_input_files*100):+.1f}%)") + + # Show structure + print(f"\nNew structure:") + print(f" /measurements/ (multiscale: levels 0-6)") + if f"{output_path}/geometry" in str(dt_output): + print(f" /geometry/ (consolidated)") + if f"{output_path}/meteorology" in str(dt_output): + print(f" /meteorology/ (consolidated)") + + print("="*50) + + +def convert_s2_optimized( + dt_input: xr.DataTree, + output_path: str, + **kwargs +) -> xr.DataTree: + """ + Convenience function for S2 optimization. + + Args: + dt_input: Input Sentinel-2 DataTree + output_path: Output path + **kwargs: Additional arguments for S2OptimizedConverter + + Returns: + Optimized DataTree + """ + converter = S2OptimizedConverter(**kwargs) + return converter.convert_s2_optimized(dt_input, output_path, **kwargs) +``` + +### 6. s2_validation.py + +```python +""" +Validation for S2 optimized datasets. +""" + +import os +from typing import Dict, List, Any +import xarray as xr +from ..fs_utils import get_storage_options + +class S2OptimizationValidator: + """Validates S2 optimized dataset structure and integrity.""" + + def validate_optimized_dataset(self, dataset_path: str) -> Dict[str, Any]: + """ + Validate an optimized S2 dataset. + + Args: + dataset_path: Path to the optimized dataset + + Returns: + Validation results dictionary + """ + results = { + 'is_valid': True, + 'issues': [], + 'warnings': [], + 'summary': {} + } + + try: + storage_options = get_storage_options(dataset_path) + dt = xr.open_datatree( + dataset_path, + engine='zarr', + chunks='auto', + storage_options=storage_options + ) + + # Check required groups + self._validate_group_structure(dt, results) + + # Check multiscale structure + self._validate_multiscale_structure(dt, results) + + # Check data integrity + self._validate_data_integrity(dt, results) + + # Check metadata compliance + self._validate_metadata_compliance(dt, results) + + except Exception as e: + results['is_valid'] = False + results['issues'].append(f"Failed to open dataset: {e}") + + return results + + def _validate_group_structure(self, dt: xr.DataTree, results: Dict) -> None: + """Validate the expected group structure.""" + required_groups = ['/measurements'] + optional_groups = ['/geometry', '/meteorology'] + + existing_groups = set(dt.groups.keys()) if hasattr(dt, 'groups') else set() + + # Check required groups + for group in required_groups: + if group not in existing_groups: + results['issues'].append(f"Missing required group: {group}") + results['is_valid'] = False + + # Check for unexpected groups + expected_groups = set(required_groups + optional_groups) + unexpected = existing_groups - expected_groups - {'.'} # Exclude root + if unexpected: + results['warnings'].append(f"Unexpected groups found: {list(unexpected)}") + + results['summary']['groups_found'] = len(existing_groups) + + def _validate_multiscale_structure(self, dt: xr.DataTree, results: Dict) -> None: + """Validate multiscale pyramid structure in measurements.""" + if '/measurements' not in dt.groups: + return + + measurements_group = dt['/measurements'] + + # Check for numeric subgroups (pyramid levels) + if not hasattr(measurements_group, 'children'): + results['issues'].append("Measurements group has no pyramid levels") + results['is_valid'] = False + return + + pyramid_levels = [] + for child_name in measurements_group.children: + if child_name.isdigit(): + pyramid_levels.append(int(child_name)) + + pyramid_levels.sort() + + if not pyramid_levels: + results['issues'].append("No pyramid levels found in measurements") + results['is_valid'] = False + return + + # Validate level 0 exists (native resolution) + if 0 not in pyramid_levels: + results['issues'].append("Missing pyramid level 0 (native resolution)") + results['is_valid'] = False + + # Check for reasonable progression + if len(pyramid_levels) < 2: + results['warnings'].append("Only one pyramid level found - not truly multiscale") + + results['summary']['pyramid_levels'] = pyramid_levels + results['summary']['max_pyramid_level'] = max(pyramid_levels) if pyramid_levels else 0 + + def _validate_data_integrity(self, dt: xr.DataTree, results: Dict) -> None: + """Validate data integrity across pyramid levels.""" + if '/measurements' not in dt.groups: + return + + measurements = dt['/measurements'] + pyramid_levels = [] + + # Get pyramid levels + if hasattr(measurements, 'children'): + for child_name in measurements.children: + if child_name.isdigit(): + pyramid_levels.append(int(child_name)) + + if not pyramid_levels: + return + + pyramid_levels.sort() + + # Check coordinate consistency across levels + reference_crs = None + for level in pyramid_levels: + level_path = f'/measurements/{level}' + if level_path in dt.groups: + level_ds = dt[level_path].to_dataset() + + # Check CRS consistency + if hasattr(level_ds, 'rio') and level_ds.rio.crs: + if reference_crs is None: + reference_crs = level_ds.rio.crs + elif reference_crs != level_ds.rio.crs: + results['issues'].append(f"CRS mismatch at level {level}") + results['is_valid'] = False + + # Check for data variables + if not level_ds.data_vars: + results['warnings'].append(f"Pyramid level {level} has no data variables") + + results['summary']['reference_crs'] = str(reference_crs) if reference_crs else None + + def _validate_metadata_compliance(self, dt: xr.DataTree, results: Dict) -> None: + """Validate GeoZarr and CF compliance.""" + compliance_issues = [] + + # Check for multiscales metadata + if '/measurements' in dt.groups: + measurements = dt['/measurements'] + if hasattr(measurements, 'attrs'): + if 'multiscales' not in measurements.attrs: + results['warnings'].append("Missing multiscales metadata in measurements group") + + # Check variable attributes + total_vars = 0 + compliant_vars = 0 + + for group_path in dt.groups: + if group_path == '.': + continue + + group_ds = dt[group_path].to_dataset() + for var_name, var_data in group_ds.data_vars.items(): + total_vars += 1 + + # Check required attributes + var_issues = [] + if '_ARRAY_DIMENSIONS' not in var_data.attrs: + var_issues.append('Missing _ARRAY_DIMENSIONS') + + if 'standard_name' not in var_data.attrs: + var_issues.append('Missing standard_name') + + if not var_issues: + compliant_vars += 1 + else: + compliance_issues.append(f"{group_path}/{var_name}: {', '.join(var_issues)}") + + if compliance_issues: + results['warnings'].extend(compliance_issues[:5]) # Show first 5 + if len(compliance_issues) > 5: + results['warnings'].append(f"... and {len(compliance_issues) - 5} more compliance issues") + + results['summary']['total_variables'] = total_vars + results['summary']['compliant_variables'] = compliant_vars + results['summary']['compliance_rate'] = f"{compliant_vars/total_vars*100:.1f}%" if total_vars > 0 else "0%" +``` + +### 7. CLI Integration + +```python +""" +CLI integration for S2 optimization. +""" + +import argparse +from pathlib import Path +from .s2_converter import convert_s2_optimized +from ..fs_utils import get_storage_options +import xarray as xr + +def add_s2_optimization_commands(subparsers): + """Add S2 optimization commands to CLI parser.""" + + # Convert S2 optimized command + s2_parser = subparsers.add_parser( + 'convert-s2-optimized', + help='Convert Sentinel-2 dataset to optimized structure' + ) + s2_parser.add_argument( + 'input_path', + type=str, + help='Path to input Sentinel-2 dataset (Zarr format)' + ) + s2_parser.add_argument( + 'output_path', + type=str, + help='Path for output optimized dataset' + ) + s2_parser.add_argument( + '--spatial-chunk', + type=int, + default=1024, + help='Spatial chunk size (default: 1024)' + ) + s2_parser.add_argument( + '--enable-sharding', + action='store_true', + help='Enable Zarr v3 sharding' + ) + s2_parser.add_argument( + '--compression-level', + type=int, + default=3, + choices=range(1, 10), + help='Compression level 1-9 (default: 3)' + ) + s2_parser.add_argument( + '--skip-geometry', + action='store_true', + help='Skip creating geometry group' + ) + s2_parser.add_argument( + '--skip-meteorology', + action='store_true', + help='Skip creating meteorology group' + ) + s2_parser.add_argument( + '--skip-validation', + action='store_true', + help='Skip output validation' + ) + s2_parser.add_argument( + '--verbose', + action='store_true', + help='Enable verbose output' + ) + s2_parser.set_defaults(func=convert_s2_optimized_command) + +def convert_s2_optimized_command(args): + """Execute S2 optimized conversion command.""" + try: + # Validate input + input_path = Path(args.input_path) + if not input_path.exists(): + print(f"Error: Input path {input_path} does not exist") + return 1 + + # Load input dataset + print(f"Loading Sentinel-2 dataset from: {args.input_path}") + storage_options = get_storage_options(str(input_path)) + dt_input = xr.open_datatree( + str(input_path), + engine='zarr', + chunks='auto', + storage_options=storage_options + ) + + # Convert + dt_optimized = convert_s2_optimized( + dt_input=dt_input, + output_path=args.output_path, + enable_sharding=args.enable_sharding, + spatial_chunk=args.spatial_chunk, + compression_level=args.compression_level, + create_geometry_group=not args.skip_geometry, + create_meteorology_group=not args.skip_meteorology, + validate_output=not args.skip_validation, + verbose=args.verbose + ) + + print(f"✅ S2 optimization completed: {args.output_path}") + return 0 + + except Exception as e: + print(f"❌ Error during S2 optimization: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 +``` + +## Testing Strategy + +### Unit Tests Structure +``` +tests/s2_optimization/ +├── test_band_mapping.py +├── test_resampling.py +├── test_consolidator.py +├── test_multiscale.py +├── test_converter.py +├── test_validation.py +└── fixtures/ + ├── sample_s2_structure.zarr/ + └── expected_outputs/ +``` + +### Integration Test Scenarios +1. **Complete S2 L2A conversion** with all groups +2. **Minimal conversion** with measurements only +3. **Large dataset handling** (>10GB) +4. **Error recovery** scenarios +5. **Performance benchmarking** vs. current implementation + +### Validation Checklist +- [ ] All native resolution data preserved exactly +- [ ] Proper downsampling applied at each level +- [ ] Coordinate systems consistent across levels +- [ ] Metadata compliance maintained +- [ ] File count reduction achieved +- [ ] Access time improvements verified +- [ ] Memory usage optimized \ No newline at end of file diff --git a/s2_optimization_plan.md b/s2_optimization_plan.md new file mode 100644 index 00000000..5ca06548 --- /dev/null +++ b/s2_optimization_plan.md @@ -0,0 +1,363 @@ +# Sentinel-2 Zarr Conversion Optimization Plan + +## Overview +This plan outlines the development of an optimized Sentinel-2 converter (`convert_s2`) that dramatically simplifies the dataset structure while maintaining scientific integrity and improving storage efficiency. + +## Current State Analysis + +### Problems with Current Structure +- **File proliferation**: Multiple resolution groups create numerous zarr chunks +- **Data redundancy**: Repeated metadata across similar resolution groups +- **Complex navigation**: Deep nested structure makes data discovery difficult +- **Storage inefficiency**: Many small groups instead of consolidated datasets +- **Inconsistent multiscale**: Each resolution treated separately rather than as pyramid + +### Current Structure Issues +``` +root/ +├── conditions/ # Mixed resolution data scattered +├── quality/ # Atmosphere data duplicated across resolutions +└── measurements/ # Resolution-based grouping creates complexity + └── reflectance/ + ├── r10m/ (+ 6 pyramid levels) + ├── r20m/ (+ 5 pyramid levels) + └── r60m/ (+ 3 pyramid levels) +``` + +## Proposed Optimized Structure + +### New Simplified Structure +``` +root/ +├── measurements/ # Single multiscale group +│ ├── 0/ # 10m native (b02,b03,b04,b08 + all derived data) +│ ├── 1/ # 20m native (all b01-b12,b8a + derived data) +│ ├── 2/ # 60m native (minimal additional bands) +│ ├── 3/ # 120m (downsampled) +│ ├── 4/ # 240m (downsampled) +│ ├── 5/ # 480m (downsampled) +│ └── 6/ # 960m (downsampled) +├── geometry/ # Consolidated geometric data +└── meteorology/ # Consolidated weather data +``` + +## Technical Design Specifications + +### 1. Band Distribution Strategy + +#### Level 0 (10m) - Native Resolution +**Variables:** +- `b02`, `b03`, `b04`, `b08` (native 10m bands) +- `detector_footprint_b02`, `detector_footprint_b03`, `detector_footprint_b04`, `detector_footprint_b08` +- `quality_b02`, `quality_b03`, `quality_b04`, `quality_b08` + +#### Level 1 (20m) - Native Resolution +**Variables:** +- `b02`, `b03`, `b04`, `b05`, `b06`, `b07`, `b08`, `b11`, `b12`, `b8a` (native 20m bands + all 10m bands downsampled) +- detector footprints of all 10m and 20m bands +- `aot`, `wvp` (native 20m atmosphere quality) +- `scl` (native 20m classification) +- `cld`, `snw` (cloud and snow probability) +- All quality masks for each band (downsampled from 10m where applicable) + +#### Level 2 (60m) - Native Resolution +**Variables:** +- All Bands: `b01`, `b02`, `b03`, `b04`, `b05`, `b06`, `b07`, `b08`, `b09`, `b11`, `b12`, `b8a` +- Detector footprints for all bands +- `scl` (native 60m classification) +- `aot`, `wvp` (downsampled from 20m) +- `cld`, `snw` (downsampled from 20m) +- Quality masks for all bands (downsampled from 20m where applicable) + +#### Levels 3-6 (120m, 240m, 480m, 960m) +**Variables:** +- All bands downsampled using appropriate resampling methods +- All quality and classification data downsampled accordingly + +### 2. Data Consolidation Rules + +#### Measurements Group +- **Resampling Strategy**: + - Upsampling: Bilinear interpolation for reflectance bands + - Downsampling: Block averaging for reflectance, mode for classifications + - Quality data: Logical operations (any/all) for binary masks + +- **Variable Naming Convention**: + - Spectral bands: `b01`, `b02`, ..., `b12`, `b8a` + - Detector footprints: `detector_footprint_{band}` + - Atmosphere quality: `aot`, `wvp` + - Classification: `scl` + - Probability: `cld`, `snw` + - Quality: `quality_{band}` + +#### Geometry Group +- `sun_angles` (consolidated) +- `viewing_incidence_angles` (consolidated) +- `mean_sun_angles`, `mean_viewing_incidence_angles` +- `spatial_ref` + +#### Meteorology Group +- CAMS data: `aod*`, `*aod550`, etc. +- ECMWF data: `msl`, `tco3`, `tcwv`, `u10`, `v10`, `r` + +### 3. Multiscale Implementation + +#### Pyramid Strategy +- **Native data preservation**: Only store at native resolution +- **No upsampling**: Higher resolution levels only contain natively available bands +- **Consistent downsampling**: /2 decimation between consecutive levels where appropriate +- **Smart resampling**: + - Reflectance: Bilinear → Block average + - Classifications: Mode + - Quality masks: Logical operations + - Probabilities: Average + +#### Storage Optimization +- **Chunking**: Align chunks across all levels (e.g., 1024×1024) +- **Compression**: Consistent codec across all variables +- **Sharding**: Enable for the full shape dimension +- **Metadata consolidation**: Single zarr.json with complete multiscales info + +## Implementation Plan + +### Phase 1: Core Infrastructure (New Files) + +#### File: `s2_converter.py` +```python +class S2OptimizedConverter: + """Optimized Sentinel-2 to GeoZarr converter""" + + def __init__(self, enable_sharding=True, spatial_chunk=1024): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + + def convert_s2(self, dt_input, output_path, **kwargs): + """Main conversion entry point""" + + def _create_optimized_structure(self, dt_input): + """Reorganize data into 3 main groups""" + + def _create_measurements_multiscale(self, measurements_data, output_path): + """Create consolidated multiscale measurements""" + + def _consolidate_geometry_data(self, dt_input): + """Consolidate all geometry-related data""" + + def _consolidate_meteorology_data(self, dt_input): + """Consolidate CAMS and ECMWF data""" +``` + +#### File: `s2_band_mapping.py` +```python +# Band availability by native resolution +NATIVE_BANDS = { + 10: ['b02', 'b03', 'b04', 'b08'], + 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], + 60: ['b01', 'b09'] # Only these are truly native to 60m +} + +# Quality data mapping +QUALITY_DATA_MAPPING = { + 'atmosphere': ['aot', 'wvp'], + 'classification': ['scl'], + 'probability': ['cld', 'snw'], + 'detector_footprint': NATIVE_BANDS, + 'quality_masks': NATIVE_BANDS +} +``` + +#### File: `s2_resampling.py` +```python +class S2ResamplingEngine: + """Handles all resampling operations for S2 data""" + + def upsample_reflectance(self, data, target_resolution): + """Bilinear upsampling for reflectance bands (REMOVED - no upsampling)""" + raise NotImplementedError("Upsampling disabled - only native resolution data stored") + + def downsample_reflectance(self, data, target_resolution): + """Block averaging for reflectance bands""" + + def resample_classification(self, data, target_resolution, method='mode'): + """Mode-based resampling for classification data""" + + def resample_quality_masks(self, data, target_resolution, operation='any'): + """Logical operations for quality masks""" +``` + +### Phase 2: Data Processing Pipeline + +#### Measurement Processing Pipeline +1. **Data Inventory**: Scan input structure and catalog all variables +2. **Resolution Analysis**: Determine native resolution for each variable +3. **Level Planning**: Calculate which variables belong to each pyramid level +4. **Resampling Execution**: Apply appropriate resampling for each variable/level combination +5. **Multiscale Writing**: Write consolidated datasets with proper metadata + +#### Quality Assurance +- **Data Integrity**: Verify no data loss during consolidation +- **Coordinate Consistency**: Ensure all variables share consistent coordinate systems +- **Metadata Compliance**: Maintain GeoZarr-spec compliance +- **Performance Validation**: Measure storage and access improvements + +### Phase 3: Advanced Features + +#### Smart Chunking Strategy +```python +def calculate_optimal_chunks(self, resolution_level, data_shape): + """Calculate chunks that optimize both storage and access patterns""" + base_chunk = 1024 + level_factor = 2 ** resolution_level + optimal_chunk = min(base_chunk, data_shape[-1] // level_factor) + return (1, optimal_chunk, optimal_chunk) # For 3D (band, y, x) +``` + +#### Compression Optimization +```python +COMPRESSION_CONFIG = { + 'reflectance': {'codec': 'blosc', 'level': 5, 'shuffle': True}, + 'classification': {'codec': 'blosc', 'level': 9, 'shuffle': False}, + 'quality': {'codec': 'blosc', 'level': 7, 'shuffle': True}, + 'probability': {'codec': 'blosc', 'level': 6, 'shuffle': True} +} +``` + +## Expected Benefits Analysis + +### Storage Optimization +- **Estimated reduction**: 40-60% fewer zarr chunks +- **Metadata efficiency**: ~90% reduction in .zmetadata files +- **Redundancy elimination**: Remove duplicate spatial reference data +- **Compression synergy**: Better compression ratios with consolidated data + +### Access Pattern Improvements +- **Faster discovery**: 3 top-level groups vs current 15+ +- **Consistent multiscale**: Single pyramid instead of separate resolution trees +- **Simplified APIs**: Users access data by scale level, not resolution group +- **Better caching**: Consolidated chunks improve filesystem performance + +### Scientific Workflow Benefits +- **Band co-registration**: All bands at same level guaranteed co-registered +- **Quality correlation**: Quality data co-located with measurements +- **Scale-aware processing**: Natural support for multi-resolution analysis +- **Simplified subsetting**: Single coordinate system across all variables at each level + +## Technical Challenges & Solutions + +### Challenge 1: Mixed Resolution Data +**Problem**: Different bands have different native resolutions +**Solution**: Store only at native resolution, use resampling for access at other levels + +### Challenge 2: Quality Data Alignment +**Problem**: Quality data needs to align with corresponding measurement bands +**Solution**: Resample quality data to match measurement resolution at each level + +### Challenge 3: Coordinate System Consistency +**Problem**: Ensuring consistent coordinates across all variables +**Solution**: Use master coordinate grid for each level, snap all data to this grid + +### Challenge 4: Backward Compatibility +**Problem**: Existing tools expect current structure +**Solution**: Provide mapping utilities and clear migration documentation + +## Implementation Timeline + +### Week 1-2: Infrastructure +- Create new module files +- Implement band mapping and resolution logic +- Develop resampling engine +- Create basic converter structure + +### Week 3-4: Core Conversion +- Implement measurements multiscale creation +- Develop geometry and meteorology consolidation +- Add multiscale metadata generation +- Create chunking and compression optimization + +### Week 5-6: Validation & Testing +- Implement data integrity checks +- Create performance benchmarking +- Add error handling and edge cases +- Develop unit and integration tests + +### Week 7-8: Documentation & Examples +- Create user documentation +- Develop example notebooks +- Add CLI integration +- Performance comparison analysis + +## Usage Interface + +### CLI Integration +```bash +# New optimized conversion +eopf-geozarr convert-s2-optimized input.zarr output.zarr \ + --spatial-chunk 1024 \ + --enable-sharding \ + --compression-level 5 + +# Validation +eopf-geozarr validate-s2 output.zarr --check-optimization +``` + +### Python API +```python +from eopf_geozarr.s2_converter import convert_s2_optimized + +# Convert with optimization +convert_s2_optimized( + input_datatree=dt, + output_path="optimized.zarr", + spatial_chunk=1024, + enable_sharding=True, + compression_preset='balanced' +) +``` + +## Validation Criteria + +### Storage Efficiency +- [ ] <50% of original zarr chunk count +- [ ] <30% of original metadata files +- [ ] ≥20% reduction in total storage size +- [ ] Consistent compression ratios across levels + +### Data Integrity +- [ ] Bit-exact preservation of native resolution data +- [ ] Consistent coordinate systems across all levels +- [ ] Proper handling of nodata/fill values +- [ ] Metadata preservation and enhancement + +### Performance +- [ ] ≥2x faster dataset opening time +- [ ] ≥1.5x faster band access time +- [ ] Reduced memory overhead for multiscale operations +- [ ] Better parallel access patterns + +### Compliance +- [ ] Full GeoZarr-spec compliance maintained +- [ ] CF conventions adherence +- [ ] STAC metadata compatibility +- [ ] Cloud-optimized access patterns + +## Risk Mitigation + +### Data Loss Prevention +- Comprehensive validation pipeline +- Bit-level comparison tools +- Automated regression testing +- Rollback capabilities + +### Performance Regression +- Benchmarking against current implementation +- Memory usage monitoring +- Access pattern optimization +- Chunking strategy validation + +### User Adoption +- Clear migration guides +- Backward compatibility tools +- Performance demonstrations +- Community feedback integration + +This plan provides a roadmap for creating a significantly more efficient and user-friendly Sentinel-2 zarr format while maintaining scientific integrity and improving performance across multiple use cases. \ No newline at end of file diff --git a/src/eopf_geozarr/s2_optimization/__init__.py b/src/eopf_geozarr/s2_optimization/__init__.py new file mode 100644 index 00000000..b5e84226 --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/__init__.py @@ -0,0 +1,2 @@ +# Sentinel-2 Optimization Module +# This package contains tools for optimizing Sentinel-2 Zarr datasets. diff --git a/src/eopf_geozarr/s2_optimization/cli_integration.py b/src/eopf_geozarr/s2_optimization/cli_integration.py new file mode 100644 index 00000000..582dd96b --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/cli_integration.py @@ -0,0 +1,45 @@ +""" +CLI integration for Sentinel-2 optimization. +""" + +import argparse +from .s2_converter import S2OptimizedConverter + +def add_s2_optimization_commands(subparsers): + """Add Sentinel-2 optimization commands to CLI parser.""" + + s2_parser = subparsers.add_parser( + 'convert-s2-optimized', + help='Convert Sentinel-2 dataset to optimized structure' + ) + s2_parser.add_argument( + 'input_path', + type=str, + help='Path to input Sentinel-2 dataset (Zarr format)' + ) + s2_parser.add_argument( + 'output_path', + type=str, + help='Path for output optimized dataset' + ) + s2_parser.add_argument( + '--spatial-chunk', + type=int, + default=1024, + help='Spatial chunk size (default: 1024)' + ) + s2_parser.add_argument( + '--enable-sharding', + action='store_true', + help='Enable Zarr v3 sharding' + ) + s2_parser.set_defaults(func=convert_s2_optimized_command) + +def convert_s2_optimized_command(args): + """Execute Sentinel-2 optimized conversion command.""" + converter = S2OptimizedConverter( + enable_sharding=args.enable_sharding, + spatial_chunk=args.spatial_chunk + ) + # Placeholder for CLI command execution logic + pass diff --git a/src/eopf_geozarr/s2_optimization/s2_band_mapping.py b/src/eopf_geozarr/s2_optimization/s2_band_mapping.py new file mode 100644 index 00000000..af4bbe67 --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_band_mapping.py @@ -0,0 +1,80 @@ +""" +Band mapping and resolution definitions for Sentinel-2 optimization. +""" + +from typing import Dict, List, Set +from dataclasses import dataclass + +@dataclass +class BandInfo: + """Information about a spectral band.""" + name: str + native_resolution: int # meters + data_type: str + wavelength_center: float # nanometers + wavelength_width: float # nanometers + +# Native resolution definitions +NATIVE_BANDS: Dict[int, List[str]] = { + 10: ['b02', 'b03', 'b04', 'b08'], # Blue, Green, Red, NIR + 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], # Red Edge, SWIR + 60: ['b01', 'b09'] # Coastal, Water Vapor +} + +# Complete band information +BAND_INFO: Dict[str, BandInfo] = { + 'b01': BandInfo('b01', 60, 'uint16', 443, 21), # Coastal aerosol + 'b02': BandInfo('b02', 10, 'uint16', 490, 66), # Blue + 'b03': BandInfo('b03', 10, 'uint16', 560, 36), # Green + 'b04': BandInfo('b04', 10, 'uint16', 665, 31), # Red + 'b05': BandInfo('b05', 20, 'uint16', 705, 15), # Red Edge 1 + 'b06': BandInfo('b06', 20, 'uint16', 740, 15), # Red Edge 2 + 'b07': BandInfo('b07', 20, 'uint16', 783, 20), # Red Edge 3 + 'b08': BandInfo('b08', 10, 'uint16', 842, 106), # NIR + 'b8a': BandInfo('b8a', 20, 'uint16', 865, 21), # NIR Narrow + 'b09': BandInfo('b09', 60, 'uint16', 945, 20), # Water Vapor + 'b11': BandInfo('b11', 20, 'uint16', 1614, 91), # SWIR 1 + 'b12': BandInfo('b12', 20, 'uint16', 2202, 175), # SWIR 2 +} + +# Quality data mapping - defines which auxiliary data exists at which resolutions +QUALITY_DATA_NATIVE: Dict[str, int] = { + 'scl': 20, # Scene Classification Layer - native 20m + 'aot': 20, # Aerosol Optical Thickness - native 20m + 'wvp': 20, # Water Vapor - native 20m + 'cld': 20, # Cloud probability - native 20m + 'snw': 20, # Snow probability - native 20m +} + +# Detector footprint availability - matches spectral bands +DETECTOR_FOOTPRINT_NATIVE: Dict[int, List[str]] = { + 10: ['b02', 'b03', 'b04', 'b08'], + 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], + 60: ['b01', 'b09'] +} + +def get_bands_for_level(level: int) -> Set[str]: + """ + Get all bands available at a given pyramid level. + + Args: + level: Pyramid level (0=10m, 1=20m, 2=60m, 3+=downsampled) + + Returns: + Set of band names available at this level + """ + if level == 0: # 10m - only native 10m bands + return set(NATIVE_BANDS[10]) + elif level == 1: # 20m - all bands (native + downsampled from 10m) + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + elif level == 2: # 60m - all bands downsampled + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + else: # Further downsampling - all bands + return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + +def get_quality_data_for_level(level: int) -> Set[str]: + """Get quality data available at a given level (no upsampling).""" + if level == 0: # 10m - no quality data (would require upsampling) + return set() + elif level >= 1: # 20m and below - all quality data available + return set(QUALITY_DATA_NATIVE.keys()) \ No newline at end of file diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py new file mode 100644 index 00000000..d828879d --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -0,0 +1,36 @@ +""" +Main S2 optimization converter. +""" + +class S2OptimizedConverter: + """Optimized Sentinel-2 to GeoZarr converter.""" + + def __init__(self, enable_sharding=True, spatial_chunk=1024): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + + def convert_s2(self, dt_input, output_path, **kwargs): + """Main conversion entry point.""" + from .s2_data_consolidator import S2DataConsolidator + from .s2_multiscale import S2MultiscalePyramid + from .s2_validation import S2OptimizationValidator + + # Consolidate data + consolidator = S2DataConsolidator(dt_input) + measurements, geometry, meteorology = consolidator.consolidate_all_data() + + # Create multiscale pyramids + pyramid = S2MultiscalePyramid( + enable_sharding=self.enable_sharding, + spatial_chunk=self.spatial_chunk + ) + multiscale_data = pyramid.create_multiscale_measurements(measurements, output_path) + + # Validate the output + validator = S2OptimizationValidator() + validation_results = validator.validate_optimized_dataset(output_path) + + return { + "multiscale_data": multiscale_data, + "validation_results": validation_results + } diff --git a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py new file mode 100644 index 00000000..119dc55e --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py @@ -0,0 +1,43 @@ +""" +Data consolidation logic for reorganizing Sentinel-2 structure. +""" + +import xarray as xr +from typing import Dict, Tuple + +class S2DataConsolidator: + """Consolidates Sentinel-2 data from scattered structure into organized groups.""" + + def __init__(self, dt_input: xr.DataTree): + self.dt_input = dt_input + self.measurements_data = {} + self.geometry_data = {} + self.meteorology_data = {} + + def consolidate_all_data(self) -> Tuple[Dict, Dict, Dict]: + """ + Consolidate all data into three main categories. + + Returns: + Tuple of (measurements, geometry, meteorology) data dictionaries + """ + self._extract_measurements_data() + self._extract_geometry_data() + self._extract_meteorology_data() + + return self.measurements_data, self.geometry_data, self.meteorology_data + + def _extract_measurements_data(self) -> None: + """Extract and organize all measurement-related data.""" + # Placeholder for measurement extraction logic + pass + + def _extract_geometry_data(self) -> None: + """Extract all geometry-related data into a single group.""" + # Placeholder for geometry extraction logic + pass + + def _extract_meteorology_data(self) -> None: + """Extract meteorological data (CAMS and ECMWF).""" + # Placeholder for meteorology extraction logic + pass diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py new file mode 100644 index 00000000..ca1fe78a --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -0,0 +1,27 @@ +""" +Multiscale pyramid creation for optimized Sentinel-2 structure. +""" + +from typing import Dict +import xarray as xr + +class S2MultiscalePyramid: + """Creates multiscale pyramids for consolidated Sentinel-2 data.""" + + def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + + def create_multiscale_measurements(self, measurements_by_resolution: Dict[int, Dict], output_path: str) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid from consolidated measurements. + + Args: + measurements_by_resolution: Data organized by native resolution + output_path: Base output path + + Returns: + Dictionary of datasets by pyramid level + """ + # Placeholder for multiscale creation logic + pass diff --git a/src/eopf_geozarr/s2_optimization/s2_resampling.py b/src/eopf_geozarr/s2_optimization/s2_resampling.py new file mode 100644 index 00000000..5f835f04 --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_resampling.py @@ -0,0 +1,236 @@ +""" +Downsampling operations for Sentinel-2 data (no upsampling). +""" + +import numpy as np +import xarray as xr + +class S2ResamplingEngine: + """Handles downsampling operations for S2 multiscale creation.""" + + def __init__(self): + self.resampling_methods = { + 'reflectance': self._downsample_reflectance, + 'classification': self._downsample_classification, + 'quality_mask': self._downsample_quality_mask, + 'probability': self._downsample_probability, + 'detector_footprint': self._downsample_quality_mask, # Same as quality mask + } + + def downsample_variable(self, data: xr.DataArray, target_height: int, + target_width: int, var_type: str) -> xr.DataArray: + """ + Downsample a variable to target dimensions. + + Args: + data: Input data array + target_height: Target height in pixels + target_width: Target width in pixels + var_type: Type of variable ('reflectance', 'classification', etc.) + + Returns: + Downsampled data array + """ + if var_type not in self.resampling_methods: + raise ValueError(f"Unknown variable type: {var_type}") + + method = self.resampling_methods[var_type] + return method(data, target_height, target_width) + + def _downsample_reflectance(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Block averaging for reflectance bands.""" + # Calculate block sizes + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Ensure exact divisibility + if current_height % target_height != 0 or current_width % target_width != 0: + # Crop to make it divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + # Perform block averaging + if data.ndim == 3: # (time, y, x) or similar + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + downsampled = reshaped.mean(axis=(2, 4)) + else: # (y, x) + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = reshaped.mean(axis=(1, 3)) + + # Create new coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + # Create new DataArray + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_classification(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Mode-based downsampling for classification data.""" + from scipy import stats + + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Crop to make divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + # Reshape for block processing + if data.ndim == 3: + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + # Compute mode for each block + downsampled = np.zeros((data.shape[0], target_height, target_width), dtype=data.dtype) + for t in range(data.shape[0]): + for i in range(target_height): + for j in range(target_width): + block = reshaped[t, i, :, j, :].flatten() + mode_val = stats.mode(block, keepdims=False)[0] + downsampled[t, i, j] = mode_val + else: + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = np.zeros((target_height, target_width), dtype=data.dtype) + for i in range(target_height): + for j in range(target_width): + block = reshaped[i, :, j, :].flatten() + mode_val = stats.mode(block, keepdims=False)[0] + downsampled[i, j] = mode_val + + # Create coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_quality_mask(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Logical OR downsampling for quality masks (any bad pixel = bad block).""" + current_height, current_width = data.shape[-2:] + block_h = current_height // target_height + block_w = current_width // target_width + + # Crop to make divisible + new_height = (current_height // block_h) * block_h + new_width = (current_width // block_w) * block_w + data = data[..., :new_height, :new_width] + + if data.ndim == 3: + reshaped = data.values.reshape( + data.shape[0], target_height, block_h, target_width, block_w + ) + # Any non-zero value in block makes the downsampled pixel non-zero + downsampled = (reshaped.sum(axis=(2, 4)) > 0).astype(data.dtype) + else: + reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + downsampled = (reshaped.sum(axis=(1, 3)) > 0).astype(data.dtype) + + # Create coordinates + y_coords = data.coords[data.dims[-2]][::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + + if data.ndim == 3: + coords = { + data.dims[0]: data.coords[data.dims[0]], + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + else: + coords = { + data.dims[-2]: y_coords, + data.dims[-1]: x_coords + } + + return xr.DataArray( + downsampled, + dims=data.dims, + coords=coords, + attrs=data.attrs.copy() + ) + + def _downsample_probability(self, data: xr.DataArray, target_height: int, + target_width: int) -> xr.DataArray: + """Average downsampling for probability data.""" + # Use same method as reflectance but ensure values stay in [0,1] or [0,100] range + result = self._downsample_reflectance(data, target_height, target_width) + + # Clamp values to valid probability range + if result.max() <= 1.0: # [0,1] probabilities + result.values = np.clip(result.values, 0, 1) + else: # [0,100] percentages + result.values = np.clip(result.values, 0, 100) + + return result + +def determine_variable_type(var_name: str, var_data: xr.DataArray) -> str: + """ + Determine the type of a variable for appropriate resampling. + + Args: + var_name: Name of the variable + var_data: The data array + + Returns: + Variable type string + """ + # Spectral bands + if var_name.startswith('b') and (var_name[1:].isdigit() or var_name == 'b8a'): + return 'reflectance' + + # Quality data + if var_name in ['scl']: # Scene Classification Layer + return 'classification' + + if var_name in ['cld', 'snw']: # Probability data + return 'probability' + + if var_name in ['aot', 'wvp']: # Atmosphere quality - treat as reflectance + return 'reflectance' + + if var_name.startswith('detector_footprint_') or var_name.startswith('quality_'): + return 'quality_mask' + + # Default to reflectance for unknown variables + return 'reflectance' \ No newline at end of file diff --git a/src/eopf_geozarr/s2_optimization/s2_validation.py b/src/eopf_geozarr/s2_optimization/s2_validation.py new file mode 100644 index 00000000..030ac5c2 --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_validation.py @@ -0,0 +1,29 @@ +""" +Validation for optimized Sentinel-2 datasets. +""" + +from typing import Dict, Any +import xarray as xr + +class S2OptimizationValidator: + """Validates optimized Sentinel-2 dataset structure and integrity.""" + + def validate_optimized_dataset(self, dataset_path: str) -> Dict[str, Any]: + """ + Validate an optimized Sentinel-2 dataset. + + Args: + dataset_path: Path to the optimized dataset + + Returns: + Validation results dictionary + """ + results = { + 'is_valid': True, + 'issues': [], + 'warnings': [], + 'summary': {} + } + + # Placeholder for validation logic + return results diff --git a/src/eopf_geozarr/tests/test_s2_band_mapping.py b/src/eopf_geozarr/tests/test_s2_band_mapping.py new file mode 100644 index 00000000..e18f5051 --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_band_mapping.py @@ -0,0 +1,48 @@ +import pytest +from eopf_geozarr.s2_optimization.s2_band_mapping import ( + BandInfo, + NATIVE_BANDS, + BAND_INFO, + QUALITY_DATA_NATIVE, + DETECTOR_FOOTPRINT_NATIVE, + get_bands_for_level, + get_quality_data_for_level, +) + +def test_bandinfo_initialization(): + band = BandInfo("b01", 60, "uint16", 443, 21) + assert band.name == "b01" + assert band.native_resolution == 60 + assert band.data_type == "uint16" + assert band.wavelength_center == 443 + assert band.wavelength_width == 21 + +def test_native_bands(): + assert NATIVE_BANDS[10] == ["b02", "b03", "b04", "b08"] + assert NATIVE_BANDS[20] == ["b05", "b06", "b07", "b11", "b12", "b8a"] + assert NATIVE_BANDS[60] == ["b01", "b09"] + +def test_band_info(): + assert BAND_INFO["b01"].name == "b01" + assert BAND_INFO["b01"].native_resolution == 60 + assert BAND_INFO["b01"].data_type == "uint16" + assert BAND_INFO["b01"].wavelength_center == 443 + assert BAND_INFO["b01"].wavelength_width == 21 + +def test_quality_data_native(): + assert QUALITY_DATA_NATIVE["scl"] == 20 + assert QUALITY_DATA_NATIVE["aot"] == 20 + assert QUALITY_DATA_NATIVE["wvp"] == 20 + assert QUALITY_DATA_NATIVE["cld"] == 20 + assert QUALITY_DATA_NATIVE["snw"] == 20 + +def test_get_bands_for_level(): + assert get_bands_for_level(0) == set(NATIVE_BANDS[10]) + assert get_bands_for_level(1) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + assert get_bands_for_level(2) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + assert get_bands_for_level(3) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + +def test_get_quality_data_for_level(): + assert get_quality_data_for_level(0) == set() + assert get_quality_data_for_level(1) == set(QUALITY_DATA_NATIVE.keys()) + assert get_quality_data_for_level(2) == set(QUALITY_DATA_NATIVE.keys()) diff --git a/src/eopf_geozarr/tests/test_s2_converter.py b/src/eopf_geozarr/tests/test_s2_converter.py new file mode 100644 index 00000000..723a2220 --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_converter.py @@ -0,0 +1,33 @@ +""" +Unit tests for S2OptimizedConverter. +""" + +import pytest +import xarray as xr +from xarray import DataTree + +from eopf_geozarr.s2_optimization.s2_converter import S2OptimizedConverter + +@pytest.fixture +def mock_input_data(): + """Create mock input DataTree for testing.""" + # Placeholder for creating mock DataTree + return DataTree() + +def test_conversion_pipeline(mock_input_data, tmp_path): + """Test the full conversion pipeline.""" + output_path = tmp_path / "optimized_output" + converter = S2OptimizedConverter(enable_sharding=True, spatial_chunk=1024) + + result = converter.convert_s2(mock_input_data, str(output_path)) + + # Validate multiscale data + assert "multiscale_data" in result + assert isinstance(result["multiscale_data"], dict) + + # Validate output path + assert output_path.exists() + + # Validate validation results + assert "validation_results" in result + assert result["validation_results"]["is_valid"] diff --git a/src/eopf_geozarr/tests/test_s2_resampling.py b/src/eopf_geozarr/tests/test_s2_resampling.py new file mode 100644 index 00000000..b49392f0 --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_resampling.py @@ -0,0 +1,539 @@ +""" +Unit tests for S2 resampling functionality. +""" + +import pytest +import numpy as np +import xarray as xr + +from eopf_geozarr.s2_optimization.s2_resampling import ( + S2ResamplingEngine, + determine_variable_type, +) + + +@pytest.fixture +def sample_reflectance_data_2d(): + """Create a 2D reflectance data array for testing.""" + # Create a 4x4 array with known values + data = np.array([ + [100, 200, 300, 400], + [150, 250, 350, 450], + [110, 210, 310, 410], + [160, 260, 360, 460] + ], dtype=np.uint16) + + coords = { + 'y': np.array([1000, 990, 980, 970]), + 'x': np.array([500000, 500010, 500020, 500030]) + } + + return xr.DataArray( + data, + dims=['y', 'x'], + coords=coords, + attrs={'units': 'reflectance', 'scale_factor': 0.0001} + ) + + +@pytest.fixture +def sample_reflectance_data_3d(): + """Create a 3D reflectance data array with time dimension for testing.""" + # Create a 2x4x4 array (time, y, x) + data = np.array([ + [[100, 200, 300, 400], + [150, 250, 350, 450], + [110, 210, 310, 410], + [160, 260, 360, 460]], + [[120, 220, 320, 420], + [170, 270, 370, 470], + [130, 230, 330, 430], + [180, 280, 380, 480]] + ], dtype=np.uint16) + + coords = { + 'time': np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[D]'), + 'y': np.array([1000, 990, 980, 970]), + 'x': np.array([500000, 500010, 500020, 500030]) + } + + return xr.DataArray( + data, + dims=['time', 'y', 'x'], + coords=coords, + attrs={'units': 'reflectance', 'scale_factor': 0.0001} + ) + + +@pytest.fixture +def sample_classification_data(): + """Create classification data for testing.""" + # SCL values: 0=no_data, 1=saturated, 4=vegetation, 6=water, etc. + data = np.array([ + [0, 1, 4, 4], + [1, 4, 6, 6], + [4, 4, 6, 8], + [4, 6, 8, 8] + ], dtype=np.uint8) + + coords = { + 'y': np.array([1000, 990, 980, 970]), + 'x': np.array([500000, 500010, 500020, 500030]) + } + + return xr.DataArray( + data, + dims=['y', 'x'], + coords=coords, + attrs={'long_name': 'Scene Classification Layer'} + ) + + +@pytest.fixture +def sample_quality_mask(): + """Create quality mask data for testing.""" + # Binary mask: 0=good, 1=bad + data = np.array([ + [0, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 0, 1], + [0, 0, 1, 1] + ], dtype=np.uint8) + + coords = { + 'y': np.array([1000, 990, 980, 970]), + 'x': np.array([500000, 500010, 500020, 500030]) + } + + return xr.DataArray( + data, + dims=['y', 'x'], + coords=coords, + attrs={'long_name': 'Quality mask'} + ) + + +@pytest.fixture +def sample_probability_data(): + """Create probability data for testing.""" + # Cloud probabilities in percent (0-100) + data = np.array([ + [10.5, 20.3, 85.7, 92.1], + [15.2, 75.8, 88.3, 95.6], + [12.7, 18.9, 90.2, 87.4], + [8.1, 22.4, 78.9, 99.0] + ], dtype=np.float32) + + coords = { + 'y': np.array([1000, 990, 980, 970]), + 'x': np.array([500000, 500010, 500020, 500030]) + } + + return xr.DataArray( + data, + dims=['y', 'x'], + coords=coords, + attrs={'long_name': 'Cloud probability', 'units': 'percent'} + ) + + +class TestS2ResamplingEngine: + """Test cases for S2ResamplingEngine class.""" + + def test_initialization(self): + """Test engine initialization.""" + engine = S2ResamplingEngine() + + assert hasattr(engine, 'resampling_methods') + assert len(engine.resampling_methods) == 5 + assert 'reflectance' in engine.resampling_methods + assert 'classification' in engine.resampling_methods + assert 'quality_mask' in engine.resampling_methods + assert 'probability' in engine.resampling_methods + assert 'detector_footprint' in engine.resampling_methods + + def test_downsample_reflectance_2d(self, sample_reflectance_data_2d): + """Test reflectance downsampling for 2D data.""" + engine = S2ResamplingEngine() + + # Downsample from 4x4 to 2x2 + result = engine.downsample_variable( + sample_reflectance_data_2d, 2, 2, 'reflectance' + ) + + # Check dimensions + assert result.shape == (2, 2) + assert result.dims == ('y', 'x') + + # Check that values are averages of 2x2 blocks + # Top-left block: mean of [100, 200, 150, 250] = 175 + assert result.values[0, 0] == 175.0 + + # Top-right block: mean of [300, 400, 350, 450] = 375 + assert result.values[0, 1] == 375.0 + + # Check coordinates are properly subsampled + assert len(result.coords['y']) == 2 + assert len(result.coords['x']) == 2 + np.testing.assert_array_equal(result.coords['y'].values, [1000, 980]) + np.testing.assert_array_equal(result.coords['x'].values, [500000, 500020]) + + # Check attributes are preserved + assert result.attrs == sample_reflectance_data_2d.attrs + + def test_downsample_reflectance_3d(self, sample_reflectance_data_3d): + """Test reflectance downsampling for 3D data.""" + engine = S2ResamplingEngine() + + # Downsample from 2x4x4 to 2x2x2 + result = engine.downsample_variable( + sample_reflectance_data_3d, 2, 2, 'reflectance' + ) + + # Check dimensions + assert result.shape == (2, 2, 2) + assert result.dims == ('time', 'y', 'x') + + # Check first time slice values + # Top-left block: mean of [100, 200, 150, 250] = 175 + assert result.values[0, 0, 0] == 175.0 + + # Check second time slice values + # Top-left block: mean of [120, 220, 170, 270] = 195 + assert result.values[1, 0, 0] == 195.0 + + # Check coordinates + assert len(result.coords['time']) == 2 + assert len(result.coords['y']) == 2 + assert len(result.coords['x']) == 2 + + def test_downsample_classification(self, sample_classification_data): + """Test classification downsampling using mode.""" + engine = S2ResamplingEngine() + + # Downsample from 4x4 to 2x2 + result = engine.downsample_variable( + sample_classification_data, 2, 2, 'classification' + ) + + # Check dimensions + assert result.shape == (2, 2) + assert result.dims == ('y', 'x') + + # Check mode values + # Top-left block: [0, 1, 1, 4] -> mode should be 1 (most frequent) + # Top-right block: [4, 4, 6, 6] -> mode could be either 4 or 6 (both appear twice) + assert result.values[0, 0] in [0, 1, 4] # Allow for mode calculation variations + + # Check data type is preserved + assert result.dtype == sample_classification_data.dtype + + def test_downsample_quality_mask(self, sample_quality_mask): + """Test quality mask downsampling using logical OR.""" + engine = S2ResamplingEngine() + + # Downsample from 4x4 to 2x2 + result = engine.downsample_variable( + sample_quality_mask, 2, 2, 'quality_mask' + ) + + # Check dimensions + assert result.shape == (2, 2) + assert result.dims == ('y', 'x') + + # Check logical OR behavior + # Top-left block: [0, 0, 0, 1] -> any non-zero = 1 + assert result.values[0, 0] == 1 + + # Top-right block: [1, 0, 0, 0] -> any non-zero = 1 + assert result.values[0, 1] == 1 + + # Bottom-left block: [1, 0, 0, 0] -> any non-zero = 1 + assert result.values[1, 0] == 1 + + # Bottom-right block: [0, 1, 1, 1] -> any non-zero = 1 + assert result.values[1, 1] == 1 + + def test_downsample_probability(self, sample_probability_data): + """Test probability downsampling with value clamping.""" + engine = S2ResamplingEngine() + + # Downsample from 4x4 to 2x2 + result = engine.downsample_variable( + sample_probability_data, 2, 2, 'probability' + ) + + # Check dimensions + assert result.shape == (2, 2) + assert result.dims == ('y', 'x') + + # Values should be averages and clamped to [0, 100] + assert np.all(result.values >= 0) + assert np.all(result.values <= 100) + + # Check specific average calculation + # Top-left block: mean of [10.5, 20.3, 15.2, 75.8] ≈ 30.45 + expected_val = (10.5 + 20.3 + 15.2 + 75.8) / 4 + np.testing.assert_almost_equal(result.values[0, 0], expected_val, decimal=2) + + def test_detector_footprint_same_as_quality_mask(self, sample_quality_mask): + """Test that detector footprint uses same method as quality mask.""" + engine = S2ResamplingEngine() + + result_quality = engine.downsample_variable( + sample_quality_mask, 2, 2, 'quality_mask' + ) + result_detector = engine.downsample_variable( + sample_quality_mask, 2, 2, 'detector_footprint' + ) + + # Results should be identical + np.testing.assert_array_equal(result_quality.values, result_detector.values) + + def test_invalid_variable_type(self, sample_reflectance_data_2d): + """Test error handling for invalid variable type.""" + engine = S2ResamplingEngine() + + with pytest.raises(ValueError, match="Unknown variable type"): + engine.downsample_variable( + sample_reflectance_data_2d, 2, 2, 'invalid_type' + ) + + def test_non_divisible_dimensions(self): + """Test handling of non-divisible dimensions.""" + engine = S2ResamplingEngine() + + # Create 5x5 data (not evenly divisible by 2) + data = np.random.rand(5, 5).astype(np.float32) + coords = { + 'y': np.arange(5), + 'x': np.arange(5) + } + da = xr.DataArray(data, dims=['y', 'x'], coords=coords) + + # Should crop to make it divisible + result = engine.downsample_variable(da, 2, 2, 'reflectance') + + # Should result in 2x2 output (cropped from 4x4) + assert result.shape == (2, 2) + + def test_single_pixel_downsampling(self): + """Test downsampling to single pixel.""" + engine = S2ResamplingEngine() + + # Create 4x4 data + data = np.ones((4, 4), dtype=np.float32) * 100 + coords = { + 'y': np.arange(4), + 'x': np.arange(4) + } + da = xr.DataArray(data, dims=['y', 'x'], coords=coords) + + # Downsample to 1x1 + result = engine.downsample_variable(da, 1, 1, 'reflectance') + + assert result.shape == (1, 1) + assert result.values[0, 0] == 100.0 + + +class TestDetermineVariableType: + """Test cases for determine_variable_type function.""" + + def test_spectral_bands(self): + """Test recognition of spectral bands.""" + dummy_data = xr.DataArray([1, 2, 3]) + + # Test standard bands + assert determine_variable_type('b01', dummy_data) == 'reflectance' + assert determine_variable_type('b02', dummy_data) == 'reflectance' + assert determine_variable_type('b8a', dummy_data) == 'reflectance' + + # Test specific non-band variables that should be classified differently + assert determine_variable_type('scl', dummy_data) == 'classification' + assert determine_variable_type('cld', dummy_data) == 'probability' + assert determine_variable_type('quality_b01', dummy_data) == 'quality_mask' + + def test_classification_data(self): + """Test recognition of classification data.""" + dummy_data = xr.DataArray([1, 2, 3]) + + assert determine_variable_type('scl', dummy_data) == 'classification' + + def test_probability_data(self): + """Test recognition of probability data.""" + dummy_data = xr.DataArray([1, 2, 3]) + + assert determine_variable_type('cld', dummy_data) == 'probability' + assert determine_variable_type('snw', dummy_data) == 'probability' + + def test_atmospheric_quality(self): + """Test recognition of atmospheric quality data.""" + dummy_data = xr.DataArray([1, 2, 3]) + + assert determine_variable_type('aot', dummy_data) == 'reflectance' + assert determine_variable_type('wvp', dummy_data) == 'reflectance' + + def test_quality_masks(self): + """Test recognition of quality mask data.""" + dummy_data = xr.DataArray([1, 2, 3]) + + assert determine_variable_type('detector_footprint_b01', dummy_data) == 'quality_mask' + assert determine_variable_type('quality_b02', dummy_data) == 'quality_mask' + + def test_unknown_variable_defaults_to_reflectance(self): + """Test that unknown variables default to reflectance.""" + dummy_data = xr.DataArray([1, 2, 3]) + + assert determine_variable_type('unknown_var', dummy_data) == 'reflectance' + assert determine_variable_type('custom_band', dummy_data) == 'reflectance' + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_data_array(self): + """Test handling of empty data arrays.""" + engine = S2ResamplingEngine() + + # Create minimal data array + data = np.array([[1]]) + coords = {'y': [0], 'x': [0]} + da = xr.DataArray(data, dims=['y', 'x'], coords=coords) + + # This should work for 1x1 -> 1x1 downsampling + result = engine.downsample_variable(da, 1, 1, 'reflectance') + assert result.shape == (1, 1) + assert result.values[0, 0] == 1 + + def test_preserve_attributes_and_encoding(self): + """Test that attributes and encoding are preserved.""" + engine = S2ResamplingEngine() + + data = np.ones((4, 4), dtype=np.uint16) * 1000 + coords = { + 'y': np.arange(4), + 'x': np.arange(4) + } + + attrs = { + 'long_name': 'Test reflectance', + 'units': 'reflectance', + 'scale_factor': 0.0001, + 'add_offset': 0 + } + + da = xr.DataArray(data, dims=['y', 'x'], coords=coords, attrs=attrs) + + result = engine.downsample_variable(da, 2, 2, 'reflectance') + + # Attributes should be preserved + assert result.attrs == attrs + + def test_coordinate_names_preserved(self): + """Test that coordinate names are preserved during downsampling.""" + engine = S2ResamplingEngine() + + data = np.ones((4, 4), dtype=np.float32) + coords = { + 'latitude': np.arange(4), + 'longitude': np.arange(4) + } + + da = xr.DataArray(data, dims=['latitude', 'longitude'], coords=coords) + + result = engine.downsample_variable(da, 2, 2, 'reflectance') + + # Coordinate names should be preserved + assert 'latitude' in result.coords + assert 'longitude' in result.coords + assert result.dims == ('latitude', 'longitude') + + +class TestIntegrationScenarios: + """Integration test scenarios.""" + + def test_multiscale_pyramid_creation(self): + """Test creating a complete multiscale pyramid.""" + engine = S2ResamplingEngine() + + # Start with 32x32 data + original_size = 32 + data = np.random.rand(original_size, original_size).astype(np.float32) * 1000 + coords = { + 'y': np.arange(original_size), + 'x': np.arange(original_size) + } + + da = xr.DataArray(data, dims=['y', 'x'], coords=coords) + + # Create pyramid levels: 32x32 -> 16x16 -> 8x8 -> 4x4 -> 2x2 -> 1x1 + levels = [] + current_data = da + current_size = original_size + + while current_size >= 2: + next_size = current_size // 2 + downsampled = engine.downsample_variable( + current_data, next_size, next_size, 'reflectance' + ) + levels.append(downsampled) + current_data = downsampled + current_size = next_size + + # Verify pyramid structure + expected_sizes = [16, 8, 4, 2, 1] + for i, level in enumerate(levels): + expected_size = expected_sizes[i] + assert level.shape == (expected_size, expected_size) + + # Verify that values are reasonable (not NaN, not extreme) + for level in levels: + assert not np.isnan(level.values).any() + assert np.all(level.values >= 0) + + def test_mixed_variable_types_processing(self): + """Test processing different variable types together.""" + engine = S2ResamplingEngine() + + # Create base 4x4 data + size = 4 + coords = {'y': np.arange(size), 'x': np.arange(size)} + + # Create different variable types + reflectance_data = xr.DataArray( + np.random.rand(size, size) * 1000, + dims=['y', 'x'], coords=coords + ) + + classification_data = xr.DataArray( + np.random.randint(0, 10, (size, size)), + dims=['y', 'x'], coords=coords + ) + + quality_data = xr.DataArray( + np.random.randint(0, 2, (size, size)), + dims=['y', 'x'], coords=coords + ) + + # Process each with appropriate method + results = {} + for var_name, var_data, var_type in [ + ('b04', reflectance_data, 'reflectance'), + ('scl', classification_data, 'classification'), + ('quality_b04', quality_data, 'quality_mask') + ]: + results[var_name] = engine.downsample_variable( + var_data, 2, 2, var_type + ) + + # Verify all results have same dimensions + for result in results.values(): + assert result.shape == (2, 2) + + # Verify coordinate consistency + y_coords = results['b04'].coords['y'] + x_coords = results['b04'].coords['x'] + + for result in results.values(): + np.testing.assert_array_equal(result.coords['y'].values, y_coords.values) + np.testing.assert_array_equal(result.coords['x'].values, x_coords.values) From 4fc14a39c0e326c9aa45c4f4c79c5c68625715d3 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 14:52:55 +0200 Subject: [PATCH 12/83] feat: enhance S2 data consolidator with comprehensive extraction methods and testing framework --- .../s2_optimization/s2_data_consolidator.py | 207 +++++++- .../tests/test_s2_data_consolidator.py | 491 ++++++++++++++++++ 2 files changed, 687 insertions(+), 11 deletions(-) create mode 100644 src/eopf_geozarr/tests/test_s2_data_consolidator.py diff --git a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py index 119dc55e..1bf1d2bd 100644 --- a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py +++ b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py @@ -1,12 +1,16 @@ """ -Data consolidation logic for reorganizing Sentinel-2 structure. +Data consolidation logic for reorganizing S2 structure. """ import xarray as xr -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Optional +from .s2_band_mapping import ( + NATIVE_BANDS, QUALITY_DATA_NATIVE, DETECTOR_FOOTPRINT_NATIVE, + get_bands_for_level, get_quality_data_for_level +) class S2DataConsolidator: - """Consolidates Sentinel-2 data from scattered structure into organized groups.""" + """Consolidates S2 data from scattered structure into organized groups.""" def __init__(self, dt_input: xr.DataTree): self.dt_input = dt_input @@ -28,16 +32,197 @@ def consolidate_all_data(self) -> Tuple[Dict, Dict, Dict]: return self.measurements_data, self.geometry_data, self.meteorology_data def _extract_measurements_data(self) -> None: - """Extract and organize all measurement-related data.""" - # Placeholder for measurement extraction logic - pass + """Extract and organize all measurement-related data by native resolution.""" + + # Initialize resolution groups + for resolution in [10, 20, 60]: + self.measurements_data[resolution] = { + 'bands': {}, + 'quality': {}, + 'detector_footprints': {}, + 'classification': {}, + 'atmosphere': {}, + 'probability': {} + } + + # Extract reflectance bands + if any('/measurements/reflectance' in group for group in self.dt_input.groups): + self._extract_reflectance_bands() + + # Extract quality data + self._extract_quality_data() + + # Extract detector footprints + self._extract_detector_footprints() + + # Extract atmosphere quality + self._extract_atmosphere_data() + + # Extract classification data + self._extract_classification_data() + + # Extract probability data + self._extract_probability_data() + + def _extract_reflectance_bands(self) -> None: + """Extract reflectance bands from measurements/reflectance groups.""" + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) # Extract number from 'r10m' + group_path = f'/measurements/reflectance/{resolution}' + + if group_path in self.dt_input.groups: + # Check if this is a multiscale group (has numeric subgroups) + group_node = self.dt_input[group_path] + if hasattr(group_node, 'children') and group_node.children: + # Take level 0 (native resolution) + native_path = f'{group_path}/0' + if native_path in self.dt_input.groups: + ds = self.dt_input[native_path].to_dataset() + else: + ds = group_node.to_dataset() + else: + ds = group_node.to_dataset() + + # Extract only native bands for this resolution + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + self.measurements_data[res_num]['bands'][band] = ds[band] + + def _extract_quality_data(self) -> None: + """Extract quality mask data.""" + quality_base = '/quality/mask' + + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) + group_path = f'{quality_base}/{resolution}' + + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + # Only extract quality for native bands at this resolution + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + self.measurements_data[res_num]['quality'][f'quality_{band}'] = ds[band] + + def _extract_detector_footprints(self) -> None: + """Extract detector footprint data.""" + footprint_base = '/conditions/mask/detector_footprint' + + for resolution in ['r10m', 'r20m', 'r60m']: + res_num = int(resolution[1:-1]) + group_path = f'{footprint_base}/{resolution}' + + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + # Only extract footprints for native bands + native_bands = NATIVE_BANDS.get(res_num, []) + for band in native_bands: + if band in ds.data_vars: + var_name = f'detector_footprint_{band}' + self.measurements_data[res_num]['detector_footprints'][var_name] = ds[band] + + def _extract_atmosphere_data(self) -> None: + """Extract atmosphere quality data (aot, wvp) - native at 20m.""" + atm_base = '/quality/atmosphere' + + # Atmosphere data is native at 20m resolution + group_path = f'{atm_base}/r20m' + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + for var in ['aot', 'wvp']: + if var in ds.data_vars: + self.measurements_data[20]['atmosphere'][var] = ds[var] + + def _extract_classification_data(self) -> None: + """Extract scene classification data - native at 20m.""" + class_base = '/conditions/mask/l2a_classification' + + # Classification is native at 20m + group_path = f'{class_base}/r20m' + if group_path in self.dt_input.groups: + ds = self.dt_input[group_path].to_dataset() + + if 'scl' in ds.data_vars: + self.measurements_data[20]['classification']['scl'] = ds['scl'] + + def _extract_probability_data(self) -> None: + """Extract cloud and snow probability data - native at 20m.""" + prob_base = '/quality/probability/r20m' + + if prob_base in self.dt_input.groups: + ds = self.dt_input[prob_base].to_dataset() + + for var in ['cld', 'snw']: + if var in ds.data_vars: + self.measurements_data[20]['probability'][var] = ds[var] def _extract_geometry_data(self) -> None: - """Extract all geometry-related data into a single group.""" - # Placeholder for geometry extraction logic - pass + """Extract all geometry-related data into single group.""" + geom_base = '/conditions/geometry' + + if geom_base in self.dt_input.groups: + ds = self.dt_input[geom_base].to_dataset() + + # Consolidate all geometry variables + for var_name in ds.data_vars: + self.geometry_data[var_name] = ds[var_name] def _extract_meteorology_data(self) -> None: """Extract meteorological data (CAMS and ECMWF).""" - # Placeholder for meteorology extraction logic - pass + # CAMS data + cams_path = '/conditions/meteorology/cams' + if cams_path in self.dt_input.groups: + ds = self.dt_input[cams_path].to_dataset() + for var_name in ds.data_vars: + self.meteorology_data[f'cams_{var_name}'] = ds[var_name] + + # ECMWF data + ecmwf_path = '/conditions/meteorology/ecmwf' + if ecmwf_path in self.dt_input.groups: + ds = self.dt_input[ecmwf_path].to_dataset() + for var_name in ds.data_vars: + self.meteorology_data[f'ecmwf_{var_name}'] = ds[var_name] + +def create_consolidated_dataset(data_dict: Dict, resolution: int) -> xr.Dataset: + """ + Create a consolidated dataset from categorized data. + + Args: + data_dict: Dictionary with categorized data + resolution: Target resolution in meters + + Returns: + Consolidated xarray Dataset + """ + all_vars = {} + + # Combine all data variables + for category, vars_dict in data_dict.items(): + all_vars.update(vars_dict) + + if not all_vars: + return xr.Dataset() + + # Create dataset + ds = xr.Dataset(all_vars) + + # Set up coordinate system and metadata + if 'x' in ds.coords and 'y' in ds.coords: + # Ensure CRS information is present + if ds.rio.crs is None: + # Try to infer CRS from one of the variables + for var_name, var_data in all_vars.items(): + if hasattr(var_data, 'rio') and var_data.rio.crs: + ds.rio.write_crs(var_data.rio.crs, inplace=True) + break + + # Add resolution metadata + ds.attrs['native_resolution_meters'] = resolution + ds.attrs['processing_level'] = 'L2A' + ds.attrs['product_type'] = 'S2MSI2A' + + return ds \ No newline at end of file diff --git a/src/eopf_geozarr/tests/test_s2_data_consolidator.py b/src/eopf_geozarr/tests/test_s2_data_consolidator.py new file mode 100644 index 00000000..3f876e65 --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_data_consolidator.py @@ -0,0 +1,491 @@ +"""Tests for S2 data consolidator module.""" + +import pytest +import numpy as np +import xarray as xr +from unittest.mock import Mock, MagicMock +from typing import Dict, List, Tuple, Any + +from eopf_geozarr.s2_optimization.s2_data_consolidator import ( + S2DataConsolidator, + create_consolidated_dataset, +) + + +class TestS2DataConsolidator: + """Test S2DataConsolidator class.""" + + @pytest.fixture + def sample_s2_datatree(self): + """Create a sample S2 DataTree structure for testing.""" + # Create coordinate arrays for different resolutions + x_10m = np.linspace(100000, 200000, 1098) + y_10m = np.linspace(5000000, 5100000, 1098) + x_20m = x_10m[::2] # 549 points + y_20m = y_10m[::2] + x_60m = x_10m[::6] # 183 points + y_60m = y_10m[::6] + time = np.array(['2023-01-15'], dtype='datetime64[ns]') + + # Create sample data arrays + data_10m = np.random.randint(0, 10000, (1, 1098, 1098), dtype=np.uint16) + data_20m = np.random.randint(0, 10000, (1, 549, 549), dtype=np.uint16) + data_60m = np.random.randint(0, 10000, (1, 183, 183), dtype=np.uint16) + + # Create datasets for different resolution groups (using lowercase band names) + ds_10m = xr.Dataset({ + 'b02': (['time', 'y', 'x'], data_10m), + 'b03': (['time', 'y', 'x'], data_10m.copy()), + 'b04': (['time', 'y', 'x'], data_10m.copy()), + 'b08': (['time', 'y', 'x'], data_10m.copy()), + }, coords={'time': time, 'x': x_10m, 'y': y_10m}) + + ds_20m = xr.Dataset({ + 'b05': (['time', 'y', 'x'], data_20m), + 'b06': (['time', 'y', 'x'], data_20m.copy()), + 'b07': (['time', 'y', 'x'], data_20m.copy()), + 'b8a': (['time', 'y', 'x'], data_20m.copy()), + 'b11': (['time', 'y', 'x'], data_20m.copy()), + 'b12': (['time', 'y', 'x'], data_20m.copy()), + 'aot': (['time', 'y', 'x'], data_20m.copy()), # atmosphere + 'wvp': (['time', 'y', 'x'], data_20m.copy()), + 'scl': (['time', 'y', 'x'], data_20m.copy()), # classification + 'cld': (['time', 'y', 'x'], data_20m.copy()), # probability + 'snw': (['time', 'y', 'x'], data_20m.copy()), + }, coords={'time': time, 'x': x_20m, 'y': y_20m}) + + ds_60m = xr.Dataset({ + 'b01': (['time', 'y', 'x'], data_60m), + 'b09': (['time', 'y', 'x'], data_60m.copy()), + }, coords={'time': time, 'x': x_60m, 'y': y_60m}) + + # Create quality datasets (using lowercase band names) + quality_10m = xr.Dataset({ + 'b02': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), + 'b03': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), + 'b04': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), + 'b08': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), + }, coords={'time': time, 'x': x_10m, 'y': y_10m}) + + # Create detector footprint datasets (using lowercase band names) + detector_10m = xr.Dataset({ + 'b02': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), + 'b03': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), + 'b04': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), + 'b08': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), + }, coords={'time': time, 'x': x_10m, 'y': y_10m}) + + # Create geometry data + geometry_ds = xr.Dataset({ + 'solar_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 549, 549))), + 'solar_azimuth_angle': (['time', 'y', 'x'], np.random.uniform(0, 360, (1, 549, 549))), + 'view_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 549, 549))), + 'view_azimuth_angle': (['time', 'y', 'x'], np.random.uniform(0, 360, (1, 549, 549))), + }, coords={'time': time, 'x': x_20m, 'y': y_20m}) + + # Create meteorology data + cams_ds = xr.Dataset({ + 'total_ozone': (['time', 'y', 'x'], np.random.uniform(200, 400, (1, 183, 183))), + 'relative_humidity': (['time', 'y', 'x'], np.random.uniform(0, 100, (1, 183, 183))), + }, coords={'time': time, 'x': x_60m, 'y': y_60m}) + + ecmwf_ds = xr.Dataset({ + 'temperature': (['time', 'y', 'x'], np.random.uniform(250, 320, (1, 183, 183))), + 'pressure': (['time', 'y', 'x'], np.random.uniform(950, 1050, (1, 183, 183))), + }, coords={'time': time, 'x': x_60m, 'y': y_60m}) + + # Build the mock DataTree structure + mock_dt = MagicMock() + mock_dt.groups = { + '/measurements/reflectance/r10m': Mock(), + '/measurements/reflectance/r20m': Mock(), + '/measurements/reflectance/r60m': Mock(), + '/quality/mask/r10m': Mock(), + '/quality/mask/r20m': Mock(), + '/quality/mask/r60m': Mock(), + '/conditions/mask/detector_footprint/r10m': Mock(), + '/conditions/mask/detector_footprint/r20m': Mock(), + '/conditions/mask/detector_footprint/r60m': Mock(), + '/quality/atmosphere/r20m': Mock(), + '/conditions/mask/l2a_classification/r20m': Mock(), + '/quality/probability/r20m': Mock(), + '/conditions/geometry': Mock(), + '/conditions/meteorology/cams': Mock(), + '/conditions/meteorology/ecmwf': Mock(), + } + + # Mock the dataset access + def mock_getitem(self, path): + mock_node = MagicMock() + if 'r10m' in path: + if 'reflectance' in path: + mock_node.to_dataset.return_value = ds_10m + elif 'quality/mask' in path: + mock_node.to_dataset.return_value = quality_10m + elif 'detector_footprint' in path: + mock_node.to_dataset.return_value = detector_10m + elif 'r20m' in path: + if 'reflectance' in path: + mock_node.to_dataset.return_value = ds_20m + elif 'atmosphere' in path: + mock_node.to_dataset.return_value = ds_20m[['aot', 'wvp']] + elif 'classification' in path: + mock_node.to_dataset.return_value = ds_20m[['scl']] + elif 'probability' in path: + mock_node.to_dataset.return_value = ds_20m[['cld', 'snw']] + elif 'r60m' in path: + if 'reflectance' in path: + mock_node.to_dataset.return_value = ds_60m + elif 'geometry' in path: + mock_node.to_dataset.return_value = geometry_ds + elif 'cams' in path: + mock_node.to_dataset.return_value = cams_ds + elif 'ecmwf' in path: + mock_node.to_dataset.return_value = ecmwf_ds + + return mock_node + + mock_dt.__getitem__ = mock_getitem + return mock_dt + + def test_init(self, sample_s2_datatree): + """Test consolidator initialization.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + + assert consolidator.dt_input == sample_s2_datatree + assert consolidator.measurements_data == {} + assert consolidator.geometry_data == {} + assert consolidator.meteorology_data == {} + + def test_consolidate_all_data(self, sample_s2_datatree): + """Test complete data consolidation.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + measurements, geometry, meteorology = consolidator.consolidate_all_data() + + # Check that all three categories are returned + assert isinstance(measurements, dict) + assert isinstance(geometry, dict) + assert isinstance(meteorology, dict) + + # Check resolution groups in measurements + assert 10 in measurements + assert 20 in measurements + assert 60 in measurements + + # Check data categories exist + for resolution in [10, 20, 60]: + assert 'bands' in measurements[resolution] + assert 'quality' in measurements[resolution] + assert 'detector_footprints' in measurements[resolution] + assert 'classification' in measurements[resolution] + assert 'atmosphere' in measurements[resolution] + assert 'probability' in measurements[resolution] + + def test_extract_reflectance_bands(self, sample_s2_datatree): + """Test reflectance band extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Check 10m bands + assert 'b02' in consolidator.measurements_data[10]['bands'] + assert 'b03' in consolidator.measurements_data[10]['bands'] + assert 'b04' in consolidator.measurements_data[10]['bands'] + assert 'b08' in consolidator.measurements_data[10]['bands'] + + # Check 20m bands + assert 'b05' in consolidator.measurements_data[20]['bands'] + assert 'b06' in consolidator.measurements_data[20]['bands'] + assert 'b11' in consolidator.measurements_data[20]['bands'] + assert 'b12' in consolidator.measurements_data[20]['bands'] + + # Check 60m bands + assert 'b01' in consolidator.measurements_data[60]['bands'] + assert 'b09' in consolidator.measurements_data[60]['bands'] + + def test_extract_quality_data(self, sample_s2_datatree): + """Test quality data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Check quality data exists for native bands + assert 'quality_b02' in consolidator.measurements_data[10]['quality'] + assert 'quality_b03' in consolidator.measurements_data[10]['quality'] + + def test_extract_detector_footprints(self, sample_s2_datatree): + """Test detector footprint extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Check detector footprint data + assert 'detector_footprint_b02' in consolidator.measurements_data[10]['detector_footprints'] + assert 'detector_footprint_b03' in consolidator.measurements_data[10]['detector_footprints'] + + def test_extract_atmosphere_data(self, sample_s2_datatree): + """Test atmosphere data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Atmosphere data should be at 20m resolution + assert 'aot' in consolidator.measurements_data[20]['atmosphere'] + assert 'wvp' in consolidator.measurements_data[20]['atmosphere'] + + def test_extract_classification_data(self, sample_s2_datatree): + """Test classification data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Classification should be at 20m resolution + assert 'scl' in consolidator.measurements_data[20]['classification'] + + def test_extract_probability_data(self, sample_s2_datatree): + """Test probability data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_measurements_data() + + # Probability data should be at 20m resolution + assert 'cld' in consolidator.measurements_data[20]['probability'] + assert 'snw' in consolidator.measurements_data[20]['probability'] + + def test_extract_geometry_data(self, sample_s2_datatree): + """Test geometry data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_geometry_data() + + # Check that geometry variables are extracted + assert 'solar_zenith_angle' in consolidator.geometry_data + assert 'solar_azimuth_angle' in consolidator.geometry_data + assert 'view_zenith_angle' in consolidator.geometry_data + assert 'view_azimuth_angle' in consolidator.geometry_data + + def test_extract_meteorology_data(self, sample_s2_datatree): + """Test meteorology data extraction.""" + consolidator = S2DataConsolidator(sample_s2_datatree) + consolidator._extract_meteorology_data() + + # Check CAMS data + assert 'cams_total_ozone' in consolidator.meteorology_data + assert 'cams_relative_humidity' in consolidator.meteorology_data + + # Check ECMWF data + assert 'ecmwf_temperature' in consolidator.meteorology_data + assert 'ecmwf_pressure' in consolidator.meteorology_data + + def test_missing_groups_handling(self): + """Test handling of missing data groups.""" + # Create DataTree with missing groups + mock_dt = MagicMock() + mock_dt.groups = {} # No groups present + + consolidator = S2DataConsolidator(mock_dt) + measurements, geometry, meteorology = consolidator.consolidate_all_data() + + # Should handle missing groups gracefully + assert isinstance(measurements, dict) + assert isinstance(geometry, dict) + assert isinstance(meteorology, dict) + + # Data structures should be initialized but empty + for resolution in [10, 20, 60]: + assert resolution in measurements + for category in ['bands', 'quality', 'detector_footprints', 'classification', 'atmosphere', 'probability']: + assert category in measurements[resolution] + assert len(measurements[resolution][category]) == 0 + + +class TestCreateConsolidatedDataset: + """Test the create_consolidated_dataset function.""" + + @pytest.fixture + def sample_data_dict(self): + """Create sample consolidated data dictionary.""" + # Create coordinate arrays + x = np.linspace(100000, 200000, 100) + y = np.linspace(5000000, 5100000, 100) + time = np.array(['2023-01-15'], dtype='datetime64[ns]') + + # Create sample data arrays + data = np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16) + + return { + 'bands': { + 'b02': xr.DataArray(data, dims=['time', 'y', 'x'], + coords={'time': time, 'x': x, 'y': y}), + 'b03': xr.DataArray(data.copy(), dims=['time', 'y', 'x'], + coords={'time': time, 'x': x, 'y': y}), + }, + 'quality': { + 'quality_b02': xr.DataArray(np.random.randint(0, 2, (1, 100, 100), dtype=np.uint8), + dims=['time', 'y', 'x'], + coords={'time': time, 'x': x, 'y': y}), + }, + 'atmosphere': { + 'aot': xr.DataArray(np.random.uniform(0.1, 0.5, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'x': x, 'y': y}), + } + } + + def test_create_consolidated_dataset_success(self, sample_data_dict): + """Test successful dataset creation.""" + ds = create_consolidated_dataset(sample_data_dict, resolution=10) + + assert isinstance(ds, xr.Dataset) + + # Check that all variables are included + expected_vars = {'b02', 'b03', 'quality_b02', 'aot'} + assert set(ds.data_vars.keys()) == expected_vars + + # Check metadata + assert ds.attrs['native_resolution_meters'] == 10 + assert ds.attrs['processing_level'] == 'L2A' + assert ds.attrs['product_type'] == 'S2MSI2A' + + # Check coordinates + assert 'x' in ds.coords + assert 'y' in ds.coords + assert 'time' in ds.coords + + def test_create_consolidated_dataset_empty_data(self): + """Test dataset creation with empty data.""" + empty_data_dict = {'bands': {}, 'quality': {}, 'atmosphere': {}} + ds = create_consolidated_dataset(empty_data_dict, resolution=20) + + # Should return empty dataset + assert isinstance(ds, xr.Dataset) + assert len(ds.data_vars) == 0 + + def test_create_consolidated_dataset_with_crs(self, sample_data_dict): + """Test dataset creation with CRS information.""" + # Add CRS to one of the data arrays + sample_data_dict['bands']['b02'] = sample_data_dict['bands']['b02'].rio.write_crs('EPSG:32632') + + ds = create_consolidated_dataset(sample_data_dict, resolution=10) + + assert isinstance(ds, xr.Dataset) + # Check that CRS is propagated (assuming rio accessor is available) + if hasattr(ds, 'rio'): + assert ds.rio.crs is not None + + +class TestIntegration: + """Integration tests combining consolidator and dataset creation.""" + + @pytest.fixture + def complete_s2_datatree(self): + """Create a complete S2 DataTree for integration testing.""" + # This would be similar to the fixture in TestS2DataConsolidator + # but with all data present for end-to-end testing + x_10m = np.linspace(100000, 200000, 100) + y_10m = np.linspace(5000000, 5100000, 100) + x_20m = x_10m[::2] + y_20m = y_10m[::2] + time = np.array(['2023-01-15'], dtype='datetime64[ns]') + + # Create complete mock DataTree (simplified for integration test) + mock_dt = MagicMock() + mock_dt.groups = { + '/measurements/reflectance/r10m': Mock(), + '/conditions/geometry': Mock(), + '/conditions/meteorology/cams': Mock(), + } + + # Mock datasets + reflectance_10m = xr.Dataset({ + 'b02': (['time', 'y', 'x'], np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16)), + 'b03': (['time', 'y', 'x'], np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16)), + }, coords={'time': time, 'x': x_10m, 'y': y_10m}) + + geometry_ds = xr.Dataset({ + 'solar_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 50, 50))), + }, coords={'time': time, 'x': x_20m, 'y': y_20m}) + + cams_ds = xr.Dataset({ + 'total_ozone': (['time', 'y', 'x'], np.random.uniform(200, 400, (1, 50, 50))), + }, coords={'time': time, 'x': x_20m, 'y': y_20m}) + + def mock_getitem(self, path): + mock_node = MagicMock() + if '/measurements/reflectance/r10m' in path: + mock_node.to_dataset.return_value = reflectance_10m + elif '/conditions/geometry' in path: + mock_node.to_dataset.return_value = geometry_ds + elif '/conditions/meteorology/cams' in path: + mock_node.to_dataset.return_value = cams_ds + return mock_node + + mock_dt.__getitem__ = mock_getitem + return mock_dt + + def test_end_to_end_consolidation(self, complete_s2_datatree): + """Test complete end-to-end consolidation and dataset creation.""" + # Step 1: Consolidate data + consolidator = S2DataConsolidator(complete_s2_datatree) + measurements, geometry, meteorology = consolidator.consolidate_all_data() + + # Step 2: Create consolidated datasets for each resolution + consolidated_datasets = {} + for resolution in [10, 20, 60]: + if measurements[resolution]: # Only create if data exists + ds = create_consolidated_dataset(measurements[resolution], resolution) + if len(ds.data_vars) > 0: # Only keep non-empty datasets + consolidated_datasets[resolution] = ds + + # Step 3: Verify results + assert len(consolidated_datasets) > 0 + + # Check that 10m data is present (from our mock) + if 10 in consolidated_datasets: + ds_10m = consolidated_datasets[10] + assert 'b02' in ds_10m.data_vars + assert 'b03' in ds_10m.data_vars + assert ds_10m.attrs['native_resolution_meters'] == 10 + + # Verify geometry data + assert len(geometry) > 0 + geometry_ds = create_consolidated_dataset({'geometry': geometry}, resolution=20) + if len(geometry_ds.data_vars) > 0: + assert 'solar_zenith_angle' in geometry_ds.data_vars + + # Verify meteorology data + assert len(meteorology) > 0 + met_ds = create_consolidated_dataset({'meteorology': meteorology}, resolution=60) + if len(met_ds.data_vars) > 0: + assert 'cams_total_ozone' in met_ds.data_vars + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_create_dataset_with_inconsistent_coordinates(self): + """Test dataset creation with inconsistent coordinate systems.""" + # Create data with mismatched coordinates + x1 = np.linspace(100000, 200000, 50) + y1 = np.linspace(5000000, 5100000, 50) + x2 = np.linspace(100000, 200000, 100) # Different size + y2 = np.linspace(5000000, 5100000, 100) + time = np.array(['2023-01-15'], dtype='datetime64[ns]') + + inconsistent_data = { + 'bands': { + 'b02': xr.DataArray(np.random.randint(0, 10000, (1, 50, 50), dtype=np.uint16), + dims=['time', 'y', 'x'], + coords={'time': time, 'x': x1, 'y': y1}), + 'b03': xr.DataArray(np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16), + dims=['time', 'y', 'x'], + coords={'time': time, 'x': x2, 'y': y2}), + } + } + + # Should handle inconsistent coordinates gracefully or raise appropriate error + # The exact behavior depends on xarray's handling of mixed coordinates + try: + ds = create_consolidated_dataset(inconsistent_data, resolution=10) + # If successful, verify it's a valid dataset + assert isinstance(ds, xr.Dataset) + except (ValueError, KeyError) as e: + # Expected error due to coordinate mismatch + assert "coordinate" in str(e).lower() or "dimension" in str(e).lower() + + +if __name__ == "__main__": + pytest.main([__file__]) From 3d1ea51a235d757e54d26cf552e5d8a074876949 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 17:42:16 +0200 Subject: [PATCH 13/83] Add comprehensive tests for S2MultiscalePyramid class - Implement unit tests for initialization, pyramid levels structure, chunk alignment, and shard dimension calculations. - Create tests for encoding generation, dataset writing, and level dataset creation with various resolutions. - Include integration tests for realistic measurements data and edge cases handling. - Ensure coverage for time separation logic and coordinate preservation during processing. --- .../s2_optimization/s2_multiscale.py | 467 +++++++++- src/eopf_geozarr/tests/test_s2_multiscale.py | 842 ++++++++++++++++++ 2 files changed, 1303 insertions(+), 6 deletions(-) create mode 100644 src/eopf_geozarr/tests/test_s2_multiscale.py diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index ca1fe78a..4c2e4144 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -1,18 +1,37 @@ """ -Multiscale pyramid creation for optimized Sentinel-2 structure. +Multiscale pyramid creation for optimized S2 structure. """ -from typing import Dict +import numpy as np import xarray as xr +from typing import Dict, List, Tuple +from .s2_resampling import S2ResamplingEngine, determine_variable_type +from .s2_band_mapping import get_bands_for_level, get_quality_data_for_level class S2MultiscalePyramid: - """Creates multiscale pyramids for consolidated Sentinel-2 data.""" + """Creates multiscale pyramids for consolidated S2 data.""" def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk + self.resampler = S2ResamplingEngine() + + # Define pyramid levels: resolution in meters + self.pyramid_levels = { + 0: 10, # Level 0: 10m (native for b02,b03,b04,b08) + 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) + 2: 60, # Level 2: 60m (3x downsampling from 20m) + 3: 120, # Level 3: 120m (2x downsampling from 60m) + 4: 240, # Level 4: 240m (2x downsampling from 120m) + 5: 480, # Level 5: 480m (2x downsampling from 240m) + 6: 960 # Level 6: 960m (2x downsampling from 480m) + } - def create_multiscale_measurements(self, measurements_by_resolution: Dict[int, Dict], output_path: str) -> Dict[int, xr.Dataset]: + def create_multiscale_measurements( + self, + measurements_by_resolution: Dict[int, Dict], + output_path: str + ) -> Dict[int, xr.Dataset]: """ Create multiscale pyramid from consolidated measurements. @@ -23,5 +42,441 @@ def create_multiscale_measurements(self, measurements_by_resolution: Dict[int, D Returns: Dictionary of datasets by pyramid level """ - # Placeholder for multiscale creation logic - pass + pyramid_datasets = {} + + # Create each pyramid level + for level, target_resolution in self.pyramid_levels.items(): + print(f"Creating pyramid level {level} ({target_resolution}m)...") + + dataset = self._create_level_dataset( + level, target_resolution, measurements_by_resolution + ) + + if dataset and len(dataset.data_vars) > 0: + pyramid_datasets[level] = dataset + + # Write this level + level_path = f"{output_path}/measurements/{level}" + self._write_level_dataset(dataset, level_path, level) + + return pyramid_datasets + + def _create_level_dataset( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict[int, Dict] + ) -> xr.Dataset: + """Create dataset for a specific pyramid level.""" + + if level == 0: + # Level 0: Only native 10m data + return self._create_level_0_dataset(measurements_by_resolution) + elif level == 1: + # Level 1: All data at 20m (native + downsampled from 10m) + return self._create_level_1_dataset(measurements_by_resolution) + elif level == 2: + # Level 2: All data at 60m (native + downsampled from 20m) + return self._create_level_2_dataset(measurements_by_resolution) + else: + # Levels 3+: Downsample from level 2 + return self._create_downsampled_dataset( + level, target_resolution, measurements_by_resolution + ) + + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 0 dataset with only native 10m data.""" + if 10 not in measurements_by_resolution: + return xr.Dataset() + + data_10m = measurements_by_resolution[10] + all_vars = {} + + # Add only native 10m bands and their associated data + for category, vars_dict in data_10m.items(): + all_vars.update(vars_dict) + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 0 + dataset.attrs['resolution_meters'] = 10 + + return dataset + + def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 1 dataset with all data at 20m resolution.""" + all_vars = {} + reference_coords = None + + # Start with native 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 20m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + 'x': first_var.coords['x'], + 'y': first_var.coords['y'] + } + + # Add downsampled 10m data + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + if reference_coords: + # Downsample to match 20m grid + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 1 + dataset.attrs['resolution_meters'] = 20 + + return dataset + + def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 2 dataset with all data at 60m resolution.""" + all_vars = {} + reference_coords = None + + # Start with native 60m data + if 60 in measurements_by_resolution: + data_60m = measurements_by_resolution[60] + for category, vars_dict in data_60m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 60m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + 'x': first_var.coords['x'], + 'y': first_var.coords['y'] + } + + # Add downsampled 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + + for category, vars_dict in data_20m.items(): + for var_name, var_data in vars_dict.items(): + if reference_coords: + # Downsample to match 60m grid + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 2 + dataset.attrs['resolution_meters'] = 60 + + return dataset + + def _create_downsampled_dataset( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict + ) -> xr.Dataset: + """Create downsampled dataset for levels 2+.""" + # Start from level 1 data (20m) and downsample + level_1_dataset = self._create_level_1_dataset(measurements_by_resolution) + + if len(level_1_dataset.data_vars) == 0: + return xr.Dataset() + + # Calculate target dimensions (downsample by factor of 2^(level-1)) + downsample_factor = 2 ** (level - 1) + + # Get reference dimensions from level 1 + ref_var = next(iter(level_1_dataset.data_vars.values())) + current_height, current_width = ref_var.shape[-2:] + target_height = current_height // downsample_factor + target_width = current_width // downsample_factor + + downsampled_vars = {} + + for var_name, var_data in level_1_dataset.data_vars.items(): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + downsampled_vars[var_name] = downsampled + + # Create dataset + dataset = xr.Dataset(downsampled_vars) + dataset.attrs['pyramid_level'] = level + dataset.attrs['resolution_meters'] = target_resolution + + return dataset + + def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> None: + """ + Write a pyramid level dataset to storage with xy-aligned sharding. + + Ensures single file per variable per time point when time dimension exists. + """ + # Create encoding with xy-aligned sharding + encoding = self._create_level_encoding(dataset, level) + + # Check if we have time dimension for single file per time handling + has_time_dim = any('time' in str(var.dims) for var in dataset.data_vars.values()) + + if has_time_dim and self._should_separate_time_files(dataset): + # Write each time slice separately to ensure single file per variable per time + self._write_time_separated_dataset(dataset, level_path, level, encoding) + else: + # Write as single dataset with xy-aligned sharding + print(f" Writing level {level} to {level_path} (xy-aligned sharding)") + dataset.to_zarr( + level_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=encoding + ) + + def _should_separate_time_files(self, dataset: xr.Dataset) -> bool: + """Determine if time files should be separated for single file per variable per time.""" + for var in dataset.data_vars.values(): + if 'time' in var.dims and len(var.coords.get('time', [])) > 1: + return True + return False + + def _write_time_separated_dataset( + self, + dataset: xr.Dataset, + level_path: str, + level: int, + encoding: Dict + ) -> None: + """Write dataset with separate files for each time point.""" + import os + + # Get time coordinate + time_coord = None + for var in dataset.data_vars.values(): + if 'time' in var.dims: + time_coord = var.coords['time'] + break + + if time_coord is None: + # Fallback to regular writing if no time found + print(f" Writing level {level} to {level_path} (no time coord found)") + dataset.to_zarr( + level_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=encoding + ) + return + + print(f" Writing level {level} with time separation to {level_path}") + + # Write each time slice separately + for t_idx, time_val in enumerate(time_coord.values): + time_slice = dataset.isel(time=t_idx) + time_path = os.path.join(level_path, f"time_{t_idx:04d}") + + # Update encoding for time slice (remove time dimension) + time_encoding = self._update_encoding_for_time_slice(encoding, time_slice) + + print(f" Writing time slice {t_idx} to {time_path}") + time_slice.to_zarr( + time_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=time_encoding + ) + + def _update_encoding_for_time_slice(self, encoding: Dict, time_slice: xr.Dataset) -> Dict: + """Update encoding configuration for time slice data.""" + updated_encoding = {} + + for var_name, var_encoding in encoding.items(): + if var_name in time_slice.data_vars: + var_data = time_slice[var_name] + + # Update chunks and shards for time slice (remove time dimension) + if 'chunks' in var_encoding and len(var_encoding['chunks']) > 2: + # Remove time dimension from chunks (first dimension) + updated_chunks = var_encoding['chunks'][1:] + updated_encoding[var_name] = var_encoding.copy() + updated_encoding[var_name]['chunks'] = updated_chunks + + # Update shards if present + if 'shards' in var_encoding and len(var_encoding['shards']) > 2: + updated_shards = var_encoding['shards'][1:] + updated_encoding[var_name]['shards'] = updated_shards + else: + updated_encoding[var_name] = var_encoding + else: + # Coordinate or other variable + updated_encoding[var_name] = encoding[var_name] + + return updated_encoding + + def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: + """Create optimized encoding for a pyramid level with xy-aligned sharding.""" + encoding = {} + + # Calculate level-appropriate chunk sizes + chunk_size = max(256, self.spatial_chunk // (2 ** level)) + + for var_name, var_data in dataset.data_vars.items(): + if var_data.ndim >= 2: + height, width = var_data.shape[-2:] + + # Ensure x/y alignment: adjust chunk sizes to align with sharding + chunk_y = self._align_chunk_to_xy_dimensions(chunk_size, height) + chunk_x = self._align_chunk_to_xy_dimensions(chunk_size, width) + + if var_data.ndim == 3: + # Single file per variable per time: chunk time dimension to 1 + chunks = (1, chunk_y, chunk_x) + else: + chunks = (chunk_y, chunk_x) + else: + chunks = (min(chunk_size, var_data.shape[0]),) + + # Configure encoding + var_encoding = { + 'chunks': chunks, + 'compressor': 'default' + } + + # Add xy-aligned sharding if enabled + if self.enable_sharding and var_data.ndim >= 2: + shard_dims = self._calculate_xy_aligned_shard_dimensions(var_data.shape, chunks) + var_encoding['shards'] = shard_dims + + encoding[var_name] = var_encoding + + # Add coordinate encoding + for coord_name in dataset.coords: + encoding[coord_name] = {'compressor': None} + + return encoding + + def _align_chunk_to_xy_dimensions(self, chunk_size: int, dimension_size: int) -> int: + """ + Align chunk size to be compatible with x/y dimension sharding requirements. + + Args: + chunk_size: Requested chunk size + dimension_size: Total size of the dimension + + Returns: + Aligned chunk size that works well with sharding + """ + if chunk_size >= dimension_size: + return dimension_size + + # Find a good divisor that's close to the requested size + best_chunk = chunk_size + best_remainder = dimension_size % chunk_size + + # Try nearby values to find better alignment + search_range = min(50, chunk_size // 4) + for offset in range(-search_range, search_range + 1): + candidate = chunk_size + offset + if candidate > 0 and candidate <= dimension_size: + remainder = dimension_size % candidate + if remainder < best_remainder or (remainder == best_remainder and candidate > best_chunk): + best_chunk = candidate + best_remainder = remainder + + return best_chunk + + def _calculate_xy_aligned_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) -> Tuple: + """ + Calculate shard dimensions for Zarr v3 sharding with x/y alignment. + + Ensures shards are properly aligned with spatial dimensions (x, y) + and maintains single file per variable per time point. + """ + shard_dims = [] + + for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): + # Special handling for different dimensions + if i == 0 and len(data_shape) == 3: + # First dimension in 3D data (time) - use single time slice per shard + shard_dim = 1 + elif i >= len(data_shape) - 2: + # Last two dimensions (y, x) - ensure proper spatial alignment + shard_dim = self._calculate_spatial_shard_dim(dim_size, chunk_size) + else: + # Other dimensions - standard calculation + shard_dim = self._calculate_standard_shard_dim(dim_size, chunk_size) + + shard_dims.append(shard_dim) + + return tuple(shard_dims) + + def _calculate_spatial_shard_dim(self, dim_size: int, chunk_size: int) -> int: + """Calculate shard dimension for spatial dimensions (x, y).""" + if chunk_size >= dim_size: + return dim_size + + # For spatial dimensions, align shard boundaries with chunk boundaries + # Use multiple of chunk_size that provides good balance + num_chunks = dim_size // chunk_size + if num_chunks >= 4: + # Use 4 chunks per shard if possible for good balance + shard_dim = min(4 * chunk_size, dim_size) + elif num_chunks >= 2: + # Use 2 chunks per shard as minimum + shard_dim = 2 * chunk_size + else: + # Single chunk per shard + shard_dim = chunk_size + + return shard_dim + + def _calculate_standard_shard_dim(self, dim_size: int, chunk_size: int) -> int: + """Calculate shard dimension for non-spatial dimensions.""" + if chunk_size >= dim_size: + return dim_size + + # For non-spatial dimensions, use standard calculation + num_chunks = dim_size // chunk_size + if num_chunks >= 4: + shard_dim = min(4 * chunk_size, dim_size) + else: + shard_dim = num_chunks * chunk_size + + return shard_dim diff --git a/src/eopf_geozarr/tests/test_s2_multiscale.py b/src/eopf_geozarr/tests/test_s2_multiscale.py new file mode 100644 index 00000000..c6f86e38 --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_multiscale.py @@ -0,0 +1,842 @@ +""" +Tests for S2 multiscale pyramid creation with xy-aligned sharding. +""" + +import numpy as np +import pytest +import tempfile +import shutil +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock +import xarray as xr + +from eopf_geozarr.s2_optimization.s2_multiscale import S2MultiscalePyramid + + +class TestS2MultiscalePyramid: + """Test suite for S2MultiscalePyramid class.""" + + @pytest.fixture + def pyramid(self): + """Create a basic S2MultiscalePyramid instance.""" + return S2MultiscalePyramid(enable_sharding=True, spatial_chunk=1024) + + @pytest.fixture + def sample_dataset(self): + """Create a sample xarray dataset for testing.""" + x = np.linspace(0, 1000, 100) + y = np.linspace(0, 1000, 100) + time = np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[ns]') + + # Create sample variables with different dimensions + b02 = xr.DataArray( + np.random.randint(0, 4000, (2, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y, 'x': x}, + name='b02' + ) + + b05 = xr.DataArray( + np.random.randint(0, 4000, (2, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y, 'x': x}, + name='b05' + ) + + scl = xr.DataArray( + np.random.randint(0, 11, (2, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y, 'x': x}, + name='scl' + ) + + dataset = xr.Dataset({ + 'b02': b02, + 'b05': b05, + 'scl': scl + }) + + return dataset + + @pytest.fixture + def sample_measurements_by_resolution(self): + """Create sample measurements organized by resolution.""" + x_10m = np.linspace(0, 1000, 200) + y_10m = np.linspace(0, 1000, 200) + x_20m = np.linspace(0, 1000, 100) + y_20m = np.linspace(0, 1000, 100) + x_60m = np.linspace(0, 1000, 50) + y_60m = np.linspace(0, 1000, 50) + time = np.array(['2023-01-01'], dtype='datetime64[ns]') + + # 10m data + b02_10m = xr.DataArray( + np.random.randint(0, 4000, (1, 200, 200)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + name='b02' + ) + + # 20m data + b05_20m = xr.DataArray( + np.random.randint(0, 4000, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + name='b05' + ) + + scl_20m = xr.DataArray( + np.random.randint(0, 11, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + name='scl' + ) + + # 60m data + b01_60m = xr.DataArray( + np.random.randint(0, 4000, (1, 50, 50)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_60m, 'x': x_60m}, + name='b01' + ) + + return { + 10: { + 'reflectance': {'b02': b02_10m} + }, + 20: { + 'reflectance': {'b05': b05_20m}, + 'quality': {'scl': scl_20m} + }, + 60: { + 'reflectance': {'b01': b01_60m} + } + } + + @pytest.fixture + def temp_dir(self): + """Create a temporary directory for testing.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + def test_init(self): + """Test S2MultiscalePyramid initialization.""" + pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=512) + + assert pyramid.enable_sharding is True + assert pyramid.spatial_chunk == 512 + assert hasattr(pyramid, 'resampler') + assert len(pyramid.pyramid_levels) == 7 + assert pyramid.pyramid_levels[0] == 10 + assert pyramid.pyramid_levels[1] == 20 + assert pyramid.pyramid_levels[2] == 60 + + def test_pyramid_levels_structure(self, pyramid): + """Test the pyramid levels structure.""" + expected_levels = { + 0: 10, # Level 0: 10m + 1: 20, # Level 1: 20m + 2: 60, # Level 2: 60m + 3: 120, # Level 3: 120m + 4: 240, # Level 4: 240m + 5: 480, # Level 5: 480m + 6: 960 # Level 6: 960m + } + + assert pyramid.pyramid_levels == expected_levels + + def test_align_chunk_to_xy_dimensions(self, pyramid): + """Test chunk alignment for x/y dimensions.""" + # Test case 1: chunk size smaller than dimension + aligned = pyramid._align_chunk_to_xy_dimensions(256, 1000) + assert aligned > 0 + assert aligned <= 1000 + assert 1000 % aligned <= 1000 % 256 # Should have better or equal alignment + + # Test case 2: chunk size larger than dimension + aligned = pyramid._align_chunk_to_xy_dimensions(2000, 1000) + assert aligned == 1000 + + # Test case 3: exact divisor + aligned = pyramid._align_chunk_to_xy_dimensions(250, 1000) + assert aligned > 0 + assert aligned <= 1000 + + def test_calculate_spatial_shard_dim(self, pyramid): + """Test spatial shard dimension calculation.""" + # Test case 1: chunk size smaller than dimension, multiple chunks possible + shard_dim = pyramid._calculate_spatial_shard_dim(1000, 200) + assert shard_dim == 800 # 4 chunks * 200 + + # Test case 2: chunk size larger than dimension + shard_dim = pyramid._calculate_spatial_shard_dim(100, 200) + assert shard_dim == 100 + + # Test case 3: only 2 chunks possible + shard_dim = pyramid._calculate_spatial_shard_dim(600, 250) + assert shard_dim == 500 # 2 chunks * 250 + + # Test case 4: single chunk + shard_dim = pyramid._calculate_spatial_shard_dim(300, 250) + assert shard_dim == 250 + + def test_calculate_standard_shard_dim(self, pyramid): + """Test standard shard dimension calculation.""" + # Test case 1: multiple chunks possible + shard_dim = pyramid._calculate_standard_shard_dim(1000, 200) + assert shard_dim == 800 # 4 chunks * 200 + + # Test case 2: chunk size larger than dimension + shard_dim = pyramid._calculate_standard_shard_dim(100, 200) + assert shard_dim == 100 + + # Test case 3: few chunks + shard_dim = pyramid._calculate_standard_shard_dim(600, 250) + assert shard_dim == 500 # 2 chunks * 250 + + def test_calculate_xy_aligned_shard_dimensions(self, pyramid): + """Test xy-aligned shard dimensions calculation.""" + # Test 3D data (time, y, x) + data_shape = (5, 1000, 1000) + chunks = (1, 256, 256) + + shard_dims = pyramid._calculate_xy_aligned_shard_dimensions(data_shape, chunks) + + assert len(shard_dims) == 3 + assert shard_dims[0] == 1 # Time dimension should be 1 + assert shard_dims[1] > 0 # Y dimension + assert shard_dims[2] > 0 # X dimension + + # Test 2D data (y, x) + data_shape = (1000, 1000) + chunks = (256, 256) + + shard_dims = pyramid._calculate_xy_aligned_shard_dimensions(data_shape, chunks) + + assert len(shard_dims) == 2 + assert shard_dims[0] > 0 # Y dimension + assert shard_dims[1] > 0 # X dimension + + def test_create_level_encoding(self, pyramid, sample_dataset): + """Test level encoding creation with xy-aligned sharding.""" + encoding = pyramid._create_level_encoding(sample_dataset, level=1) + + # Check that encoding is created for all variables + for var_name in sample_dataset.data_vars: + assert var_name in encoding + var_encoding = encoding[var_name] + + # Check basic encoding structure + assert 'chunks' in var_encoding + assert 'compressor' in var_encoding + + # Check sharding is included when enabled + if pyramid.enable_sharding: + assert 'shards' in var_encoding + + # Check coordinate encoding + for coord_name in sample_dataset.coords: + if coord_name in encoding: + assert encoding[coord_name]['compressor'] is None + + def test_create_level_encoding_time_chunking(self, pyramid, sample_dataset): + """Test that time dimension is chunked to 1 for single file per time.""" + encoding = pyramid._create_level_encoding(sample_dataset, level=0) + + for var_name in sample_dataset.data_vars: + if sample_dataset[var_name].ndim == 3: # 3D variable with time + chunks = encoding[var_name]['chunks'] + assert chunks[0] == 1 # Time dimension should be chunked to 1 + + def test_should_separate_time_files(self, pyramid): + """Test time file separation detection.""" + # Create dataset with multiple time points + time = np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[ns]') + x = np.linspace(0, 100, 10) + y = np.linspace(0, 100, 10) + + data_multi_time = xr.DataArray( + np.random.rand(2, 10, 10), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y, 'x': x} + ) + + dataset_multi_time = xr.Dataset({'var1': data_multi_time}) + assert pyramid._should_separate_time_files(dataset_multi_time) is True + + # Create dataset with single time point + data_single_time = xr.DataArray( + np.random.rand(1, 10, 10), + dims=['time', 'y', 'x'], + coords={'time': time[:1], 'y': y, 'x': x} + ) + + dataset_single_time = xr.Dataset({'var1': data_single_time}) + assert pyramid._should_separate_time_files(dataset_single_time) is False + + # Create dataset with no time dimension + data_no_time = xr.DataArray( + np.random.rand(10, 10), + dims=['y', 'x'], + coords={'y': y, 'x': x} + ) + + dataset_no_time = xr.Dataset({'var1': data_no_time}) + assert pyramid._should_separate_time_files(dataset_no_time) is False + + def test_update_encoding_for_time_slice(self, pyramid): + """Test encoding update for time slices.""" + # Original encoding with 3D chunks + original_encoding = { + 'var1': { + 'chunks': (1, 100, 100), + 'shards': (1, 200, 200), + 'compressor': 'default' + }, + 'x': {'compressor': None}, + 'y': {'compressor': None} + } + + # Create a time slice dataset + x = np.linspace(0, 100, 100) + y = np.linspace(0, 100, 100) + + time_slice = xr.Dataset({ + 'var1': xr.DataArray( + np.random.rand(100, 100), + dims=['y', 'x'], + coords={'y': y, 'x': x} + ) + }) + + updated_encoding = pyramid._update_encoding_for_time_slice(original_encoding, time_slice) + + # Check that time dimension is removed from chunks and shards + assert updated_encoding['var1']['chunks'] == (100, 100) + assert updated_encoding['var1']['shards'] == (200, 200) + assert updated_encoding['var1']['compressor'] == 'default' + + # Check coordinates are preserved + assert updated_encoding['x']['compressor'] is None + assert updated_encoding['y']['compressor'] is None + + @patch('builtins.print') + @patch('xarray.Dataset.to_zarr') + def test_write_level_dataset_no_time(self, mock_to_zarr, mock_print, pyramid, sample_dataset, temp_dir): + """Test writing level dataset without time separation.""" + # Create dataset without multiple time points + single_time_dataset = sample_dataset.isel(time=0) + + pyramid._write_level_dataset(single_time_dataset, temp_dir, level=0) + + # Should call to_zarr once (no time separation) + mock_to_zarr.assert_called_once() + args, kwargs = mock_to_zarr.call_args + + assert kwargs['mode'] == 'w' + assert kwargs['consolidated'] is True + assert kwargs['zarr_format'] == 3 + assert 'encoding' in kwargs + + @patch('builtins.print') + def test_write_level_dataset_with_time_separation(self, mock_print, pyramid, sample_dataset, temp_dir): + """Test writing level dataset with time separation.""" + with patch.object(pyramid, '_write_time_separated_dataset') as mock_time_sep: + pyramid._write_level_dataset(sample_dataset, temp_dir, level=0) + + # Should call time separation method + mock_time_sep.assert_called_once() + + def test_create_level_0_dataset(self, pyramid, sample_measurements_by_resolution): + """Test level 0 dataset creation.""" + dataset = pyramid._create_level_0_dataset(sample_measurements_by_resolution) + + assert len(dataset.data_vars) > 0 + assert dataset.attrs['pyramid_level'] == 0 + assert dataset.attrs['resolution_meters'] == 10 + + # Should only contain 10m native data + assert 'b02' in dataset.data_vars + + def test_create_level_0_dataset_no_10m_data(self, pyramid): + """Test level 0 dataset creation with no 10m data.""" + measurements_no_10m = { + 20: {'reflectance': {'b05': Mock()}}, + 60: {'reflectance': {'b01': Mock()}} + } + + dataset = pyramid._create_level_0_dataset(measurements_no_10m) + assert len(dataset.data_vars) == 0 + + @patch.object(S2MultiscalePyramid, '_create_level_0_dataset') + @patch.object(S2MultiscalePyramid, '_create_level_1_dataset') + @patch.object(S2MultiscalePyramid, '_create_level_2_dataset') + @patch.object(S2MultiscalePyramid, '_create_downsampled_dataset') + def test_create_level_dataset_routing(self, mock_downsampled, mock_level2, mock_level1, mock_level0, pyramid): + """Test that _create_level_dataset routes to correct methods.""" + measurements = {} + + # Test level 0 + pyramid._create_level_dataset(0, 10, measurements) + mock_level0.assert_called_once_with(measurements) + + # Test level 1 + pyramid._create_level_dataset(1, 20, measurements) + mock_level1.assert_called_once_with(measurements) + + # Test level 2 + pyramid._create_level_dataset(2, 60, measurements) + mock_level2.assert_called_once_with(measurements) + + # Test level 3+ + pyramid._create_level_dataset(3, 120, measurements) + mock_downsampled.assert_called_once_with(3, 120, measurements) + + @patch('builtins.print') + @patch.object(S2MultiscalePyramid, '_write_level_dataset') + @patch.object(S2MultiscalePyramid, '_create_level_dataset') + def test_create_multiscale_measurements(self, mock_create, mock_write, mock_print, pyramid, temp_dir): + """Test multiscale measurements creation.""" + # Mock dataset creation + mock_dataset = Mock() + mock_dataset.data_vars = {'b02': Mock()} # Non-empty dataset + mock_create.return_value = mock_dataset + + measurements = {10: {'reflectance': {'b02': Mock()}}} + + result = pyramid.create_multiscale_measurements(measurements, temp_dir) + + # Should create all pyramid levels + assert len(result) == len(pyramid.pyramid_levels) + assert mock_create.call_count == len(pyramid.pyramid_levels) + assert mock_write.call_count == len(pyramid.pyramid_levels) + + @patch('builtins.print') + @patch.object(S2MultiscalePyramid, '_write_level_dataset') + @patch.object(S2MultiscalePyramid, '_create_level_dataset') + def test_create_multiscale_measurements_empty_dataset(self, mock_create, mock_write, mock_print, pyramid, temp_dir): + """Test multiscale measurements creation with empty dataset.""" + # Mock empty dataset creation + mock_dataset = Mock() + mock_dataset.data_vars = {} # Empty dataset + mock_create.return_value = mock_dataset + + measurements = {} + + result = pyramid.create_multiscale_measurements(measurements, temp_dir) + + # Should not include empty datasets + assert len(result) == 0 + assert mock_write.call_count == 0 + + def test_create_level_1_dataset_with_downsampling(self, pyramid): + """Test level 1 dataset creation with downsampling from 10m.""" + # Create mock measurements with both 10m and 20m data + x_20m = np.linspace(0, 1000, 100) + y_20m = np.linspace(0, 1000, 100) + x_10m = np.linspace(0, 1000, 200) + y_10m = np.linspace(0, 1000, 200) + time = np.array(['2023-01-01'], dtype='datetime64[ns]') + + # 20m native data + b05_20m = xr.DataArray( + np.random.randint(0, 4000, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + name='b05' + ) + + # 10m data to be downsampled + b02_10m = xr.DataArray( + np.random.randint(0, 4000, (1, 200, 200)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + name='b02' + ) + + measurements = { + 20: {'reflectance': {'b05': b05_20m}}, + 10: {'reflectance': {'b02': b02_10m}} + } + + with patch.object(pyramid.resampler, 'downsample_variable') as mock_downsample: + # Mock the downsampling to return a properly shaped array + mock_downsampled = xr.DataArray( + np.random.randint(0, 4000, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + name='b02' + ) + mock_downsample.return_value = mock_downsampled + + dataset = pyramid._create_level_1_dataset(measurements) + + # Should call downsampling for 10m data + mock_downsample.assert_called() + + # Should contain both native 20m and downsampled 10m data + assert 'b05' in dataset.data_vars + assert 'b02' in dataset.data_vars + assert dataset.attrs['pyramid_level'] == 1 + assert dataset.attrs['resolution_meters'] == 20 + + def test_create_level_2_dataset_structure(self, pyramid, sample_measurements_by_resolution): + """Test level 2 dataset creation according to optimization plan.""" + dataset = pyramid._create_level_2_dataset(sample_measurements_by_resolution) + + # Check basic structure + assert dataset.attrs['pyramid_level'] == 2 + assert dataset.attrs['resolution_meters'] == 60 + + # Should contain 60m native data + assert 'b01' in dataset.data_vars + + def test_create_level_2_dataset_with_downsampling(self, pyramid): + """Test level 2 dataset creation with 20m data downsampling.""" + # Create measurements with 60m and 20m data + x_60m = np.linspace(0, 1000, 50) + y_60m = np.linspace(0, 1000, 50) + x_20m = np.linspace(0, 1000, 100) + y_20m = np.linspace(0, 1000, 100) + time = np.array(['2023-01-01'], dtype='datetime64[ns]') + + # 60m native data + b01_60m = xr.DataArray( + np.random.randint(0, 4000, (1, 50, 50)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_60m, 'x': x_60m}, + name='b01' + ) + + # 20m data to be downsampled + scl_20m = xr.DataArray( + np.random.randint(0, 11, (1, 100, 100)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + name='scl' + ) + + measurements = { + 60: {'reflectance': {'b01': b01_60m}}, + 20: {'quality': {'scl': scl_20m}} + } + + with patch.object(pyramid.resampler, 'downsample_variable') as mock_downsample: + # Mock the downsampling to return a properly shaped array + mock_downsampled = xr.DataArray( + np.random.randint(0, 11, (1, 50, 50)), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_60m, 'x': x_60m}, + name='scl' + ) + mock_downsample.return_value = mock_downsampled + + dataset = pyramid._create_level_2_dataset(measurements) + + # Should call downsampling for 20m data + mock_downsample.assert_called() + + # Should contain both native 60m and downsampled 20m data + assert 'b01' in dataset.data_vars + assert 'scl' in dataset.data_vars + assert dataset.attrs['pyramid_level'] == 2 + assert dataset.attrs['resolution_meters'] == 60 + + def test_error_handling_invalid_level(self, pyramid): + """Test error handling for invalid pyramid levels.""" + measurements = {} + + # Test with invalid level (should work but return empty dataset if no source data) + dataset = pyramid._create_level_dataset(-1, 5, measurements) + # Should create downsampled dataset (empty in this case) + assert isinstance(dataset, xr.Dataset) + + +class TestS2MultiscalePyramidIntegration: + """Integration tests for S2MultiscalePyramid.""" + + @pytest.fixture + def real_measurements_data(self): + """Create realistic measurements data for integration testing.""" + time = np.array(['2023-06-15T10:30:00'], dtype='datetime64[ns]') + + # 10m resolution data (200x200 pixels) + x_10m = np.linspace(300000, 310000, 200) # UTM coordinates + y_10m = np.linspace(4900000, 4910000, 200) + + # 20m resolution data (100x100 pixels) + x_20m = np.linspace(300000, 310000, 100) + y_20m = np.linspace(4900000, 4910000, 100) + + # 60m resolution data (50x50 pixels) + x_60m = np.linspace(300000, 310000, 50) + y_60m = np.linspace(4900000, 4910000, 50) + + # Create realistic spectral bands + measurements = { + 10: { + 'reflectance': { + 'b02': xr.DataArray( + np.random.randint(500, 3000, (1, 200, 200), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + attrs={'long_name': 'Blue band', 'units': 'digital_number'} + ), + 'b03': xr.DataArray( + np.random.randint(600, 3500, (1, 200, 200), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + attrs={'long_name': 'Green band', 'units': 'digital_number'} + ), + 'b04': xr.DataArray( + np.random.randint(400, 3200, (1, 200, 200), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + attrs={'long_name': 'Red band', 'units': 'digital_number'} + ), + 'b08': xr.DataArray( + np.random.randint(3000, 6000, (1, 200, 200), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_10m, 'x': x_10m}, + attrs={'long_name': 'NIR band', 'units': 'digital_number'} + ) + } + }, + 20: { + 'reflectance': { + 'b05': xr.DataArray( + np.random.randint(2000, 4000, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Red edge 1', 'units': 'digital_number'} + ), + 'b06': xr.DataArray( + np.random.randint(2500, 4500, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Red edge 2', 'units': 'digital_number'} + ), + 'b07': xr.DataArray( + np.random.randint(2800, 4800, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Red edge 3', 'units': 'digital_number'} + ), + 'b11': xr.DataArray( + np.random.randint(1000, 3000, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'SWIR 1', 'units': 'digital_number'} + ), + 'b12': xr.DataArray( + np.random.randint(500, 2500, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'SWIR 2', 'units': 'digital_number'} + ), + 'b8a': xr.DataArray( + np.random.randint(2800, 5500, (1, 100, 100), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'NIR narrow', 'units': 'digital_number'} + ) + }, + 'quality': { + 'scl': xr.DataArray( + np.random.randint(0, 11, (1, 100, 100), dtype=np.uint8), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Scene classification', 'units': 'class'} + ), + 'aot': xr.DataArray( + np.random.randint(0, 1000, (1, 100, 100), dtype=np.uint16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Aerosol optical thickness', 'units': 'dimensionless'} + ), + 'wvp': xr.DataArray( + np.random.randint(0, 5000, (1, 100, 100), dtype=np.uint16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_20m, 'x': x_20m}, + attrs={'long_name': 'Water vapor', 'units': 'kg/m^2'} + ) + } + }, + 60: { + 'reflectance': { + 'b01': xr.DataArray( + np.random.randint(1500, 3500, (1, 50, 50), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_60m, 'x': x_60m}, + attrs={'long_name': 'Coastal aerosol', 'units': 'digital_number'} + ), + 'b09': xr.DataArray( + np.random.randint(100, 1000, (1, 50, 50), dtype=np.int16), + dims=['time', 'y', 'x'], + coords={'time': time, 'y': y_60m, 'x': x_60m}, + attrs={'long_name': 'Water vapor', 'units': 'digital_number'} + ) + } + } + } + + return measurements + + @patch('builtins.print') # Mock print to avoid test output + def test_full_pyramid_creation(self, mock_print, real_measurements_data, tmp_path): + """Test complete pyramid creation with realistic data.""" + pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=512) + + output_path = str(tmp_path) + + with patch.object(pyramid, '_write_level_dataset') as mock_write: + result = pyramid.create_multiscale_measurements(real_measurements_data, output_path) + + # Should create all 7 pyramid levels + assert len(result) == 7 + + # Check that each level has appropriate characteristics + for level, dataset in result.items(): + assert dataset.attrs['pyramid_level'] == level + assert dataset.attrs['resolution_meters'] == pyramid.pyramid_levels[level] + assert len(dataset.data_vars) > 0 + + # Verify write was called for each level + assert mock_write.call_count == 7 + + def test_level_specific_content(self, real_measurements_data): + """Test that each pyramid level contains appropriate content.""" + pyramid = S2MultiscalePyramid(enable_sharding=False, spatial_chunk=256) # Disable sharding for simpler testing + + # Test level 0 (10m native) + level_0 = pyramid._create_level_0_dataset(real_measurements_data) + level_0_vars = set(level_0.data_vars.keys()) + expected_10m_vars = {'b02', 'b03', 'b04', 'b08'} + assert len(expected_10m_vars.intersection(level_0_vars)) > 0 + + # Test level 1 (20m consolidated) + level_1 = pyramid._create_level_1_dataset(real_measurements_data) + # Should contain both native 20m and downsampled 10m data + level_1_vars = set(level_1.data_vars.keys()) + # Check some expected variables are present + expected_vars = {'b05', 'b06', 'b07', 'b11', 'b12', 'b8a', 'scl', 'aot', 'wvp'} + assert len(expected_vars.intersection(level_1_vars)) > 0 + + # Test level 2 (60m consolidated) + level_2 = pyramid._create_level_2_dataset(real_measurements_data) + # Should contain native 60m and processed 20m data + level_2_vars = set(level_2.data_vars.keys()) + expected_60m_vars = {'b01', 'b09'} + assert len(expected_60m_vars.intersection(level_2_vars)) > 0 + + def test_sharding_configuration_integration(self, real_measurements_data): + """Test sharding configuration with realistic data.""" + pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=256) + + # Create a test dataset + level_0 = pyramid._create_level_0_dataset(real_measurements_data) + + if len(level_0.data_vars) > 0: + encoding = pyramid._create_level_encoding(level_0, level=0) + + # Check encoding structure + for var_name, var_data in level_0.data_vars.items(): + assert var_name in encoding + var_encoding = encoding[var_name] + + # Check sharding configuration + if var_data.ndim >= 2: + assert 'shards' in var_encoding + shards = var_encoding['shards'] + + # Verify shard dimensions are reasonable + if var_data.ndim == 3: + assert shards[0] == 1 # Time dimension + assert shards[1] > 0 # Y dimension + assert shards[2] > 0 # X dimension + elif var_data.ndim == 2: + assert shards[0] > 0 # Y dimension + assert shards[1] > 0 # X dimension + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_measurements_data(self): + """Test handling of empty measurements data.""" + pyramid = S2MultiscalePyramid() + + empty_measurements = {} + + with patch('builtins.print'): + with patch.object(pyramid, '_write_level_dataset'): + result = pyramid.create_multiscale_measurements(empty_measurements, "/tmp") + + # Should return empty results + assert len(result) == 0 + + def test_missing_resolution_data(self): + """Test handling when specific resolution data is missing.""" + pyramid = S2MultiscalePyramid() + + # Only provide 20m data, missing 10m and 60m + measurements_partial = { + 20: { + 'reflectance': { + 'b05': xr.DataArray( + np.random.rand(1, 50, 50), + dims=['time', 'y', 'x'], + coords={ + 'time': ['2023-01-01'], + 'y': np.arange(50), + 'x': np.arange(50) + } + ) + } + } + } + + # Should handle gracefully + level_0 = pyramid._create_level_0_dataset(measurements_partial) + assert len(level_0.data_vars) == 0 # No 10m data available + + level_1 = pyramid._create_level_1_dataset(measurements_partial) + assert len(level_1.data_vars) > 0 # Should have 20m data + + def test_coordinate_preservation(self): + """Test that coordinate systems are preserved through processing.""" + pyramid = S2MultiscalePyramid() + + # Create data with specific coordinate attributes + x = np.linspace(300000, 310000, 100) + y = np.linspace(4900000, 4910000, 100) + time = np.array(['2023-01-01'], dtype='datetime64[ns]') + + # Add coordinate attributes + x_coord = xr.DataArray(x, dims=['x'], attrs={'units': 'm', 'crs': 'EPSG:32633'}) + y_coord = xr.DataArray(y, dims=['y'], attrs={'units': 'm', 'crs': 'EPSG:32633'}) + time_coord = xr.DataArray(time, dims=['time'], attrs={'calendar': 'gregorian'}) + + test_data = xr.DataArray( + np.random.rand(1, 100, 100), + dims=['time', 'y', 'x'], + coords={'time': time_coord, 'y': y_coord, 'x': x_coord}, + name='b05' + ) + + measurements = { + 20: {'reflectance': {'b05': test_data}} + } + + dataset = pyramid._create_level_1_dataset(measurements) + + # Check that coordinate attributes are preserved + if 'b05' in dataset.data_vars: + assert 'x' in dataset.coords + assert 'y' in dataset.coords + assert 'time' in dataset.coords + + # Check coordinate attributes preservation + assert dataset.coords['x'].attrs.get('units') == 'm' + assert dataset.coords['y'].attrs.get('units') == 'm' From 1ae2c19d3a1950b37f95fe81c50d79d551dfbf4b Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 17:50:01 +0200 Subject: [PATCH 14/83] feat: simplify chunk alignment and sharding logic in S2MultiscalePyramid --- .../s2_optimization/s2_multiscale.py | 110 +++++------------- src/eopf_geozarr/tests/test_s2_multiscale.py | 75 ++---------- 2 files changed, 42 insertions(+), 143 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 4c2e4144..edfa9834 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -361,15 +361,18 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: if var_data.ndim >= 2: height, width = var_data.shape[-2:] - # Ensure x/y alignment: adjust chunk sizes to align with sharding - chunk_y = self._align_chunk_to_xy_dimensions(chunk_size, height) - chunk_x = self._align_chunk_to_xy_dimensions(chunk_size, width) + # Use original geozarr.py chunk alignment logic + spatial_chunk_aligned = min( + chunk_size, + self._calculate_aligned_chunk_size(width, chunk_size), + self._calculate_aligned_chunk_size(height, chunk_size), + ) if var_data.ndim == 3: # Single file per variable per time: chunk time dimension to 1 - chunks = (1, chunk_y, chunk_x) + chunks = (1, spatial_chunk_aligned, spatial_chunk_aligned) else: - chunks = (chunk_y, chunk_x) + chunks = (spatial_chunk_aligned, spatial_chunk_aligned) else: chunks = (min(chunk_size, var_data.shape[0]),) @@ -379,9 +382,9 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: 'compressor': 'default' } - # Add xy-aligned sharding if enabled + # Add simplified sharding if enabled - shards match x/y dimensions exactly if self.enable_sharding and var_data.ndim >= 2: - shard_dims = self._calculate_xy_aligned_shard_dimensions(var_data.shape, chunks) + shard_dims = self._calculate_simple_shard_dimensions(var_data.shape) var_encoding['shards'] = shard_dims encoding[var_name] = var_encoding @@ -391,92 +394,39 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: encoding[coord_name] = {'compressor': None} return encoding - - def _align_chunk_to_xy_dimensions(self, chunk_size: int, dimension_size: int) -> int: + + def _calculate_aligned_chunk_size(self, dimension_size: int, target_chunk: int) -> int: """ - Align chunk size to be compatible with x/y dimension sharding requirements. + Calculate aligned chunk size following geozarr.py logic. - Args: - chunk_size: Requested chunk size - dimension_size: Total size of the dimension - - Returns: - Aligned chunk size that works well with sharding + This ensures good chunk alignment without complex calculations. """ - if chunk_size >= dimension_size: + if target_chunk >= dimension_size: return dimension_size - - # Find a good divisor that's close to the requested size - best_chunk = chunk_size - best_remainder = dimension_size % chunk_size - - # Try nearby values to find better alignment - search_range = min(50, chunk_size // 4) - for offset in range(-search_range, search_range + 1): - candidate = chunk_size + offset - if candidate > 0 and candidate <= dimension_size: - remainder = dimension_size % candidate - if remainder < best_remainder or (remainder == best_remainder and candidate > best_chunk): - best_chunk = candidate - best_remainder = remainder - + + # Find the largest divisor of dimension_size that's close to target_chunk + best_chunk = target_chunk + for chunk_candidate in range(target_chunk, max(target_chunk // 2, 1), -1): + if dimension_size % chunk_candidate == 0: + best_chunk = chunk_candidate + break + return best_chunk - def _calculate_xy_aligned_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) -> Tuple: + def _calculate_simple_shard_dimensions(self, data_shape: Tuple) -> Tuple: """ - Calculate shard dimensions for Zarr v3 sharding with x/y alignment. + Calculate shard dimensions that simply match x/y dimensions exactly. - Ensures shards are properly aligned with spatial dimensions (x, y) - and maintains single file per variable per time point. + Shards dimensions will always be the same as the x and y dimensions. """ shard_dims = [] - for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): - # Special handling for different dimensions + for i, dim_size in enumerate(data_shape): if i == 0 and len(data_shape) == 3: # First dimension in 3D data (time) - use single time slice per shard - shard_dim = 1 - elif i >= len(data_shape) - 2: - # Last two dimensions (y, x) - ensure proper spatial alignment - shard_dim = self._calculate_spatial_shard_dim(dim_size, chunk_size) + shard_dims.append(1) else: - # Other dimensions - standard calculation - shard_dim = self._calculate_standard_shard_dim(dim_size, chunk_size) - - shard_dims.append(shard_dim) + # For x/y dimensions, shard dimension equals the full dimension size + shard_dims.append(dim_size) return tuple(shard_dims) - - def _calculate_spatial_shard_dim(self, dim_size: int, chunk_size: int) -> int: - """Calculate shard dimension for spatial dimensions (x, y).""" - if chunk_size >= dim_size: - return dim_size - - # For spatial dimensions, align shard boundaries with chunk boundaries - # Use multiple of chunk_size that provides good balance - num_chunks = dim_size // chunk_size - if num_chunks >= 4: - # Use 4 chunks per shard if possible for good balance - shard_dim = min(4 * chunk_size, dim_size) - elif num_chunks >= 2: - # Use 2 chunks per shard as minimum - shard_dim = 2 * chunk_size - else: - # Single chunk per shard - shard_dim = chunk_size - - return shard_dim - - def _calculate_standard_shard_dim(self, dim_size: int, chunk_size: int) -> int: - """Calculate shard dimension for non-spatial dimensions.""" - if chunk_size >= dim_size: - return dim_size - - # For non-spatial dimensions, use standard calculation - num_chunks = dim_size // chunk_size - if num_chunks >= 4: - shard_dim = min(4 * chunk_size, dim_size) - else: - shard_dim = num_chunks * chunk_size - - return shard_dim diff --git a/src/eopf_geozarr/tests/test_s2_multiscale.py b/src/eopf_geozarr/tests/test_s2_multiscale.py index c6f86e38..6a03f1cd 100644 --- a/src/eopf_geozarr/tests/test_s2_multiscale.py +++ b/src/eopf_geozarr/tests/test_s2_multiscale.py @@ -146,77 +146,26 @@ def test_pyramid_levels_structure(self, pyramid): assert pyramid.pyramid_levels == expected_levels - def test_align_chunk_to_xy_dimensions(self, pyramid): - """Test chunk alignment for x/y dimensions.""" - # Test case 1: chunk size smaller than dimension - aligned = pyramid._align_chunk_to_xy_dimensions(256, 1000) - assert aligned > 0 - assert aligned <= 1000 - assert 1000 % aligned <= 1000 % 256 # Should have better or equal alignment - - # Test case 2: chunk size larger than dimension - aligned = pyramid._align_chunk_to_xy_dimensions(2000, 1000) - assert aligned == 1000 - - # Test case 3: exact divisor - aligned = pyramid._align_chunk_to_xy_dimensions(250, 1000) - assert aligned > 0 - assert aligned <= 1000 - - def test_calculate_spatial_shard_dim(self, pyramid): - """Test spatial shard dimension calculation.""" - # Test case 1: chunk size smaller than dimension, multiple chunks possible - shard_dim = pyramid._calculate_spatial_shard_dim(1000, 200) - assert shard_dim == 800 # 4 chunks * 200 - - # Test case 2: chunk size larger than dimension - shard_dim = pyramid._calculate_spatial_shard_dim(100, 200) - assert shard_dim == 100 - - # Test case 3: only 2 chunks possible - shard_dim = pyramid._calculate_spatial_shard_dim(600, 250) - assert shard_dim == 500 # 2 chunks * 250 - - # Test case 4: single chunk - shard_dim = pyramid._calculate_spatial_shard_dim(300, 250) - assert shard_dim == 250 - - def test_calculate_standard_shard_dim(self, pyramid): - """Test standard shard dimension calculation.""" - # Test case 1: multiple chunks possible - shard_dim = pyramid._calculate_standard_shard_dim(1000, 200) - assert shard_dim == 800 # 4 chunks * 200 - - # Test case 2: chunk size larger than dimension - shard_dim = pyramid._calculate_standard_shard_dim(100, 200) - assert shard_dim == 100 - - # Test case 3: few chunks - shard_dim = pyramid._calculate_standard_shard_dim(600, 250) - assert shard_dim == 500 # 2 chunks * 250 - - def test_calculate_xy_aligned_shard_dimensions(self, pyramid): - """Test xy-aligned shard dimensions calculation.""" - # Test 3D data (time, y, x) + def test_calculate_simple_shard_dimensions(self, pyramid): + """Test simplified shard dimensions calculation.""" + # Test 3D data (time, y, x) - shards match dimensions exactly data_shape = (5, 1000, 1000) - chunks = (1, 256, 256) - shard_dims = pyramid._calculate_xy_aligned_shard_dimensions(data_shape, chunks) + shard_dims = pyramid._calculate_simple_shard_dimensions(data_shape) assert len(shard_dims) == 3 - assert shard_dims[0] == 1 # Time dimension should be 1 - assert shard_dims[1] > 0 # Y dimension - assert shard_dims[2] > 0 # X dimension + assert shard_dims[0] == 1 # Time dimension should be 1 + assert shard_dims[1] == 1000 # Y dimension matches exactly + assert shard_dims[2] == 1000 # X dimension matches exactly - # Test 2D data (y, x) - data_shape = (1000, 1000) - chunks = (256, 256) + # Test 2D data (y, x) - shards match dimensions exactly + data_shape = (500, 800) - shard_dims = pyramid._calculate_xy_aligned_shard_dimensions(data_shape, chunks) + shard_dims = pyramid._calculate_simple_shard_dimensions(data_shape) assert len(shard_dims) == 2 - assert shard_dims[0] > 0 # Y dimension - assert shard_dims[1] > 0 # X dimension + assert shard_dims[0] == 500 # Y dimension matches exactly + assert shard_dims[1] == 800 # X dimension matches exactly def test_create_level_encoding(self, pyramid, sample_dataset): """Test level encoding creation with xy-aligned sharding.""" From fca840ecb5db73628c3ef142f26f0392be50df5a Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 18:01:56 +0200 Subject: [PATCH 15/83] feat: integrate S2 optimization commands into CLI and enhance converter functionality --- src/eopf_geozarr/cli.py | 105 ++++++ .../s2_optimization/cli_integration.py | 45 --- .../s2_optimization/s2_converter.py | 320 ++++++++++++++++-- 3 files changed, 403 insertions(+), 67 deletions(-) delete mode 100644 src/eopf_geozarr/s2_optimization/cli_integration.py diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 3bd9df75..3c99ec73 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -13,6 +13,8 @@ import xarray as xr +from eopf_geozarr.s2_optimization.s2_converter import convert_s2_optimized + from . import create_geozarr_dataset from .conversion.fs_utils import ( get_s3_credentials_info, @@ -1144,9 +1146,112 @@ def create_parser() -> argparse.ArgumentParser: "--verbose", action="store_true", help="Enable verbose output" ) validate_parser.set_defaults(func=validate_command) + + # Add S2 optimization commands + add_s2_optimization_commands(subparsers) return parser +def add_s2_optimization_commands(subparsers): + """Add S2 optimization commands to CLI parser.""" + + # Convert S2 optimized command + s2_parser = subparsers.add_parser( + 'convert-s2-optimized', + help='Convert Sentinel-2 dataset to optimized structure' + ) + s2_parser.add_argument( + 'input_path', + type=str, + help='Path to input Sentinel-2 dataset (Zarr format)' + ) + s2_parser.add_argument( + 'output_path', + type=str, + help='Path for output optimized dataset' + ) + s2_parser.add_argument( + '--spatial-chunk', + type=int, + default=1024, + help='Spatial chunk size (default: 1024)' + ) + s2_parser.add_argument( + '--enable-sharding', + action='store_true', + help='Enable Zarr v3 sharding' + ) + s2_parser.add_argument( + '--compression-level', + type=int, + default=3, + choices=range(1, 10), + help='Compression level 1-9 (default: 3)' + ) + s2_parser.add_argument( + '--skip-geometry', + action='store_true', + help='Skip creating geometry group' + ) + s2_parser.add_argument( + '--skip-meteorology', + action='store_true', + help='Skip creating meteorology group' + ) + s2_parser.add_argument( + '--skip-validation', + action='store_true', + help='Skip output validation' + ) + s2_parser.add_argument( + '--verbose', + action='store_true', + help='Enable verbose output' + ) + s2_parser.set_defaults(func=convert_s2_optimized_command) + +def convert_s2_optimized_command(args): + """Execute S2 optimized conversion command.""" + try: + # Validate input + input_path = Path(args.input_path) + if not input_path.exists(): + print(f"Error: Input path {input_path} does not exist") + return 1 + + # Load input dataset + print(f"Loading Sentinel-2 dataset from: {args.input_path}") + storage_options = get_storage_options(str(input_path)) + dt_input = xr.open_datatree( + str(input_path), + engine='zarr', + chunks='auto', + storage_options=storage_options + ) + + # Convert + dt_optimized = convert_s2_optimized( + dt_input=dt_input, + output_path=args.output_path, + enable_sharding=args.enable_sharding, + spatial_chunk=args.spatial_chunk, + compression_level=args.compression_level, + create_geometry_group=not args.skip_geometry, + create_meteorology_group=not args.skip_meteorology, + validate_output=not args.skip_validation, + verbose=args.verbose + ) + + print(f"✅ S2 optimization completed: {args.output_path}") + return 0 + + except Exception as e: + print(f"❌ Error during S2 optimization: {e}") + if args.verbose: + import traceback + traceback.print_exc() + return 1 + def main() -> None: """Execute main entry point for the CLI.""" diff --git a/src/eopf_geozarr/s2_optimization/cli_integration.py b/src/eopf_geozarr/s2_optimization/cli_integration.py deleted file mode 100644 index 582dd96b..00000000 --- a/src/eopf_geozarr/s2_optimization/cli_integration.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -CLI integration for Sentinel-2 optimization. -""" - -import argparse -from .s2_converter import S2OptimizedConverter - -def add_s2_optimization_commands(subparsers): - """Add Sentinel-2 optimization commands to CLI parser.""" - - s2_parser = subparsers.add_parser( - 'convert-s2-optimized', - help='Convert Sentinel-2 dataset to optimized structure' - ) - s2_parser.add_argument( - 'input_path', - type=str, - help='Path to input Sentinel-2 dataset (Zarr format)' - ) - s2_parser.add_argument( - 'output_path', - type=str, - help='Path for output optimized dataset' - ) - s2_parser.add_argument( - '--spatial-chunk', - type=int, - default=1024, - help='Spatial chunk size (default: 1024)' - ) - s2_parser.add_argument( - '--enable-sharding', - action='store_true', - help='Enable Zarr v3 sharding' - ) - s2_parser.set_defaults(func=convert_s2_optimized_command) - -def convert_s2_optimized_command(args): - """Execute Sentinel-2 optimized conversion command.""" - converter = S2OptimizedConverter( - enable_sharding=args.enable_sharding, - spatial_chunk=args.spatial_chunk - ) - # Placeholder for CLI command execution logic - pass diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index d828879d..5c3b3e5f 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -2,35 +2,311 @@ Main S2 optimization converter. """ +import os +import time +from pathlib import Path +from typing import Dict, Optional, List +import xarray as xr + +from .s2_data_consolidator import S2DataConsolidator, create_consolidated_dataset +from .s2_multiscale import S2MultiscalePyramid +from .s2_validation import S2OptimizationValidator +from eopf_geozarr.conversion.fs_utils import get_storage_options, normalize_path + class S2OptimizedConverter: """Optimized Sentinel-2 to GeoZarr converter.""" - def __init__(self, enable_sharding=True, spatial_chunk=1024): + def __init__( + self, + enable_sharding: bool = True, + spatial_chunk: int = 1024, + compression_level: int = 3, + max_retries: int = 3 + ): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk + self.compression_level = compression_level + self.max_retries = max_retries - def convert_s2(self, dt_input, output_path, **kwargs): - """Main conversion entry point.""" - from .s2_data_consolidator import S2DataConsolidator - from .s2_multiscale import S2MultiscalePyramid - from .s2_validation import S2OptimizationValidator - - # Consolidate data + # Initialize components + self.pyramid_creator = S2MultiscalePyramid(enable_sharding, spatial_chunk) + self.validator = S2OptimizationValidator() + + def convert_s2_optimized( + self, + dt_input: xr.DataTree, + output_path: str, + create_geometry_group: bool = True, + create_meteorology_group: bool = True, + validate_output: bool = True, + verbose: bool = False + ) -> xr.DataTree: + """ + Convert S2 dataset to optimized structure. + + Args: + dt_input: Input Sentinel-2 DataTree + output_path: Output path for optimized dataset + create_geometry_group: Whether to create geometry group + create_meteorology_group: Whether to create meteorology group + validate_output: Whether to validate the output + verbose: Enable verbose logging + + Returns: + Optimized DataTree + """ + start_time = time.time() + + if verbose: + print(f"Starting S2 optimized conversion...") + print(f"Input: {len(dt_input.groups)} groups") + print(f"Output: {output_path}") + + # Validate input is S2 + if not self._is_sentinel2_dataset(dt_input): + raise ValueError("Input dataset is not a Sentinel-2 product") + + # Step 1: Consolidate data from scattered structure + print("Step 1: Consolidating EOPF data structure...") consolidator = S2DataConsolidator(dt_input) - measurements, geometry, meteorology = consolidator.consolidate_all_data() - - # Create multiscale pyramids - pyramid = S2MultiscalePyramid( - enable_sharding=self.enable_sharding, - spatial_chunk=self.spatial_chunk + measurements_data, geometry_data, meteorology_data = consolidator.consolidate_all_data() + + if verbose: + print(f" Measurements data extracted: {sum(len(d['bands']) for d in measurements_data.values())} bands") + print(f" Geometry variables: {len(geometry_data)}") + print(f" Meteorology variables: {len(meteorology_data)}") + + # Step 2: Create multiscale measurements + print("Step 2: Creating multiscale measurements pyramid...") + pyramid_datasets = self.pyramid_creator.create_multiscale_measurements( + measurements_data, output_path ) - multiscale_data = pyramid.create_multiscale_measurements(measurements, output_path) + + print(f" Created {len(pyramid_datasets)} pyramid levels") + + # Step 3: Create geometry group + if create_geometry_group and geometry_data: + print("Step 3: Creating consolidated geometry group...") + geometry_ds = xr.Dataset(geometry_data) + geometry_path = f"{output_path}/geometry" + self._write_auxiliary_group(geometry_ds, geometry_path, "geometry", verbose) + + # Step 4: Create meteorology group + if create_meteorology_group and meteorology_data: + print("Step 4: Creating consolidated meteorology group...") + meteorology_ds = xr.Dataset(meteorology_data) + meteorology_path = f"{output_path}/meteorology" + self._write_auxiliary_group(meteorology_ds, meteorology_path, "meteorology", verbose) + + # Step 5: Create root-level multiscales metadata + print("Step 5: Adding multiscales metadata...") + self._add_root_multiscales_metadata(output_path, pyramid_datasets) + + # Step 6: Consolidate metadata + print("Step 6: Consolidating metadata...") + self._consolidate_root_metadata(output_path) + + # Step 7: Validation + if validate_output: + print("Step 7: Validating optimized dataset...") + validation_results = self.validator.validate_optimized_dataset(output_path) + if not validation_results['is_valid']: + print(" Warning: Validation issues found:") + for issue in validation_results['issues']: + print(f" - {issue}") + + # Create result DataTree + result_dt = self._create_result_datatree(output_path) + + total_time = time.time() - start_time + print(f"Optimization complete in {total_time:.2f}s") + + if verbose: + self._print_optimization_summary(dt_input, result_dt, output_path) + + return result_dt + + def _is_sentinel2_dataset(self, dt: xr.DataTree) -> bool: + """Check if dataset is Sentinel-2.""" + # Check STAC properties + stac_props = dt.attrs.get('stac_discovery', {}).get('properties', {}) + mission = stac_props.get('mission', '') + + if mission.lower().startswith('sentinel-2'): + return True + + # Check for characteristic S2 groups + s2_indicators = [ + '/measurements/reflectance', + '/conditions/geometry', + '/quality/atmosphere' + ] + + found_indicators = sum(1 for indicator in s2_indicators if indicator in dt.groups) + return found_indicators >= 2 + + def _write_auxiliary_group( + self, + dataset: xr.Dataset, + group_path: str, + group_type: str, + verbose: bool + ) -> None: + """Write auxiliary group (geometry or meteorology).""" + # Create simple encoding + encoding = {} + for var_name in dataset.data_vars: + encoding[var_name] = {'compressor': 'default'} + for coord_name in dataset.coords: + encoding[coord_name] = {'compressor': None} + + # Write dataset + storage_options = get_storage_options(group_path) + dataset.to_zarr( + group_path, + mode='w', + consolidated=True, + zarr_format=3, + encoding=encoding, + storage_options=storage_options + ) + + if verbose: + print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") + + def _add_root_multiscales_metadata( + self, + output_path: str, + pyramid_datasets: Dict[int, xr.Dataset] + ) -> None: + """Add multiscales metadata at root level.""" + from ..geozarr import create_native_crs_tile_matrix_set, calculate_overview_levels + + # Get information from level 0 dataset + if 0 not in pyramid_datasets: + return + + level_0_ds = pyramid_datasets[0] + if not level_0_ds.data_vars: + return + + # Get spatial info from first variable + first_var = next(iter(level_0_ds.data_vars.values())) + native_height, native_width = first_var.shape[-2:] + native_crs = level_0_ds.rio.crs + native_bounds = level_0_ds.rio.bounds() + + # Calculate overview levels + overview_levels = [] + for level, resolution in self.pyramid_creator.pyramid_levels.items(): + if level in pyramid_datasets: + level_ds = pyramid_datasets[level] + level_var = next(iter(level_ds.data_vars.values())) + level_height, level_width = level_var.shape[-2:] + + overview_levels.append({ + 'level': level, + 'resolution': resolution, + 'width': level_width, + 'height': level_height, + 'scale_factor': 2 ** level if level > 0 else 1 + }) + + # Create tile matrix set + tile_matrix_set = create_native_crs_tile_matrix_set( + native_crs, native_bounds, overview_levels, "measurements" + ) + + # Add metadata to measurements group + measurements_zarr_path = normalize_path(f"{output_path}/measurements/zarr.json") + if os.path.exists(measurements_zarr_path): + import json + with open(measurements_zarr_path, 'r') as f: + zarr_json = json.load(f) + + zarr_json.setdefault('attributes', {}) + zarr_json['attributes']['multiscales'] = { + 'tile_matrix_set': tile_matrix_set, + 'resampling_method': 'average', + 'datasets': [{'path': str(level)} for level in sorted(pyramid_datasets.keys())] + } + + with open(measurements_zarr_path, 'w') as f: + json.dump(zarr_json, f, indent=2) + + def _consolidate_root_metadata(self, output_path: str) -> None: + """Consolidate metadata at root level.""" + try: + from eopf_geozarr.conversion.geozarr import consolidate_metadata + from eopf_geozarr.conversion.fs_utils import open_zarr_group + + zarr_group = open_zarr_group(output_path, mode="r+") + consolidate_metadata(zarr_group.store) + except Exception as e: + print(f" Warning: Root metadata consolidation failed: {e}") + + def _create_result_datatree(self, output_path: str) -> xr.DataTree: + """Create result DataTree from written output.""" + try: + storage_options = get_storage_options(output_path) + return xr.open_datatree( + output_path, + engine='zarr', + chunks='auto', + storage_options=storage_options + ) + except Exception as e: + print(f"Warning: Could not open result DataTree: {e}") + return xr.DataTree() + + def _print_optimization_summary( + self, + dt_input: xr.DataTree, + dt_output: xr.DataTree, + output_path: str + ) -> None: + """Print optimization summary statistics.""" + print("\n" + "="*50) + print("OPTIMIZATION SUMMARY") + print("="*50) + + # Count groups + input_groups = len(dt_input.groups) if hasattr(dt_input, 'groups') else 0 + output_groups = len(dt_output.groups) if hasattr(dt_output, 'groups') else 0 + + print(f"Groups: {input_groups} → {output_groups} ({((output_groups-input_groups)/input_groups*100):+.1f}%)") + + # Estimate file count reduction + estimated_input_files = input_groups * 10 # Rough estimate + estimated_output_files = output_groups * 5 # Fewer files per group + print(f"Estimated files: {estimated_input_files} → {estimated_output_files} ({((estimated_output_files-estimated_input_files)/estimated_input_files*100):+.1f}%)") + + # Show structure + print(f"\nNew structure:") + print(f" /measurements/ (multiscale: levels 0-6)") + if f"{output_path}/geometry" in str(dt_output): + print(f" /geometry/ (consolidated)") + if f"{output_path}/meteorology" in str(dt_output): + print(f" /meteorology/ (consolidated)") + + print("="*50) - # Validate the output - validator = S2OptimizationValidator() - validation_results = validator.validate_optimized_dataset(output_path) - return { - "multiscale_data": multiscale_data, - "validation_results": validation_results - } +def convert_s2_optimized( + dt_input: xr.DataTree, + output_path: str, + **kwargs +) -> xr.DataTree: + """ + Convenience function for S2 optimization. + + Args: + dt_input: Input Sentinel-2 DataTree + output_path: Output path + **kwargs: Additional arguments for S2OptimizedConverter + + Returns: + Optimized DataTree + """ + converter = S2OptimizedConverter(**kwargs) + return converter.convert_s2_optimized(dt_input, output_path, **kwargs) \ No newline at end of file From 72546bc5410fe0820e50f3f5170e95ce4b89228c Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Fri, 26 Sep 2025 18:18:13 +0200 Subject: [PATCH 16/83] feat: add S2L2A optimized conversion command to CLI and update launch configuration --- .vscode/launch.json | 26 ++++++++++++++++++++++++++ src/eopf_geozarr/cli.py | 10 ++-------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 3beff475..6e74c339 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -133,6 +133,32 @@ "AWS_ENDPOINT_URL": "https://s3.de.io.cloud.ovh.net/" }, + }, + { + // eopf_geozarr convert https://objectstore.eodc.eu:2222/e05ab01a9d56408d82ac32d69a5aae2a:sample-data/tutorial_data/cpm_v253/S2B_MSIL1C_20250113T103309_N0511_R108_T32TLQ_20250113T122458.zarr /tmp/tmp7mmjkjk3/s2b_subset_test.zarr --groups /measurements/reflectance/r10m --spatial-chunk 512 --min-dimension 128 --tile-width 256 --max-retries 2 --verbose + "name": "Convert to GeoZarr S2L2A Optimized (S3)", + "type": "debugpy", + "request": "launch", + "module": "eopf_geozarr", + "args": [ + "convert-s2-optimized", + "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/08/products/cpm_v256/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + "./tests-output/eopf_geozarr/s2l2_optimized.zarr", + "--spatial-chunk", "256", + "--enable-sharding", + "--verbose" + ], + "cwd": "${workspaceFolder}", + "justMyCode": false, + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}/.venv/bin", + "AWS_PROFILE": "eopf-explorer", + "AWS_DEFAULT_REGION": "de", + "AWS_ENDPOINT_URL": "https://s3.de.io.cloud.ovh.net/" + }, + }, { "name": "Convert to GeoZarr Sentinel-1 GRD (Local)", diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 3c99ec73..092dea1b 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -1213,17 +1213,11 @@ def add_s2_optimization_commands(subparsers): def convert_s2_optimized_command(args): """Execute S2 optimized conversion command.""" try: - # Validate input - input_path = Path(args.input_path) - if not input_path.exists(): - print(f"Error: Input path {input_path} does not exist") - return 1 - # Load input dataset print(f"Loading Sentinel-2 dataset from: {args.input_path}") - storage_options = get_storage_options(str(input_path)) + storage_options = get_storage_options(str(args.input_path)) dt_input = xr.open_datatree( - str(input_path), + str(args.input_path), engine='zarr', chunks='auto', storage_options=storage_options From 4d079073c6656e00f31ccb18dc745ea5f54efb2f Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sat, 27 Sep 2025 20:27:28 +0200 Subject: [PATCH 17/83] feat: enhance S2 converter and multiscale pyramid with optimized encoding and rechunking --- .../s2_optimization/s2_converter.py | 23 ++++++-- .../s2_optimization/s2_multiscale.py | 55 +++++++++++++++++-- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 5c3b3e5f..927d07a6 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -153,12 +153,14 @@ def _write_auxiliary_group( verbose: bool ) -> None: """Write auxiliary group (geometry or meteorology).""" - # Create simple encoding + # Create simple encoding following geozarr.py pattern + from zarr.codecs import BloscCodec + compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) encoding = {} for var_name in dataset.data_vars: - encoding[var_name] = {'compressor': 'default'} + encoding[var_name] = {'compressors': [compressor]} for coord_name in dataset.coords: - encoding[coord_name] = {'compressor': None} + encoding[coord_name] = {'compressors': None} # Write dataset storage_options = get_storage_options(group_path) @@ -308,5 +310,16 @@ def convert_s2_optimized( Returns: Optimized DataTree """ - converter = S2OptimizedConverter(**kwargs) - return converter.convert_s2_optimized(dt_input, output_path, **kwargs) \ No newline at end of file + # Separate constructor args from method args + constructor_args = { + 'enable_sharding': kwargs.pop('enable_sharding', True), + 'spatial_chunk': kwargs.pop('spatial_chunk', 1024), + 'compression_level': kwargs.pop('compression_level', 3), + 'max_retries': kwargs.pop('max_retries', 3) + } + + # Remaining kwargs are for the convert_s2_optimized method + method_args = kwargs + + converter = S2OptimizedConverter(**constructor_args) + return converter.convert_s2_optimized(dt_input, output_path, **method_args) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index edfa9834..a7dbf9f4 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -260,7 +260,11 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) else: # Write as single dataset with xy-aligned sharding print(f" Writing level {level} to {level_path} (xy-aligned sharding)") - dataset.to_zarr( + + # Rechunk the dataset to align with encoding chunks (following geozarr.py pattern) + rechunked_dataset = self._rechunk_dataset_for_encoding(dataset, encoding) + + rechunked_dataset.to_zarr( level_path, mode='w', consolidated=True, @@ -376,10 +380,12 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: else: chunks = (min(chunk_size, var_data.shape[0]),) - # Configure encoding + # Configure encoding - use proper compressor following geozarr.py pattern + from zarr.codecs import BloscCodec + compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) var_encoding = { 'chunks': chunks, - 'compressor': 'default' + 'compressors': [compressor] } # Add simplified sharding if enabled - shards match x/y dimensions exactly @@ -391,7 +397,7 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: # Add coordinate encoding for coord_name in dataset.coords: - encoding[coord_name] = {'compressor': None} + encoding[coord_name] = {'compressors': None} return encoding @@ -430,3 +436,44 @@ def _calculate_simple_shard_dimensions(self, data_shape: Tuple) -> Tuple: shard_dims.append(dim_size) return tuple(shard_dims) + + def _rechunk_dataset_for_encoding(self, dataset: xr.Dataset, encoding: Dict) -> xr.Dataset: + """ + Rechunk dataset variables to align with sharding dimensions when sharding is enabled. + + When using Zarr v3 sharding, Dask chunks must align with shard dimensions to avoid + checksum validation errors. + """ + rechunked_vars = {} + + for var_name, var_data in dataset.data_vars.items(): + if var_name in encoding: + var_encoding = encoding[var_name] + + # If sharding is enabled, rechunk based on shard dimensions + if 'shards' in var_encoding and var_encoding['shards'] is not None: + target_chunks = var_encoding['shards'] # Use shard dimensions for rechunking + elif 'chunks' in var_encoding: + target_chunks = var_encoding['chunks'] # Fallback to chunk dimensions + else: + # No specific chunking needed, use original variable + rechunked_vars[var_name] = var_data + continue + + # Create chunk dict using the actual dimensions of the variable + var_dims = var_data.dims + chunk_dict = {} + for i, dim in enumerate(var_dims): + if i < len(target_chunks): + chunk_dict[dim] = target_chunks[i] + + # Rechunk the variable to match the target dimensions + rechunked_vars[var_name] = var_data.chunk(chunk_dict) + else: + # No specific chunking needed, use original variable + rechunked_vars[var_name] = var_data + + # Create new dataset with rechunked variables, preserving coordinates + rechunked_dataset = xr.Dataset(rechunked_vars, coords=dataset.coords, attrs=dataset.attrs) + + return rechunked_dataset From 89bc6cfc1ac9503f8c9b241ea965312bab9f4424 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sat, 27 Sep 2025 23:22:21 +0200 Subject: [PATCH 18/83] feat: enhance sharding logic to ensure compatibility with chunk dimensions in S2MultiscalePyramid --- .../s2_optimization/s2_multiscale.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index a7dbf9f4..06c64c11 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -390,7 +390,7 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: # Add simplified sharding if enabled - shards match x/y dimensions exactly if self.enable_sharding and var_data.ndim >= 2: - shard_dims = self._calculate_simple_shard_dimensions(var_data.shape) + shard_dims = self._calculate_simple_shard_dimensions(var_data.shape, chunks) var_encoding['shards'] = shard_dims encoding[var_name] = var_encoding @@ -419,21 +419,34 @@ def _calculate_aligned_chunk_size(self, dimension_size: int, target_chunk: int) return best_chunk - def _calculate_simple_shard_dimensions(self, data_shape: Tuple) -> Tuple: + def _calculate_simple_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) -> Tuple: """ - Calculate shard dimensions that simply match x/y dimensions exactly. + Calculate shard dimensions that are compatible with chunk dimensions. - Shards dimensions will always be the same as the x and y dimensions. + Shard dimensions must be evenly divisible by chunk dimensions for Zarr v3. + When possible, shards should match x/y dimensions exactly as required. """ shard_dims = [] - for i, dim_size in enumerate(data_shape): + for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): if i == 0 and len(data_shape) == 3: # First dimension in 3D data (time) - use single time slice per shard shard_dims.append(1) else: - # For x/y dimensions, shard dimension equals the full dimension size - shard_dims.append(dim_size) + # For x/y dimensions, try to use full dimension size + # But ensure it's divisible by chunk size + if dim_size % chunk_size == 0: + # Perfect: full dimension is divisible by chunk + shard_dims.append(dim_size) + else: + # Find the largest multiple of chunk_size that fits + num_chunks = dim_size // chunk_size + if num_chunks > 0: + shard_size = num_chunks * chunk_size + shard_dims.append(shard_size) + else: + # Fallback: use chunk size itself + shard_dims.append(chunk_size) return tuple(shard_dims) From 3aff8d399c6be8b6d29c28cf5521e1c73cd88e84 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sat, 27 Sep 2025 23:28:14 +0200 Subject: [PATCH 19/83] feat: add downsampling for 10m data and adjust dataset creation for levels 3+ --- .../s2_optimization/s2_multiscale.py | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 06c64c11..a266a4a0 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -194,6 +194,26 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase downsampled = downsampled.assign_coords(reference_coords) all_vars[var_name] = downsampled + # Add downsampled 10m data + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + if reference_coords: + # Downsample to match 60m grid + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + if not all_vars: return xr.Dataset() @@ -210,25 +230,25 @@ def _create_downsampled_dataset( target_resolution: int, measurements_by_resolution: Dict ) -> xr.Dataset: - """Create downsampled dataset for levels 2+.""" - # Start from level 1 data (20m) and downsample - level_1_dataset = self._create_level_1_dataset(measurements_by_resolution) + """Create downsampled dataset for levels 3+.""" + # Start from level 2 data (60m) which includes all bands, and downsample + level_2_dataset = self._create_level_2_dataset(measurements_by_resolution) - if len(level_1_dataset.data_vars) == 0: + if len(level_2_dataset.data_vars) == 0: return xr.Dataset() - # Calculate target dimensions (downsample by factor of 2^(level-1)) - downsample_factor = 2 ** (level - 1) + # Calculate target dimensions (downsample by factor of 2^(level-2)) + downsample_factor = 2 ** (level - 2) - # Get reference dimensions from level 1 - ref_var = next(iter(level_1_dataset.data_vars.values())) + # Get reference dimensions from level 2 + ref_var = next(iter(level_2_dataset.data_vars.values())) current_height, current_width = ref_var.shape[-2:] target_height = current_height // downsample_factor target_width = current_width // downsample_factor downsampled_vars = {} - for var_name, var_data in level_1_dataset.data_vars.items(): + for var_name, var_data in level_2_dataset.data_vars.items(): var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( var_data, target_height, target_width, var_type From 1cd42815444d39fb742dcf0edb16bf4cbb76eb26 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 09:24:03 +0200 Subject: [PATCH 20/83] feat: add support for Dask cluster in S2 optimization commands and enhance progress tracking for Zarr writes --- .vscode/launch.json | 1 + src/eopf_geozarr/cli.py | 21 +++++++++ .../s2_optimization/s2_converter.py | 29 ++++++++++-- .../s2_optimization/s2_multiscale.py | 44 +++++++++++++++++-- 4 files changed, 88 insertions(+), 7 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 6e74c339..47b91448 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -147,6 +147,7 @@ "./tests-output/eopf_geozarr/s2l2_optimized.zarr", "--spatial-chunk", "256", "--enable-sharding", + "--dask-cluster", "--verbose" ], "cwd": "${workspaceFolder}", diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 092dea1b..6b5b35ca 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -1208,10 +1208,20 @@ def add_s2_optimization_commands(subparsers): action='store_true', help='Enable verbose output' ) + s2_parser.add_argument( + '--dask-cluster', + action='store_true', + help='Start a local dask cluster for parallel processing and progress bars' + ) s2_parser.set_defaults(func=convert_s2_optimized_command) def convert_s2_optimized_command(args): """Execute S2 optimized conversion command.""" + # Set up dask cluster if requested + dask_client = setup_dask_cluster( + enable_dask=getattr(args, "dask_cluster", False), verbose=args.verbose + ) + try: # Load input dataset print(f"Loading Sentinel-2 dataset from: {args.input_path}") @@ -1245,6 +1255,17 @@ def convert_s2_optimized_command(args): import traceback traceback.print_exc() return 1 + finally: + # Clean up dask client if it was created + if dask_client is not None: + try: + if hasattr(dask_client, "close"): + dask_client.close() + if args.verbose: + print("🔄 Dask cluster closed") + except Exception as e: + if args.verbose: + print(f"Warning: Error closing dask cluster: {e}") def main() -> None: diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 927d07a6..127da4ba 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -13,6 +13,12 @@ from .s2_validation import S2OptimizationValidator from eopf_geozarr.conversion.fs_utils import get_storage_options, normalize_path +try: + import distributed + DISTRIBUTED_AVAILABLE = True +except ImportError: + DISTRIBUTED_AVAILABLE = False + class S2OptimizedConverter: """Optimized Sentinel-2 to GeoZarr converter.""" @@ -162,16 +168,33 @@ def _write_auxiliary_group( for coord_name in dataset.coords: encoding[coord_name] = {'compressors': None} - # Write dataset + # Write dataset with progress bar storage_options = get_storage_options(group_path) - dataset.to_zarr( + + # Create zarr write job with progress bar + write_job = dataset.to_zarr( group_path, mode='w', consolidated=True, zarr_format=3, encoding=encoding, - storage_options=storage_options + storage_options=storage_options, + compute=False ) + write_job = write_job.persist() + + # Show progress bar if distributed is available + if DISTRIBUTED_AVAILABLE: + try: + # this will return an interactive (non-blocking) widget if in a notebook + # environment. To force the widget to block, provide notebook=False. + distributed.progress(write_job, notebook=False) + except Exception as e: + print(f" Warning: Could not display progress bar: {e}") + write_job.compute() + else: + print(f" Writing {group_type} zarr file...") + write_job.compute() if verbose: print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index a266a4a0..e932bb5a 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -8,6 +8,12 @@ from .s2_resampling import S2ResamplingEngine, determine_variable_type from .s2_band_mapping import get_bands_for_level, get_quality_data_for_level +try: + import distributed + DISTRIBUTED_AVAILABLE = True +except ImportError: + DISTRIBUTED_AVAILABLE = False + class S2MultiscalePyramid: """Creates multiscale pyramids for consolidated S2 data.""" @@ -284,13 +290,29 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) # Rechunk the dataset to align with encoding chunks (following geozarr.py pattern) rechunked_dataset = self._rechunk_dataset_for_encoding(dataset, encoding) - rechunked_dataset.to_zarr( + # Create zarr write job with progress bar + write_job = rechunked_dataset.to_zarr( level_path, mode='w', consolidated=True, zarr_format=3, - encoding=encoding + encoding=encoding, + compute=False ) + write_job = write_job.persist() + + # Show progress bar if distributed is available + if DISTRIBUTED_AVAILABLE: + try: + # this will return an interactive (non-blocking) widget if in a notebook + # environment. To force the widget to block, provide notebook=False. + distributed.progress(write_job, notebook=False) + except Exception as e: + print(f" Warning: Could not display progress bar: {e}") + write_job.compute() + else: + print(f" Writing zarr file...") + write_job.compute() def _should_separate_time_files(self, dataset: xr.Dataset) -> bool: """Determine if time files should be separated for single file per variable per time.""" @@ -339,13 +361,27 @@ def _write_time_separated_dataset( time_encoding = self._update_encoding_for_time_slice(encoding, time_slice) print(f" Writing time slice {t_idx} to {time_path}") - time_slice.to_zarr( + + # Create zarr write job with progress bar for time slice + write_job = time_slice.to_zarr( time_path, mode='w', consolidated=True, zarr_format=3, - encoding=time_encoding + encoding=time_encoding, + compute=False ) + write_job = write_job.persist() + + # Show progress bar if distributed is available + if DISTRIBUTED_AVAILABLE: + try: + distributed.progress(write_job, notebook=False) + except Exception as e: + print(f" Warning: Could not display progress bar: {e}") + write_job.compute() + else: + write_job.compute() def _update_encoding_for_time_slice(self, encoding: Dict, time_slice: xr.Dataset) -> Dict: """Update encoding configuration for time slice data.""" From dfca7e1d84c9eaafdbbc877f3ff5bad5b2d5aac0 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 07:24:46 +0000 Subject: [PATCH 21/83] feat: add compression level option for GeoZarr conversion --- .vscode/launch.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/launch.json b/.vscode/launch.json index 6e74c339..53b91c8a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -146,6 +146,7 @@ // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", "./tests-output/eopf_geozarr/s2l2_optimized.zarr", "--spatial-chunk", "256", + "--compression-level", "5", "--enable-sharding", "--verbose" ], From d4c74871547ed9bd72bb8e3ae4d82ee8f18807da Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 10:08:03 +0200 Subject: [PATCH 22/83] feat: implement Dask parallelization for multiscale pyramid creation and downsampling --- .../s2_optimization/s2_multiscale.py | 293 +++++++++++++++--- 1 file changed, 255 insertions(+), 38 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index e932bb5a..3f8bd8b7 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -10,9 +10,17 @@ try: import distributed + from dask import delayed, compute DISTRIBUTED_AVAILABLE = True + DASK_AVAILABLE = True except ImportError: DISTRIBUTED_AVAILABLE = False + DASK_AVAILABLE = False + # Create dummy delayed function for non-dask environments + def delayed(func): + return func + def compute(*args, **kwargs): + return args class S2MultiscalePyramid: """Creates multiscale pyramids for consolidated S2 data.""" @@ -39,7 +47,7 @@ def create_multiscale_measurements( output_path: str ) -> Dict[int, xr.Dataset]: """ - Create multiscale pyramid from consolidated measurements. + Create multiscale pyramid from consolidated measurements with parallelization. Args: measurements_by_resolution: Data organized by native resolution @@ -48,6 +56,131 @@ def create_multiscale_measurements( Returns: Dictionary of datasets by pyramid level """ + if DASK_AVAILABLE: + return self._create_multiscale_measurements_parallel( + measurements_by_resolution, output_path + ) + else: + return self._create_multiscale_measurements_sequential( + measurements_by_resolution, output_path + ) + + def _create_multiscale_measurements_parallel( + self, + measurements_by_resolution: Dict[int, Dict], + output_path: str + ) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid with Dask parallelization. + + Strategy: + 1. Levels 0, 1, 2 can be created in parallel (they use different source data) + 2. Levels 3+ depend on level 2, so must be created after level 2 + 3. Writing can be parallelized across all levels + """ + print("Creating multiscale pyramid with Dask parallelization...") + pyramid_datasets = {} + + # Phase 1: Create base levels (0, 1, 2) in parallel + print("Phase 1: Creating base pyramid levels (0, 1, 2) in parallel...") + + @delayed + def create_and_prepare_level(level: int, target_resolution: int): + """Create a single pyramid level and prepare for writing.""" + print(f" Creating pyramid level {level} ({target_resolution}m)...") + dataset = self._create_level_dataset( + level, target_resolution, measurements_by_resolution + ) + if dataset and len(dataset.data_vars) > 0: + return level, dataset + return level, None + + # Create base levels in parallel + base_level_tasks = [] + for level in [0, 1, 2]: # Base levels that don't depend on each other + if level in self.pyramid_levels: + target_resolution = self.pyramid_levels[level] + task = create_and_prepare_level(level, target_resolution) + base_level_tasks.append(task) + + # Compute base levels + if base_level_tasks: + print(" Computing base levels...") + base_results = compute(*base_level_tasks) + + # Store results + for level, dataset in base_results: + if dataset is not None: + pyramid_datasets[level] = dataset + + # Phase 2: Create higher levels (3+) that depend on level 2 + if 2 in pyramid_datasets: + print("Phase 2: Creating higher pyramid levels (3+) in parallel...") + + @delayed + def create_higher_level(level: int, target_resolution: int, level_2_data): + """Create higher pyramid levels from level 2 data.""" + print(f" Creating pyramid level {level} ({target_resolution}m)...") + dataset = self._create_downsampled_dataset_from_level2( + level, target_resolution, level_2_data + ) + if dataset and len(dataset.data_vars) > 0: + return level, dataset + return level, None + + # Create higher levels in parallel + higher_level_tasks = [] + level_2_dataset = pyramid_datasets[2] + + for level in range(3, max(self.pyramid_levels.keys()) + 1): + if level in self.pyramid_levels: + target_resolution = self.pyramid_levels[level] + task = create_higher_level(level, target_resolution, level_2_dataset) + higher_level_tasks.append(task) + + # Compute higher levels + if higher_level_tasks: + print(" Computing higher levels...") + higher_results = compute(*higher_level_tasks) + + # Store results + for level, dataset in higher_results: + if dataset is not None: + pyramid_datasets[level] = dataset + + # Phase 3: Write all levels in parallel + print("Phase 3: Writing all pyramid levels in parallel...") + + @delayed + def write_level(level: int, dataset: xr.Dataset): + """Write a single pyramid level.""" + level_path = f"{output_path}/measurements/{level}" + self._write_level_dataset(dataset, level_path, level) + return level + + # Create write tasks for all levels + write_tasks = [] + for level, dataset in pyramid_datasets.items(): + task = write_level(level, dataset) + write_tasks.append(task) + + # Execute all writes in parallel + if write_tasks: + print(" Writing levels...") + compute(*write_tasks) + + print(f"✅ Parallel pyramid creation complete: {len(pyramid_datasets)} levels") + return pyramid_datasets + + def _create_multiscale_measurements_sequential( + self, + measurements_by_resolution: Dict[int, Dict], + output_path: str + ) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid sequentially (fallback for non-Dask environments). + """ + print("Creating multiscale pyramid sequentially...") pyramid_datasets = {} # Create each pyramid level @@ -180,45 +313,58 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase 'y': first_var.coords['y'] } - # Add downsampled 20m data - if 20 in measurements_by_resolution: - data_20m = measurements_by_resolution[20] + # Collect all variables that need downsampling to 60m + vars_to_downsample = [] + if reference_coords: + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) - for category, vars_dict in data_20m.items(): - for var_name, var_data in vars_dict.items(): - if reference_coords: - # Downsample to match 60m grid - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - - var_type = determine_variable_type(var_name, var_data) - downsampled = self.resampler.downsample_variable( - var_data, target_height, target_width, var_type - ) - - # Align coordinates - downsampled = downsampled.assign_coords(reference_coords) - all_vars[var_name] = downsampled - - # Add downsampled 10m data - if 10 in measurements_by_resolution: - data_10m = measurements_by_resolution[10] + # Add 20m data for downsampling + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + for var_name, var_data in vars_dict.items(): + vars_to_downsample.append((var_name, var_data, '20m')) - for category, vars_dict in data_10m.items(): - for var_name, var_data in vars_dict.items(): - if reference_coords: - # Downsample to match 60m grid - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - - var_type = determine_variable_type(var_name, var_data) - downsampled = self.resampler.downsample_variable( - var_data, target_height, target_width, var_type - ) - - # Align coordinates - downsampled = downsampled.assign_coords(reference_coords) - all_vars[var_name] = downsampled + # Add 10m data for downsampling + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + vars_to_downsample.append((var_name, var_data, '10m')) + + # Process all downsampling in parallel if Dask is available + if DASK_AVAILABLE and vars_to_downsample: + @delayed + def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res: str): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + return var_name, downsampled + + # Create tasks for all variables + downsample_tasks = [ + downsample_to_60m_variable(var_name, var_data, source_res) + for var_name, var_data, source_res in vars_to_downsample + ] + + # Compute all in parallel + results = compute(*downsample_tasks) + for var_name, downsampled_var in results: + all_vars[var_name] = downsampled_var + else: + # Sequential fallback + for var_name, var_data, source_res in vars_to_downsample: + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled if not all_vars: return xr.Dataset() @@ -240,6 +386,20 @@ def _create_downsampled_dataset( # Start from level 2 data (60m) which includes all bands, and downsample level_2_dataset = self._create_level_2_dataset(measurements_by_resolution) + if len(level_2_dataset.data_vars) == 0: + return xr.Dataset() + + return self._create_downsampled_dataset_from_level2( + level, target_resolution, level_2_dataset + ) + + def _create_downsampled_dataset_from_level2( + self, + level: int, + target_resolution: int, + level_2_dataset: xr.Dataset + ) -> xr.Dataset: + """Create downsampled dataset from existing level 2 data.""" if len(level_2_dataset.data_vars) == 0: return xr.Dataset() @@ -252,6 +412,63 @@ def _create_downsampled_dataset( target_height = current_height // downsample_factor target_width = current_width // downsample_factor + # Parallelize variable downsampling if Dask is available + if DASK_AVAILABLE: + return self._downsample_variables_parallel( + level_2_dataset, level, target_resolution, target_height, target_width + ) + else: + return self._downsample_variables_sequential( + level_2_dataset, level, target_resolution, target_height, target_width + ) + + def _downsample_variables_parallel( + self, + level_2_dataset: xr.Dataset, + level: int, + target_resolution: int, + target_height: int, + target_width: int + ) -> xr.Dataset: + """Downsample all variables in parallel using Dask.""" + @delayed + def downsample_single_variable(var_name: str, var_data: xr.DataArray): + """Downsample a single variable.""" + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + return var_name, downsampled + + # Create downsampling tasks for all variables + downsample_tasks = [] + for var_name, var_data in level_2_dataset.data_vars.items(): + task = downsample_single_variable(var_name, var_data) + downsample_tasks.append(task) + + # Compute all downsampling in parallel + if downsample_tasks: + results = compute(*downsample_tasks) + downsampled_vars = dict(results) + else: + downsampled_vars = {} + + # Create dataset + dataset = xr.Dataset(downsampled_vars) + dataset.attrs['pyramid_level'] = level + dataset.attrs['resolution_meters'] = target_resolution + + return dataset + + def _downsample_variables_sequential( + self, + level_2_dataset: xr.Dataset, + level: int, + target_resolution: int, + target_height: int, + target_width: int + ) -> xr.Dataset: + """Downsample all variables sequentially (fallback).""" downsampled_vars = {} for var_name, var_data in level_2_dataset.data_vars.items(): From a1539fbf86a24c203671d552a8bfe4fd056809d5 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 14:24:22 +0200 Subject: [PATCH 23/83] feat: enhance multiscale pyramid creation with streaming Dask parallelization and improved memory management --- .../s2_optimization/s2_multiscale.py | 347 +++++++++++++----- 1 file changed, 253 insertions(+), 94 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 3f8bd8b7..193a9cc6 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -71,105 +71,59 @@ def _create_multiscale_measurements_parallel( output_path: str ) -> Dict[int, xr.Dataset]: """ - Create multiscale pyramid with Dask parallelization. - + Create multiscale pyramid with streaming Dask parallelization. + Strategy: - 1. Levels 0, 1, 2 can be created in parallel (they use different source data) - 2. Levels 3+ depend on level 2, so must be created after level 2 - 3. Writing can be parallelized across all levels + 1. Process levels sequentially to minimize memory usage + 2. Use parallel variable processing within each level + 3. Write each level immediately after creation + 4. Keep only necessary data from previous levels for dependencies """ - print("Creating multiscale pyramid with Dask parallelization...") + print("Creating multiscale pyramid with streaming Dask parallelization...") pyramid_datasets = {} - - # Phase 1: Create base levels (0, 1, 2) in parallel - print("Phase 1: Creating base pyramid levels (0, 1, 2) in parallel...") - - @delayed - def create_and_prepare_level(level: int, target_resolution: int): - """Create a single pyramid level and prepare for writing.""" - print(f" Creating pyramid level {level} ({target_resolution}m)...") - dataset = self._create_level_dataset( - level, target_resolution, measurements_by_resolution - ) + + # Process levels sequentially but with parallel variable processing + for level in sorted(self.pyramid_levels.keys()): + target_resolution = self.pyramid_levels[level] + print(f"Creating pyramid level {level} ({target_resolution}m)...") + + # Create the level dataset with parallel variable processing + if level <= 2: + # Base levels: use source measurements data + dataset = self._create_level_dataset_parallel( + level, target_resolution, measurements_by_resolution + ) + else: + # Higher levels: use level 2 data if available + if 2 in pyramid_datasets: + dataset = self._create_downsampled_dataset_from_level2_parallel( + level, target_resolution, pyramid_datasets[2] + ) + else: + print(f" Skipping level {level} - level 2 not available") + continue + if dataset and len(dataset.data_vars) > 0: - return level, dataset - return level, None - - # Create base levels in parallel - base_level_tasks = [] - for level in [0, 1, 2]: # Base levels that don't depend on each other - if level in self.pyramid_levels: - target_resolution = self.pyramid_levels[level] - task = create_and_prepare_level(level, target_resolution) - base_level_tasks.append(task) - - # Compute base levels - if base_level_tasks: - print(" Computing base levels...") - base_results = compute(*base_level_tasks) - - # Store results - for level, dataset in base_results: - if dataset is not None: + # Write immediately to avoid memory buildup + level_path = f"{output_path}/measurements/{level}" + print(f" Writing level {level} to {level_path}") + self._write_level_dataset(dataset, level_path, level) + + # Store only essential levels for dependencies + if level == 2: + # Keep level 2 for creating higher levels pyramid_datasets[level] = dataset - - # Phase 2: Create higher levels (3+) that depend on level 2 - if 2 in pyramid_datasets: - print("Phase 2: Creating higher pyramid levels (3+) in parallel...") - - @delayed - def create_higher_level(level: int, target_resolution: int, level_2_data): - """Create higher pyramid levels from level 2 data.""" - print(f" Creating pyramid level {level} ({target_resolution}m)...") - dataset = self._create_downsampled_dataset_from_level2( - level, target_resolution, level_2_data - ) - if dataset and len(dataset.data_vars) > 0: - return level, dataset - return level, None - - # Create higher levels in parallel - higher_level_tasks = [] - level_2_dataset = pyramid_datasets[2] - - for level in range(3, max(self.pyramid_levels.keys()) + 1): - if level in self.pyramid_levels: - target_resolution = self.pyramid_levels[level] - task = create_higher_level(level, target_resolution, level_2_dataset) - higher_level_tasks.append(task) - - # Compute higher levels - if higher_level_tasks: - print(" Computing higher levels...") - higher_results = compute(*higher_level_tasks) - - # Store results - for level, dataset in higher_results: - if dataset is not None: - pyramid_datasets[level] = dataset - - # Phase 3: Write all levels in parallel - print("Phase 3: Writing all pyramid levels in parallel...") - - @delayed - def write_level(level: int, dataset: xr.Dataset): - """Write a single pyramid level.""" - level_path = f"{output_path}/measurements/{level}" - self._write_level_dataset(dataset, level_path, level) - return level - - # Create write tasks for all levels - write_tasks = [] - for level, dataset in pyramid_datasets.items(): - task = write_level(level, dataset) - write_tasks.append(task) - - # Execute all writes in parallel - if write_tasks: - print(" Writing levels...") - compute(*write_tasks) - - print(f"✅ Parallel pyramid creation complete: {len(pyramid_datasets)} levels") + elif level < 2: + # Keep reference but could be cleaned up if memory is tight + pyramid_datasets[level] = dataset + + # Clean up memory for higher levels (they're already written) + if level > 2: + pyramid_datasets[level] = None # Just track that it was created + else: + print(f" Skipping empty level {level}") + + print(f"✅ Streaming pyramid creation complete: {len([k for k, v in pyramid_datasets.items() if v is not None])} levels") return pyramid_datasets def _create_multiscale_measurements_sequential( @@ -223,6 +177,211 @@ def _create_level_dataset( level, target_resolution, measurements_by_resolution ) + def _create_level_dataset_parallel( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict[int, Dict] + ) -> xr.Dataset: + """Create dataset for a specific pyramid level with parallel processing.""" + + if level == 0: + # Level 0: Only native 10m data (no parallelization needed) + return self._create_level_0_dataset(measurements_by_resolution) + elif level == 1: + # Level 1: All data at 20m (native + downsampled from 10m) - parallel downsampling + return self._create_level_1_dataset_parallel(measurements_by_resolution) + elif level == 2: + # Level 2: All data at 60m (native + downsampled from 20m) - parallel downsampling + return self._create_level_2_dataset_parallel(measurements_by_resolution) + else: + # This shouldn't be called for levels 3+ in the streaming approach + return self._create_downsampled_dataset( + level, target_resolution, measurements_by_resolution + ) + + def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 1 dataset with parallel downsampling from 10m data.""" + all_vars = {} + reference_coords = None + + # Start with native 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 20m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + 'x': first_var.coords['x'], + 'y': first_var.coords['y'] + } + + # Add downsampled 10m data with parallelization + if 10 in measurements_by_resolution and reference_coords: + data_10m = measurements_by_resolution[10] + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + # Collect all 10m variables for parallel processing + vars_to_downsample = [] + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + vars_to_downsample.append((var_name, var_data)) + + # Process variables in parallel if Dask is available + if DASK_AVAILABLE and vars_to_downsample: + @delayed + def downsample_10m_variable(var_name: str, var_data: xr.DataArray): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + return var_name, downsampled + + # Create tasks for all variables + downsample_tasks = [ + downsample_10m_variable(var_name, var_data) + for var_name, var_data in vars_to_downsample + ] + + # Compute all in parallel + print(f" Parallel downsampling {len(downsample_tasks)} variables from 10m to 20m...") + results = compute(*downsample_tasks) + for var_name, downsampled_var in results: + all_vars[var_name] = downsampled_var + else: + # Sequential fallback + for var_name, var_data in vars_to_downsample: + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 1 + dataset.attrs['resolution_meters'] = 20 + + return dataset + + def _create_level_2_dataset_parallel(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 2 dataset with parallel downsampling to 60m.""" + all_vars = {} + reference_coords = None + + # Start with native 60m data + if 60 in measurements_by_resolution: + data_60m = measurements_by_resolution[60] + for category, vars_dict in data_60m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 60m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + 'x': first_var.coords['x'], + 'y': first_var.coords['y'] + } + + # Collect all variables that need downsampling to 60m + vars_to_downsample = [] + if reference_coords: + target_height = len(reference_coords['y']) + target_width = len(reference_coords['x']) + + # Add 20m data for downsampling + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + for var_name, var_data in vars_dict.items(): + vars_to_downsample.append((var_name, var_data, '20m')) + + # Add 10m data for downsampling + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + vars_to_downsample.append((var_name, var_data, '10m')) + + # Process all downsampling in parallel if Dask is available + if DASK_AVAILABLE and vars_to_downsample: + @delayed + def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res: str): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + return var_name, downsampled + + # Create tasks for all variables + downsample_tasks = [ + downsample_to_60m_variable(var_name, var_data, source_res) + for var_name, var_data, source_res in vars_to_downsample + ] + + # Compute all in parallel + print(f" Parallel downsampling {len(downsample_tasks)} variables to 60m...") + results = compute(*downsample_tasks) + for var_name, downsampled_var in results: + all_vars[var_name] = downsampled_var + else: + # Sequential fallback + for var_name, var_data, source_res in vars_to_downsample: + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + all_vars[var_name] = downsampled + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs['pyramid_level'] = 2 + dataset.attrs['resolution_meters'] = 60 + + return dataset + + def _create_downsampled_dataset_from_level2_parallel( + self, + level: int, + target_resolution: int, + level_2_dataset: xr.Dataset + ) -> xr.Dataset: + """Create downsampled dataset from level 2 with parallel processing.""" + if len(level_2_dataset.data_vars) == 0: + return xr.Dataset() + + # Calculate target dimensions (downsample by factor of 2^(level-2)) + downsample_factor = 2 ** (level - 2) + + # Get reference dimensions from level 2 + ref_var = next(iter(level_2_dataset.data_vars.values())) + current_height, current_width = ref_var.shape[-2:] + target_height = current_height // downsample_factor + target_width = current_width // downsample_factor + + # Always use parallel processing for higher levels + return self._downsample_variables_parallel( + level_2_dataset, level, target_resolution, target_height, target_width + ) + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: """Create level 0 dataset with only native 10m data.""" if 10 not in measurements_by_resolution: From ded0f61f03f2352f9635dfcbb4136a14e3ce99b6 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 14:46:17 +0200 Subject: [PATCH 24/83] feat: configure Dask client to use 3 workers with 8GB memory each for improved parallel processing --- src/eopf_geozarr/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 6b5b35ca..75583654 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -54,7 +54,7 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any from dask.distributed import Client # Set up local cluster with high memory limits - client = Client(memory_limit="8GB") # set up local cluster + client = Client(n_workers=3, memory_limit="8GB") # set up local cluster with 3 workers and 8GB memory each # client = Client() # set up local cluster if verbose: From 0a2cc41a23d8945628f7226d2f60908998886a05 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 15:30:11 +0200 Subject: [PATCH 25/83] fix: update import path for geozarr functions in S2OptimizedConverter --- src/eopf_geozarr/s2_optimization/s2_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 127da4ba..c909a819 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -205,7 +205,7 @@ def _add_root_multiscales_metadata( pyramid_datasets: Dict[int, xr.Dataset] ) -> None: """Add multiscales metadata at root level.""" - from ..geozarr import create_native_crs_tile_matrix_set, calculate_overview_levels + from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, calculate_overview_levels # Get information from level 0 dataset if 0 not in pyramid_datasets: From 19cabb84f694656b104fafd48d45779150807c68 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 15:47:21 +0200 Subject: [PATCH 26/83] feat: refactor multiscales metadata handling and root consolidation in S2OptimizedConverter --- .../s2_optimization/s2_converter.py | 203 ++++++++++++------ 1 file changed, 135 insertions(+), 68 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index c909a819..60bb6ea4 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -104,13 +104,13 @@ def convert_s2_optimized( meteorology_path = f"{output_path}/meteorology" self._write_auxiliary_group(meteorology_ds, meteorology_path, "meteorology", verbose) - # Step 5: Create root-level multiscales metadata - print("Step 5: Adding multiscales metadata...") - self._add_root_multiscales_metadata(output_path, pyramid_datasets) + # Step 5: Add multiscales metadata to measurements group + print("Step 5: Adding multiscales metadata to measurements group...") + self._add_measurements_multiscales_metadata(output_path, pyramid_datasets) - # Step 6: Consolidate metadata - print("Step 6: Consolidating metadata...") - self._consolidate_root_metadata(output_path) + # Step 6: Simple root-level consolidation + print("Step 6: Final root-level metadata consolidation...") + self._simple_root_consolidation(output_path, pyramid_datasets) # Step 7: Validation if validate_output: @@ -199,76 +199,143 @@ def _write_auxiliary_group( if verbose: print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") - def _add_root_multiscales_metadata( - self, - output_path: str, - pyramid_datasets: Dict[int, xr.Dataset] - ) -> None: - """Add multiscales metadata at root level.""" - from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, calculate_overview_levels - - # Get information from level 0 dataset - if 0 not in pyramid_datasets: - return - - level_0_ds = pyramid_datasets[0] - if not level_0_ds.data_vars: - return - - # Get spatial info from first variable - first_var = next(iter(level_0_ds.data_vars.values())) - native_height, native_width = first_var.shape[-2:] - native_crs = level_0_ds.rio.crs - native_bounds = level_0_ds.rio.bounds() - - # Calculate overview levels - overview_levels = [] - for level, resolution in self.pyramid_creator.pyramid_levels.items(): - if level in pyramid_datasets: - level_ds = pyramid_datasets[level] - level_var = next(iter(level_ds.data_vars.values())) - level_height, level_width = level_var.shape[-2:] + def _add_measurements_multiscales_metadata(self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset]) -> None: + """Add multiscales metadata to the measurements group using rioxarray.""" + try: + measurements_path = f"{output_path}/measurements" + + # Create multiscales metadata using rioxarray .rio accessor + multiscales_metadata = self._create_multiscales_metadata_with_rio(pyramid_datasets) + + if multiscales_metadata: + # Use zarr to add metadata to the measurements group + storage_options = get_storage_options(measurements_path) - overview_levels.append({ - 'level': level, - 'resolution': resolution, - 'width': level_width, - 'height': level_height, - 'scale_factor': 2 ** level if level > 0 else 1 - }) - - # Create tile matrix set - tile_matrix_set = create_native_crs_tile_matrix_set( - native_crs, native_bounds, overview_levels, "measurements" - ) + try: + import zarr + if storage_options: + store = zarr.storage.FSStore(measurements_path, **storage_options) + else: + store = measurements_path + + # Open the measurements group and add multiscales metadata + measurements_group = zarr.open_group(store, mode='r+') + measurements_group.attrs['multiscales'] = multiscales_metadata + + print(" ✅ Added multiscales metadata to measurements group") + + except Exception as e: + print(f" ⚠️ Could not add multiscales metadata: {e}") + + except Exception as e: + print(f" ⚠️ Error adding multiscales metadata: {e}") + + def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.Dataset]) -> Dict: + """Create multiscales metadata using rioxarray .rio accessor, following geozarr.py format.""" + if not pyramid_datasets: + return {} + + # Get the first available dataset to extract spatial information using .rio + reference_ds = None + for level in sorted(pyramid_datasets.keys()): + if pyramid_datasets[level] is not None: + reference_ds = pyramid_datasets[level] + break + + if not reference_ds or not reference_ds.data_vars: + return {} - # Add metadata to measurements group - measurements_zarr_path = normalize_path(f"{output_path}/measurements/zarr.json") - if os.path.exists(measurements_zarr_path): - import json - with open(measurements_zarr_path, 'r') as f: - zarr_json = json.load(f) + try: + # Use .rio accessor to get CRS and bounds directly from the dataset + if not hasattr(reference_ds, 'rio') or not reference_ds.rio.crs: + return {} + + native_crs = reference_ds.rio.crs + native_bounds = reference_ds.rio.bounds() + + # Create overview levels list following geozarr.py format + overview_levels = [] + for level in sorted(pyramid_datasets.keys()): + if pyramid_datasets[level] is not None: + level_ds = pyramid_datasets[level] + resolution = self.pyramid_creator.pyramid_levels.get(level, level * 10) + + if hasattr(level_ds, 'rio'): + width = level_ds.rio.width + height = level_ds.rio.height + scale_factor = 2 ** level if level > 0 else 1 + + overview_levels.append({ + 'level': level, + 'width': width, + 'height': height, + 'scale_factor': scale_factor, + 'zoom': max(0, level) # Simple zoom calculation + }) - zarr_json.setdefault('attributes', {}) - zarr_json['attributes']['multiscales'] = { - 'tile_matrix_set': tile_matrix_set, - 'resampling_method': 'average', - 'datasets': [{'path': str(level)} for level in sorted(pyramid_datasets.keys())] + if not overview_levels: + return {} + + # Import the functions from geozarr.py to create proper multiscales metadata + from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, _create_tile_matrix_limits + + # Create tile matrix set following geozarr.py exactly + tile_matrix_set = create_native_crs_tile_matrix_set( + native_crs, + native_bounds, + overview_levels, + "measurements" # group prefix + ) + + # Create tile matrix limits following geozarr.py exactly + tile_matrix_limits = _create_tile_matrix_limits(overview_levels, 256) # tile_width=256 + + # Create multiscales metadata following geozarr.py format exactly + multiscales_metadata = { + "tile_matrix_set": tile_matrix_set, + "resampling_method": "average", + "tile_matrix_limits": tile_matrix_limits, } - with open(measurements_zarr_path, 'w') as f: - json.dump(zarr_json, f, indent=2) - - def _consolidate_root_metadata(self, output_path: str) -> None: - """Consolidate metadata at root level.""" + return multiscales_metadata + + except Exception as e: + print(f" Warning: Could not create multiscales metadata with .rio accessor: {e}") + return {} + + def _simple_root_consolidation(self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset]) -> None: + """Simple root-level metadata consolidation using only xarray.""" try: - from eopf_geozarr.conversion.geozarr import consolidate_metadata - from eopf_geozarr.conversion.fs_utils import open_zarr_group + # Since each level and auxiliary group was written with consolidated=True, + # we just need to create a simple root-level consolidated metadata + print(" Performing simple root consolidation...") + + # Use xarray to open and immediately close the root group with consolidation + # This creates/updates the root .zmetadata file + storage_options = get_storage_options(output_path) - zarr_group = open_zarr_group(output_path, mode="r+") - consolidate_metadata(zarr_group.store) + # Open the root zarr group and let xarray handle consolidation + try: + # This will create consolidated metadata at the root level + with xr.open_zarr(output_path, storage_options=storage_options, + consolidated=True, chunks={}) as root_ds: + # Just opening and closing with consolidated=True should be enough + pass + print(" ✅ Root consolidation completed") + except Exception as e: + print(f" ⚠️ Root consolidation using xarray failed, trying zarr directly: {e}") + + # Fallback: minimal zarr consolidation if needed + import zarr + store = zarr.storage.FSStore(output_path, **storage_options) if storage_options else output_path + try: + zarr.consolidate_metadata(store) + print(" ✅ Root consolidation completed with zarr") + except Exception as e2: + print(f" ⚠️ Warning: Root consolidation failed: {e2}") + except Exception as e: - print(f" Warning: Root metadata consolidation failed: {e}") + print(f" ⚠️ Warning: Root consolidation failed: {e}") def _create_result_datatree(self, output_path: str) -> xr.DataTree: """Create result DataTree from written output.""" From 75be5f7740ae4fa4b79553e3b57e31930db81f66 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 15:53:24 +0200 Subject: [PATCH 27/83] feat: add comprehensive unit tests for S2OptimizedConverter and related functionalities --- .../tests/test_s2_converter_simplified.py | 454 ++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 src/eopf_geozarr/tests/test_s2_converter_simplified.py diff --git a/src/eopf_geozarr/tests/test_s2_converter_simplified.py b/src/eopf_geozarr/tests/test_s2_converter_simplified.py new file mode 100644 index 00000000..24af1a0f --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_converter_simplified.py @@ -0,0 +1,454 @@ +""" +Unit tests for simplified S2 converter implementation. + +Tests the new simplified approach that uses only xarray and proper metadata consolidation. +""" + +import os +import tempfile +import shutil +from unittest.mock import Mock, patch, MagicMock +import pytest +import numpy as np +import xarray as xr +import zarr +from rasterio.crs import CRS + +from eopf_geozarr.s2_optimization.s2_converter import S2OptimizedConverter + + +@pytest.fixture +def mock_s2_dataset(): + """Create a mock S2 dataset for testing.""" + # Create test data arrays + coords = { + 'x': (['x'], np.linspace(0, 1000, 100)), + 'y': (['y'], np.linspace(0, 1000, 100)), + 'time': (['time'], [np.datetime64('2023-01-01')]) + } + + # Create test variables + data_vars = { + 'b02': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + 'b03': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + 'b04': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + } + + ds = xr.Dataset(data_vars, coords=coords) + + # Add rioxarray CRS + ds = ds.rio.write_crs('EPSG:32632') + + # Create datatree + dt = xr.DataTree(ds) + dt.attrs = { + 'stac_discovery': { + 'properties': { + 'mission': 'sentinel-2' + } + } + } + + return dt + + +@pytest.fixture +def temp_output_dir(): + """Create temporary directory for test outputs.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + shutil.rmtree(temp_dir) + + +class TestS2OptimizedConverter: + """Test the S2OptimizedConverter class.""" + + def test_init(self): + """Test converter initialization.""" + converter = S2OptimizedConverter( + enable_sharding=True, + spatial_chunk=512, + compression_level=5, + max_retries=2 + ) + + assert converter.enable_sharding is True + assert converter.spatial_chunk == 512 + assert converter.compression_level == 5 + assert converter.max_retries == 2 + assert converter.pyramid_creator is not None + assert converter.validator is not None + + def test_is_sentinel2_dataset_with_mission(self): + """Test S2 detection via mission attribute.""" + converter = S2OptimizedConverter() + + # Test with S2 mission + dt = xr.DataTree() + dt.attrs = { + 'stac_discovery': { + 'properties': { + 'mission': 'sentinel-2a' + } + } + } + + assert converter._is_sentinel2_dataset(dt) is True + + # Test with non-S2 mission + dt.attrs['stac_discovery']['properties']['mission'] = 'sentinel-1' + assert converter._is_sentinel2_dataset(dt) is False + + def test_is_sentinel2_dataset_with_groups(self): + """Test S2 detection via characteristic groups.""" + converter = S2OptimizedConverter() + + dt = xr.DataTree() + dt.attrs = {} + + # Mock groups property using patch + with patch.object(type(dt), 'groups', new_callable=lambda: property(lambda self: [ + '/measurements/reflectance', + '/conditions/geometry', + '/quality/atmosphere' + ])): + assert converter._is_sentinel2_dataset(dt) is True + + # Test with insufficient indicators + with patch.object(type(dt), 'groups', new_callable=lambda: property(lambda self: ['/measurements/reflectance'])): + assert converter._is_sentinel2_dataset(dt) is False + + +class TestMultiscalesMetadata: + """Test multiscales metadata creation.""" + + def test_create_multiscales_metadata_with_rio(self, temp_output_dir): + """Test multiscales metadata creation using rioxarray.""" + converter = S2OptimizedConverter() + + # Create mock pyramid datasets with rioxarray + pyramid_datasets = {} + for level in [0, 1, 2]: + # Create test dataset + coords = { + 'x': (['x'], np.linspace(0, 1000, 100 // (2**level))), + 'y': (['y'], np.linspace(0, 1000, 100 // (2**level))) + } + data_vars = { + 'b02': (['y', 'x'], np.random.rand(100 // (2**level), 100 // (2**level))) + } + ds = xr.Dataset(data_vars, coords=coords) + ds = ds.rio.write_crs('EPSG:32632') + + pyramid_datasets[level] = ds + + # Test metadata creation + metadata = converter._create_multiscales_metadata_with_rio(pyramid_datasets) + + # Verify structure matches geozarr.py format + assert 'tile_matrix_set' in metadata + assert 'resampling_method' in metadata + assert 'tile_matrix_limits' in metadata + assert metadata['resampling_method'] == 'average' + + # Verify tile matrix set structure + tms = metadata['tile_matrix_set'] + assert 'id' in tms + assert 'crs' in tms + assert 'tileMatrices' in tms + assert len(tms['tileMatrices']) == 3 # 3 levels + + def test_create_multiscales_metadata_no_datasets(self): + """Test metadata creation with no datasets.""" + converter = S2OptimizedConverter() + + metadata = converter._create_multiscales_metadata_with_rio({}) + assert metadata == {} + + def test_create_multiscales_metadata_no_crs(self): + """Test metadata creation with datasets lacking CRS.""" + converter = S2OptimizedConverter() + + # Create dataset without CRS + ds = xr.Dataset({'b02': (['y', 'x'], np.random.rand(10, 10))}) + pyramid_datasets = {0: ds} + + metadata = converter._create_multiscales_metadata_with_rio(pyramid_datasets) + assert metadata == {} + + +class TestAuxiliaryGroupWriting: + """Test auxiliary group writing functionality.""" + + @patch('eopf_geozarr.s2_optimization.s2_converter.distributed') + def test_write_auxiliary_group_with_distributed(self, mock_distributed, temp_output_dir): + """Test auxiliary group writing with distributed available.""" + converter = S2OptimizedConverter() + + # Create test dataset + data_vars = { + 'solar_zenith': (['y', 'x'], np.random.rand(50, 50)), + 'solar_azimuth': (['y', 'x'], np.random.rand(50, 50)) + } + coords = { + 'x': (['x'], np.linspace(0, 1000, 50)), + 'y': (['y'], np.linspace(0, 1000, 50)) + } + dataset = xr.Dataset(data_vars, coords=coords) + + group_path = os.path.join(temp_output_dir, 'geometry') + + # Mock distributed progress + mock_progress = Mock() + mock_distributed.progress = mock_progress + + # Test writing + converter._write_auxiliary_group(dataset, group_path, 'geometry', verbose=True) + + # Verify zarr group was created + assert os.path.exists(group_path) + + # Verify group can be opened + zarr_group = zarr.open_group(group_path, mode='r') + assert 'solar_zenith' in zarr_group + assert 'solar_azimuth' in zarr_group + + def test_write_auxiliary_group_without_distributed(self, temp_output_dir): + """Test auxiliary group writing without distributed.""" + converter = S2OptimizedConverter() + + # Create test dataset + data_vars = { + 'temperature': (['y', 'x'], np.random.rand(30, 30)), + 'pressure': (['y', 'x'], np.random.rand(30, 30)) + } + coords = { + 'x': (['x'], np.linspace(0, 1000, 30)), + 'y': (['y'], np.linspace(0, 1000, 30)) + } + dataset = xr.Dataset(data_vars, coords=coords) + + group_path = os.path.join(temp_output_dir, 'meteorology') + + # Patch DISTRIBUTED_AVAILABLE to False + with patch('eopf_geozarr.s2_optimization.s2_converter.DISTRIBUTED_AVAILABLE', False): + converter._write_auxiliary_group(dataset, group_path, 'meteorology', verbose=False) + + # Verify zarr group was created + assert os.path.exists(group_path) + + # Verify group can be opened + zarr_group = zarr.open_group(group_path, mode='r') + assert 'temperature' in zarr_group + assert 'pressure' in zarr_group + + +class TestMetadataConsolidation: + """Test metadata consolidation functionality.""" + + def test_add_measurements_multiscales_metadata(self, temp_output_dir): + """Test adding multiscales metadata to measurements group.""" + converter = S2OptimizedConverter() + + # Create measurements group structure + measurements_path = os.path.join(temp_output_dir, 'measurements') + os.makedirs(measurements_path) + + # Create a minimal zarr group + zarr_group = zarr.open_group(measurements_path, mode='w') + zarr_group.attrs['test'] = 'value' + + # Create mock pyramid datasets + pyramid_datasets = {} + for level in [0, 1]: + coords = { + 'x': (['x'], np.linspace(0, 1000, 50 // (2**level))), + 'y': (['y'], np.linspace(0, 1000, 50 // (2**level))) + } + data_vars = { + 'b02': (['y', 'x'], np.random.rand(50 // (2**level), 50 // (2**level))) + } + ds = xr.Dataset(data_vars, coords=coords) + ds = ds.rio.write_crs('EPSG:32632') + pyramid_datasets[level] = ds + + # Test adding metadata + converter._add_measurements_multiscales_metadata(temp_output_dir, pyramid_datasets) + + # Verify metadata was added + zarr_group = zarr.open_group(measurements_path, mode='r') + assert 'multiscales' in zarr_group.attrs + + multiscales = zarr_group.attrs['multiscales'] + assert 'tile_matrix_set' in multiscales + assert 'resampling_method' in multiscales + assert 'tile_matrix_limits' in multiscales + + def test_add_measurements_multiscales_metadata_error_handling(self, temp_output_dir): + """Test error handling in multiscales metadata addition.""" + converter = S2OptimizedConverter() + + # Test with non-existent measurements path + converter._add_measurements_multiscales_metadata(temp_output_dir, {}) + + # Should not raise an exception, just print warnings + # (We can't easily test print output in unit tests, but the method should handle errors gracefully) + + @patch('xarray.open_zarr') + def test_simple_root_consolidation_success(self, mock_open_zarr, temp_output_dir): + """Test successful root consolidation with xarray.""" + converter = S2OptimizedConverter() + + # Mock successful xarray consolidation + mock_ds = Mock() + mock_open_zarr.return_value.__enter__.return_value = mock_ds + + converter._simple_root_consolidation(temp_output_dir, {}) + + # Verify xarray.open_zarr was called with correct parameters + mock_open_zarr.assert_called_once() + args, kwargs = mock_open_zarr.call_args + assert args[0] == temp_output_dir + assert kwargs['consolidated'] is True + assert kwargs['chunks'] == {} + + @patch('zarr.consolidate_metadata') + @patch('xarray.open_zarr') + def test_simple_root_consolidation_fallback(self, mock_open_zarr, mock_consolidate, temp_output_dir): + """Test fallback to zarr consolidation when xarray fails.""" + converter = S2OptimizedConverter() + + # Mock xarray failure + mock_open_zarr.side_effect = Exception("xarray failed") + + converter._simple_root_consolidation(temp_output_dir, {}) + + # Verify fallback to zarr.consolidate_metadata + mock_consolidate.assert_called_once() + + +class TestEndToEndSimplified: + """Test simplified end-to-end functionality with mocks.""" + + @patch('eopf_geozarr.s2_optimization.s2_converter.S2DataConsolidator') + @patch('eopf_geozarr.s2_optimization.s2_converter.S2MultiscalePyramid') + @patch('eopf_geozarr.s2_optimization.s2_converter.S2OptimizationValidator') + def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid, mock_consolidator, + mock_s2_dataset, temp_output_dir): + """Test the simplified conversion flow with all major components mocked.""" + converter = S2OptimizedConverter() + + # Mock consolidator + mock_consolidator_instance = Mock() + mock_consolidator.return_value = mock_consolidator_instance + mock_consolidator_instance.consolidate_all_data.return_value = ( + {10: {'bands': {'b02': Mock(), 'b03': Mock()}}}, # measurements + {'solar_zenith': Mock()}, # geometry + {'temperature': Mock()} # meteorology + ) + + # Mock pyramid creator + mock_pyramid_instance = Mock() + mock_pyramid.return_value = mock_pyramid_instance + + # Create mock pyramid datasets with rioxarray + pyramid_datasets = {} + for level in [0, 1]: + coords = { + 'x': (['x'], np.linspace(0, 1000, 50 // (2**level))), + 'y': (['y'], np.linspace(0, 1000, 50 // (2**level))) + } + data_vars = { + 'b02': (['y', 'x'], np.random.rand(50 // (2**level), 50 // (2**level))) + } + ds = xr.Dataset(data_vars, coords=coords) + ds = ds.rio.write_crs('EPSG:32632') + pyramid_datasets[level] = ds + + mock_pyramid_instance.create_multiscale_measurements.return_value = pyramid_datasets + + # Mock validator + mock_validator_instance = Mock() + mock_validator.return_value = mock_validator_instance + mock_validator_instance.validate_optimized_dataset.return_value = { + 'is_valid': True, + 'issues': [] + } + + # Mock the multiscales metadata methods + with patch.object(converter, '_add_measurements_multiscales_metadata') as mock_add_metadata, \ + patch.object(converter, '_simple_root_consolidation') as mock_consolidation, \ + patch.object(converter, '_write_auxiliary_group') as mock_write_aux, \ + patch.object(converter, '_create_result_datatree') as mock_create_result: + + mock_create_result.return_value = xr.DataTree() + + # Run conversion + result = converter.convert_s2_optimized( + mock_s2_dataset, + temp_output_dir, + create_geometry_group=True, + create_meteorology_group=True, + validate_output=True, + verbose=True + ) + + # Verify all steps were called + mock_consolidator_instance.consolidate_all_data.assert_called_once() + mock_pyramid_instance.create_multiscale_measurements.assert_called_once() + mock_write_aux.assert_called() # Should be called twice (geometry + meteorology) + mock_add_metadata.assert_called_once_with(temp_output_dir, pyramid_datasets) + mock_consolidation.assert_called_once_with(temp_output_dir, pyramid_datasets) + mock_validator_instance.validate_optimized_dataset.assert_called_once() + + assert result is not None + + +class TestConvenienceFunction: + """Test the convenience function.""" + + @patch('eopf_geozarr.s2_optimization.s2_converter.S2OptimizedConverter') + def test_convert_s2_optimized_convenience_function(self, mock_converter_class): + """Test the convenience function parameter separation.""" + from eopf_geozarr.s2_optimization.s2_converter import convert_s2_optimized + + mock_converter_instance = Mock() + mock_converter_class.return_value = mock_converter_instance + mock_converter_instance.convert_s2_optimized.return_value = Mock() + + # Test parameter separation + dt_input = Mock() + output_path = "/test/path" + + result = convert_s2_optimized( + dt_input, + output_path, + enable_sharding=False, + spatial_chunk=512, + compression_level=2, + max_retries=5, + create_geometry_group=False, + validate_output=False, + verbose=True + ) + + # Verify constructor was called with correct args + mock_converter_class.assert_called_once_with( + enable_sharding=False, + spatial_chunk=512, + compression_level=2, + max_retries=5 + ) + + # Verify method was called with remaining args + mock_converter_instance.convert_s2_optimized.assert_called_once_with( + dt_input, + output_path, + create_geometry_group=False, + validate_output=False, + verbose=True + ) + + +if __name__ == '__main__': + pytest.main([__file__]) From c409af1101840c6d3a80fc1de31ff340ce1b4e79 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 20:04:29 +0200 Subject: [PATCH 28/83] feat: implement geographic metadata writing in S2MultiscalePyramid and add corresponding unit tests --- .../s2_optimization/s2_converter.py | 4 +- .../s2_optimization/s2_multiscale.py | 32 ++ .../tests/test_s2_converter_simplified.py | 31 +- .../tests/test_s2_multiscale_geo_metadata.py | 298 ++++++++++++++++++ 4 files changed, 348 insertions(+), 17 deletions(-) create mode 100644 src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 60bb6ea4..d1a7aadc 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -12,6 +12,7 @@ from .s2_multiscale import S2MultiscalePyramid from .s2_validation import S2OptimizationValidator from eopf_geozarr.conversion.fs_utils import get_storage_options, normalize_path +from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, _create_tile_matrix_limits try: import distributed @@ -276,9 +277,6 @@ def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.D if not overview_levels: return {} - # Import the functions from geozarr.py to create proper multiscales metadata - from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, _create_tile_matrix_limits - # Create tile matrix set following geozarr.py exactly tile_matrix_set = create_native_crs_tile_matrix_set( native_crs, diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 193a9cc6..030bfd92 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -3,6 +3,7 @@ """ import numpy as np +from pyproj import CRS import xarray as xr from typing import Dict, List, Tuple from .s2_resampling import S2ResamplingEngine, determine_variable_type @@ -402,6 +403,8 @@ def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Datase dataset.attrs['pyramid_level'] = 0 dataset.attrs['resolution_meters'] = 10 + self._write_geo_metadata(dataset) + return dataset def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: @@ -451,6 +454,8 @@ def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Datase dataset.attrs['pyramid_level'] = 1 dataset.attrs['resolution_meters'] = 20 + self._write_geo_metadata(dataset) + return dataset def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: @@ -533,6 +538,8 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res dataset.attrs['pyramid_level'] = 2 dataset.attrs['resolution_meters'] = 60 + self._write_geo_metadata(dataset) + return dataset def _create_downsampled_dataset( @@ -617,6 +624,8 @@ def downsample_single_variable(var_name: str, var_data: xr.DataArray): dataset.attrs['pyramid_level'] = level dataset.attrs['resolution_meters'] = target_resolution + self._write_geo_metadata(dataset) + return dataset def _downsample_variables_sequential( @@ -642,6 +651,8 @@ def _downsample_variables_sequential( dataset.attrs['pyramid_level'] = level dataset.attrs['resolution_meters'] = target_resolution + self._write_geo_metadata(dataset) + return dataset def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> None: @@ -922,3 +933,24 @@ def _rechunk_dataset_for_encoding(self, dataset: xr.Dataset, encoding: Dict) -> rechunked_dataset = xr.Dataset(rechunked_vars, coords=dataset.coords, attrs=dataset.attrs) return rechunked_dataset + + def _write_geo_metadata(self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref") -> None: + """ + Write geographic metadata to the dataset. + Adds a grid_mapping variable and updates all data variables to reference it. + """ + + # take the CRS from one of the data variables if available + crs = None + for var in dataset.data_vars.values(): + if hasattr(var, 'rio') and var.rio.crs: + crs = var.rio.crs + break + elif "proj:epsg" in var.attrs: + epsg = var.attrs["proj:epsg"] + crs = CRS.from_epsg(epsg) + break + + # Use standard CRS and transform if available + if crs is not None: + dataset.rio.write_crs(crs, inplace=True) diff --git a/src/eopf_geozarr/tests/test_s2_converter_simplified.py b/src/eopf_geozarr/tests/test_s2_converter_simplified.py index 24af1a0f..47096485 100644 --- a/src/eopf_geozarr/tests/test_s2_converter_simplified.py +++ b/src/eopf_geozarr/tests/test_s2_converter_simplified.py @@ -336,20 +336,6 @@ class TestEndToEndSimplified: def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid, mock_consolidator, mock_s2_dataset, temp_output_dir): """Test the simplified conversion flow with all major components mocked.""" - converter = S2OptimizedConverter() - - # Mock consolidator - mock_consolidator_instance = Mock() - mock_consolidator.return_value = mock_consolidator_instance - mock_consolidator_instance.consolidate_all_data.return_value = ( - {10: {'bands': {'b02': Mock(), 'b03': Mock()}}}, # measurements - {'solar_zenith': Mock()}, # geometry - {'temperature': Mock()} # meteorology - ) - - # Mock pyramid creator - mock_pyramid_instance = Mock() - mock_pyramid.return_value = mock_pyramid_instance # Create mock pyramid datasets with rioxarray pyramid_datasets = {} @@ -365,6 +351,18 @@ def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid ds = ds.rio.write_crs('EPSG:32632') pyramid_datasets[level] = ds + # Mock consolidator + mock_consolidator_instance = Mock() + mock_consolidator.return_value = mock_consolidator_instance + mock_consolidator_instance.consolidate_all_data.return_value = ( + {10: {'bands': {'b02': Mock(), 'b03': Mock()}}}, # measurements + {'solar_zenith': Mock()}, # geometry + {'temperature': Mock()} # meteorology + ) + + # Mock pyramid creator + mock_pyramid_instance = Mock() + mock_pyramid.return_value = mock_pyramid_instance mock_pyramid_instance.create_multiscale_measurements.return_value = pyramid_datasets # Mock validator @@ -375,6 +373,11 @@ def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid 'issues': [] } + # Create converter and replace instances that were created during initialization + converter = S2OptimizedConverter() + converter.pyramid_creator = mock_pyramid_instance + converter.validator = mock_validator_instance + # Mock the multiscales metadata methods with patch.object(converter, '_add_measurements_multiscales_metadata') as mock_add_metadata, \ patch.object(converter, '_simple_root_consolidation') as mock_consolidation, \ diff --git a/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py b/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py new file mode 100644 index 00000000..cd450cdd --- /dev/null +++ b/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py @@ -0,0 +1,298 @@ +""" +Unit tests for _write_geo_metadata method in S2MultiscalePyramid. + +Tests the geographic metadata writing functionality added to level creation. +""" + +import pytest +import numpy as np +import xarray as xr +from unittest.mock import Mock, patch +from pyproj import CRS + +from eopf_geozarr.s2_optimization.s2_multiscale import S2MultiscalePyramid + + +@pytest.fixture +def pyramid_creator(): + """Create a S2MultiscalePyramid instance for testing.""" + return S2MultiscalePyramid(enable_sharding=True, spatial_chunk=1024) + + +@pytest.fixture +def sample_dataset_with_crs(): + """Create a sample dataset with CRS information.""" + coords = { + 'x': (['x'], np.linspace(0, 1000, 100)), + 'y': (['y'], np.linspace(0, 1000, 100)), + 'time': (['time'], [np.datetime64('2023-01-01')]) + } + + data_vars = { + 'b02': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + 'b03': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + 'b04': (['y', 'x'], np.random.rand(100, 100)) + } + + ds = xr.Dataset(data_vars, coords=coords) + + ds['b02'].attrs['proj:epsg'] = 32632 + ds['b03'].attrs['proj:epsg'] = 32632 + ds['b04'].attrs['proj:epsg'] = 32632 + + return ds + + +@pytest.fixture +def sample_dataset_with_epsg_attrs(): + """Create a sample dataset with EPSG in attributes.""" + coords = { + 'x': (['x'], np.linspace(0, 1000, 50)), + 'y': (['y'], np.linspace(0, 1000, 50)) + } + + data_vars = { + 'b05': (['y', 'x'], np.random.rand(50, 50)), + 'b06': (['y', 'x'], np.random.rand(50, 50)) + } + + ds = xr.Dataset(data_vars, coords=coords) + + # Add EPSG to variable attributes + ds['b05'].attrs['proj:epsg'] = 32632 + ds['b06'].attrs['proj:epsg'] = 32632 + + return ds + + +@pytest.fixture +def sample_dataset_no_crs(): + """Create a sample dataset without CRS information.""" + coords = { + 'x': (['x'], np.linspace(0, 1000, 25)), + 'y': (['y'], np.linspace(0, 1000, 25)) + } + + data_vars = { + 'b11': (['y', 'x'], np.random.rand(25, 25)), + 'b12': (['y', 'x'], np.random.rand(25, 25)) + } + + return xr.Dataset(data_vars, coords=coords) + + +class TestWriteGeoMetadata: + """Test the _write_geo_metadata method.""" + + def test_write_geo_metadata_with_rio_crs(self, pyramid_creator, sample_dataset_with_crs): + """Test _write_geo_metadata with dataset that has rioxarray CRS.""" + + # Call the method + pyramid_creator._write_geo_metadata(sample_dataset_with_crs) + + # Verify CRS was written + assert hasattr(sample_dataset_with_crs, 'rio') + assert sample_dataset_with_crs.rio.crs is not None + assert sample_dataset_with_crs.rio.crs.to_epsg() == 32632 + + def test_write_geo_metadata_with_epsg_attrs(self, pyramid_creator, sample_dataset_with_epsg_attrs): + """Test _write_geo_metadata with dataset that has EPSG in variable attributes.""" + + # Verify initial state - no CRS + assert not hasattr(sample_dataset_with_epsg_attrs, 'rio') or sample_dataset_with_epsg_attrs.rio.crs is None + + # Call the method + pyramid_creator._write_geo_metadata(sample_dataset_with_epsg_attrs) + + # Verify CRS was written from attributes + assert hasattr(sample_dataset_with_epsg_attrs, 'rio') + assert sample_dataset_with_epsg_attrs.rio.crs is not None + assert sample_dataset_with_epsg_attrs.rio.crs.to_epsg() == 32632 + + def test_write_geo_metadata_no_crs(self, pyramid_creator, sample_dataset_no_crs): + """Test _write_geo_metadata with dataset that has no CRS information.""" + + # Verify initial state - no CRS + assert not hasattr(sample_dataset_no_crs, 'rio') or sample_dataset_no_crs.rio.crs is None + + # Call the method - should not fail but also not add CRS + pyramid_creator._write_geo_metadata(sample_dataset_no_crs) + + # Verify no CRS was added (method handles gracefully) + # The method should not fail even when no CRS is available + # This tests the robustness of the method + + def test_write_geo_metadata_custom_grid_mapping_name(self, pyramid_creator, sample_dataset_with_crs): + """Test _write_geo_metadata with custom grid_mapping variable name.""" + + # Call the method with custom grid mapping name + custom_name = "custom_spatial_ref" + pyramid_creator._write_geo_metadata(sample_dataset_with_crs, custom_name) + + # Verify CRS was written + assert hasattr(sample_dataset_with_crs, 'rio') + assert sample_dataset_with_crs.rio.crs is not None + + def test_write_geo_metadata_preserves_existing_data(self, pyramid_creator, sample_dataset_with_crs): + """Test that _write_geo_metadata preserves existing data variables and coordinates.""" + + # Store original data + original_vars = list(sample_dataset_with_crs.data_vars.keys()) + original_coords = list(sample_dataset_with_crs.coords.keys()) + original_b02_data = sample_dataset_with_crs['b02'].values.copy() + + # Call the method + pyramid_creator._write_geo_metadata(sample_dataset_with_crs) + + # Verify all original data is preserved + assert list(sample_dataset_with_crs.data_vars.keys()) == original_vars + assert all(coord in sample_dataset_with_crs.coords for coord in original_coords) + assert np.array_equal(sample_dataset_with_crs['b02'].values, original_b02_data) + + def test_write_geo_metadata_empty_dataset(self, pyramid_creator): + """Test _write_geo_metadata with empty dataset.""" + + empty_ds = xr.Dataset({}, coords={}) + + # Call the method - should handle gracefully + pyramid_creator._write_geo_metadata(empty_ds) + + # Verify method doesn't fail with empty dataset + # This tests robustness + + def test_write_geo_metadata_rio_write_crs_called(self, pyramid_creator, sample_dataset_with_crs): + """Test that rio.write_crs is called correctly.""" + + # Mock the rio.write_crs method + with patch.object(sample_dataset_with_crs.rio, 'write_crs') as mock_write_crs: + # Call the method + pyramid_creator._write_geo_metadata(sample_dataset_with_crs) + + # Verify rio.write_crs was called with correct arguments + mock_write_crs.assert_called_once() + call_args = mock_write_crs.call_args + assert call_args[1]['inplace'] is True # inplace=True should be passed + + def test_write_geo_metadata_crs_from_multiple_sources(self, pyramid_creator): + """Test CRS detection from multiple sources in priority order.""" + + # Create dataset with both rio CRS and EPSG attributes + coords = { + 'x': (['x'], np.linspace(0, 1000, 50)), + 'y': (['y'], np.linspace(0, 1000, 50)) + } + + data_vars = { + 'b08': (['y', 'x'], np.random.rand(50, 50)) + } + + ds = xr.Dataset(data_vars, coords=coords) + + # Add both rio CRS and EPSG attribute (rio should take priority) + ds = ds.rio.write_crs('EPSG:4326') # Rio CRS + ds['b08'].attrs['proj:epsg'] = 32632 # EPSG attribute + + # Call the method + pyramid_creator._write_geo_metadata(ds) + + # Verify rio CRS was used (priority over attributes) + assert ds.rio.crs.to_epsg() == 4326 # Should still be 4326, not 32632 + + def test_write_geo_metadata_integration_with_level_creation(self, pyramid_creator): + """Test that _write_geo_metadata is properly integrated in level creation methods.""" + + # Create mock measurements data + measurements_by_resolution = { + 10: { + 'bands': { + 'b02': xr.DataArray( + np.random.rand(100, 100), + dims=['y', 'x'], + coords={ + 'x': (['x'], np.linspace(0, 1000, 100)), + 'y': (['y'], np.linspace(0, 1000, 100)) + } + ).rio.write_crs('EPSG:32632') + } + } + } + + # Create level 0 dataset (which should call _write_geo_metadata) + level_0_ds = pyramid_creator._create_level_0_dataset(measurements_by_resolution) + + # Verify CRS was written by _write_geo_metadata + assert hasattr(level_0_ds, 'rio') + assert level_0_ds.rio.crs is not None + assert level_0_ds.rio.crs.to_epsg() == 32632 + + +class TestWriteGeoMetadataEdgeCases: + """Test edge cases for _write_geo_metadata method.""" + + def test_write_geo_metadata_invalid_crs(self, pyramid_creator): + """Test _write_geo_metadata with invalid CRS data.""" + + coords = { + 'x': (['x'], np.linspace(0, 1000, 10)), + 'y': (['y'], np.linspace(0, 1000, 10)) + } + + data_vars = { + 'test_var': (['y', 'x'], np.random.rand(10, 10)) + } + + ds = xr.Dataset(data_vars, coords=coords) + + # Add invalid EPSG code + ds['test_var'].attrs['proj:epsg'] = 'invalid_epsg' + + # Method should raise an exception for invalid CRS (normal behavior) + from pyproj.exceptions import CRSError + with pytest.raises(CRSError): + pyramid_creator._write_geo_metadata(ds) + + def test_write_geo_metadata_mixed_crs_variables(self, pyramid_creator): + """Test _write_geo_metadata with variables having different CRS information.""" + + coords = { + 'x': (['x'], np.linspace(0, 1000, 20)), + 'y': (['y'], np.linspace(0, 1000, 20)) + } + + data_vars = { + 'var1': (['y', 'x'], np.random.rand(20, 20)), + 'var2': (['y', 'x'], np.random.rand(20, 20)) + } + + ds = xr.Dataset(data_vars, coords=coords) + + # Add different EPSG codes to different variables + ds['var1'].attrs['proj:epsg'] = 32632 + ds['var2'].attrs['proj:epsg'] = 4326 + + # Call the method (should use the first CRS found) + pyramid_creator._write_geo_metadata(ds) + + # Verify a CRS was applied (should be the first one found) + assert hasattr(ds, 'rio') + + def test_write_geo_metadata_maintains_dataset_attrs(self, pyramid_creator, sample_dataset_with_crs): + """Test that _write_geo_metadata maintains dataset-level attributes.""" + + # Add some dataset attributes + sample_dataset_with_crs.attrs['pyramid_level'] = 1 + sample_dataset_with_crs.attrs['resolution_meters'] = 20 + sample_dataset_with_crs.attrs['custom_attr'] = 'test_value' + + original_attrs = sample_dataset_with_crs.attrs.copy() + + # Call the method + pyramid_creator._write_geo_metadata(sample_dataset_with_crs) + + # Verify dataset attributes are preserved + for key, value in original_attrs.items(): + assert sample_dataset_with_crs.attrs[key] == value + + +if __name__ == '__main__': + pytest.main([__file__]) From 7df85bd28d4ce617e0698df5cd661cfa5e8af6dd Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 20:06:39 +0200 Subject: [PATCH 29/83] feat: skip duplicate variables during downsampling in S2MultiscalePyramid --- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 030bfd92..60f92652 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -432,6 +432,9 @@ def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Datase for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): + # skip if already present from 20m data + if var_name in all_vars: + continue if reference_coords: # Downsample to match 20m grid target_height = len(reference_coords['y']) @@ -488,6 +491,9 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase data_20m = measurements_by_resolution[20] for category, vars_dict in data_20m.items(): for var_name, var_data in vars_dict.items(): + # skip if already present from 20m data + if var_name in all_vars: + continue vars_to_downsample.append((var_name, var_data, '20m')) # Add 10m data for downsampling @@ -495,6 +501,9 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase data_10m = measurements_by_resolution[10] for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): + # skip if already present from 20m data + if var_name in all_vars: + continue vars_to_downsample.append((var_name, var_data, '10m')) # Process all downsampling in parallel if Dask is available From e7896c4187442044f7d8643ed723a082eb58e9ed Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 20:10:11 +0200 Subject: [PATCH 30/83] feat: enhance CRS handling by adding grid mapping variable to dataset attributes --- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 60f92652..d5e55bfc 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -962,4 +962,8 @@ def _write_geo_metadata(self, dataset: xr.Dataset, grid_mapping_var_name: str = # Use standard CRS and transform if available if crs is not None: - dataset.rio.write_crs(crs, inplace=True) + dataset.rio.write_crs(crs, grid_mapping_name=grid_mapping_var_name, inplace=True) + + # Set the grid mapping variable for all data variables + for var in dataset.data_vars.values(): + var.attrs['grid_mapping'] = grid_mapping_var_name From eec6b27f9a29805d5fe1cf9af8f6ee8e2526ade3 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 20:14:36 +0200 Subject: [PATCH 31/83] feat: add grid mapping variable writing for datasets in S2MultiscalePyramid --- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index d5e55bfc..d68f8037 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -963,7 +963,10 @@ def _write_geo_metadata(self, dataset: xr.Dataset, grid_mapping_var_name: str = # Use standard CRS and transform if available if crs is not None: dataset.rio.write_crs(crs, grid_mapping_name=grid_mapping_var_name, inplace=True) + # Add grid mapping variable + dataset.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) - # Set the grid mapping variable for all data variables - for var in dataset.data_vars.values(): - var.attrs['grid_mapping'] = grid_mapping_var_name + # Set the grid mapping variable for all data variables + for var in dataset.data_vars.values(): + var.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) + From d4a5a95f37a8b9a5c68d25b276d9c740b2bd6d6e Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 20:27:05 +0200 Subject: [PATCH 32/83] feat: skip already present variables during downsampling in S2MultiscalePyramid --- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index d68f8037..583fc160 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -230,6 +230,9 @@ def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> vars_to_downsample = [] for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): + # skip if already present from 20m data + if var_name in all_vars: + continue vars_to_downsample.append((var_name, var_data)) # Process variables in parallel if Dask is available From df005efe83e68c8ff82465149f0dbf9e90536bb0 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Sun, 28 Sep 2025 19:42:13 +0000 Subject: [PATCH 33/83] feat: reduce memory limit for Dask client to 4GB and add geographic metadata writing in S2MultiscalePyramid --- src/eopf_geozarr/cli.py | 2 +- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 75583654..8c0ba32a 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -54,7 +54,7 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any from dask.distributed import Client # Set up local cluster with high memory limits - client = Client(n_workers=3, memory_limit="8GB") # set up local cluster with 3 workers and 8GB memory each + client = Client(n_workers=3, memory_limit="4GB") # set up local cluster with 3 workers and 8GB memory each # client = Client() # set up local cluster if verbose: diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 583fc160..d6169fb0 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -209,7 +209,7 @@ def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> # Start with native 20m data if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] - for category, vars_dict in data_20m.items(): + for category, vars_dict in data_20m.items(): all_vars.update(vars_dict) # Get reference coordinates from 20m data @@ -277,6 +277,8 @@ def downsample_10m_variable(var_name: str, var_data: xr.DataArray): dataset.attrs['pyramid_level'] = 1 dataset.attrs['resolution_meters'] = 20 + self._write_geo_metadata(dataset) + return dataset def _create_level_2_dataset_parallel(self, measurements_by_resolution: Dict) -> xr.Dataset: @@ -360,6 +362,8 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res dataset.attrs['pyramid_level'] = 2 dataset.attrs['resolution_meters'] = 60 + self._write_geo_metadata(dataset) + return dataset def _create_downsampled_dataset_from_level2_parallel( From 0a265aa7d09ac83a104ced6a189d7a04f2468f59 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Mon, 29 Sep 2025 09:45:52 +0200 Subject: [PATCH 34/83] Refactor test cases and improve code formatting in S2 resampling tests and sharding fix - Reorganized import statements and improved code formatting for better readability in `test_s2_resampling.py`. - Updated sample data creation functions to use consistent array formatting and improved attribute handling. - Enhanced assertions in tests to ensure clarity and consistency. - Improved test output messages in `test_sharding_fix.py` for better debugging and understanding of test results. - Ensured that shard dimensions are properly calculated and validated against chunk dimensions in the sharding tests. --- src/eopf_geozarr/cli.py | 68 +- src/eopf_geozarr/conversion/geozarr.py | 2 +- .../s2_optimization/s2_band_mapping.py | 61 +- .../s2_optimization/s2_converter.py | 343 +++++---- .../s2_optimization/s2_data_consolidator.py | 189 ++--- .../s2_optimization/s2_multiscale.py | 558 +++++++------- .../s2_optimization/s2_resampling.py | 181 +++-- .../s2_optimization/s2_validation.py | 19 +- .../tests/test_s2_band_mapping.py | 24 +- src/eopf_geozarr/tests/test_s2_converter.py | 11 +- .../tests/test_s2_converter_simplified.py | 410 +++++----- .../tests/test_s2_data_consolidator.py | 624 +++++++++------ src/eopf_geozarr/tests/test_s2_multiscale.py | 717 +++++++++--------- .../tests/test_s2_multiscale_geo_metadata.py | 281 +++---- src/eopf_geozarr/tests/test_s2_resampling.py | 459 ++++++----- test_sharding_fix.py | 107 +-- 16 files changed, 2150 insertions(+), 1904 deletions(-) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index f1a2f71d..a4a08975 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -1146,75 +1146,64 @@ def create_parser() -> argparse.ArgumentParser: "--verbose", action="store_true", help="Enable verbose output" ) validate_parser.set_defaults(func=validate_command) - + # Add S2 optimization commands add_s2_optimization_commands(subparsers) return parser + def add_s2_optimization_commands(subparsers): """Add S2 optimization commands to CLI parser.""" - + # Convert S2 optimized command s2_parser = subparsers.add_parser( - 'convert-s2-optimized', - help='Convert Sentinel-2 dataset to optimized structure' + "convert-s2-optimized", help="Convert Sentinel-2 dataset to optimized structure" ) s2_parser.add_argument( - 'input_path', - type=str, - help='Path to input Sentinel-2 dataset (Zarr format)' + "input_path", type=str, help="Path to input Sentinel-2 dataset (Zarr format)" ) s2_parser.add_argument( - 'output_path', - type=str, - help='Path for output optimized dataset' + "output_path", type=str, help="Path for output optimized dataset" ) s2_parser.add_argument( - '--spatial-chunk', + "--spatial-chunk", type=int, default=1024, - help='Spatial chunk size (default: 1024)' + help="Spatial chunk size (default: 1024)", ) s2_parser.add_argument( - '--enable-sharding', - action='store_true', - help='Enable Zarr v3 sharding' + "--enable-sharding", action="store_true", help="Enable Zarr v3 sharding" ) s2_parser.add_argument( - '--compression-level', + "--compression-level", type=int, default=3, choices=range(1, 10), - help='Compression level 1-9 (default: 3)' + help="Compression level 1-9 (default: 3)", ) s2_parser.add_argument( - '--skip-geometry', - action='store_true', - help='Skip creating geometry group' + "--skip-geometry", action="store_true", help="Skip creating geometry group" ) s2_parser.add_argument( - '--skip-meteorology', - action='store_true', - help='Skip creating meteorology group' + "--skip-meteorology", + action="store_true", + help="Skip creating meteorology group", ) s2_parser.add_argument( - '--skip-validation', - action='store_true', - help='Skip output validation' + "--skip-validation", action="store_true", help="Skip output validation" ) s2_parser.add_argument( - '--verbose', - action='store_true', - help='Enable verbose output' + "--verbose", action="store_true", help="Enable verbose output" ) s2_parser.add_argument( - '--dask-cluster', - action='store_true', - help='Start a local dask cluster for parallel processing and progress bars' + "--dask-cluster", + action="store_true", + help="Start a local dask cluster for parallel processing and progress bars", ) s2_parser.set_defaults(func=convert_s2_optimized_command) + def convert_s2_optimized_command(args): """Execute S2 optimized conversion command.""" # Set up dask cluster if requested @@ -1228,11 +1217,11 @@ def convert_s2_optimized_command(args): storage_options = get_storage_options(str(args.input_path)) dt_input = xr.open_datatree( str(args.input_path), - engine='zarr', - chunks='auto', - storage_options=storage_options + engine="zarr", + chunks="auto", + storage_options=storage_options, ) - + # Convert dt_optimized = convert_s2_optimized( dt_input=dt_input, @@ -1243,16 +1232,17 @@ def convert_s2_optimized_command(args): create_geometry_group=not args.skip_geometry, create_meteorology_group=not args.skip_meteorology, validate_output=not args.skip_validation, - verbose=args.verbose + verbose=args.verbose, ) - + print(f"✅ S2 optimization completed: {args.output_path}") return 0 - + except Exception as e: print(f"❌ Error during S2 optimization: {e}") if args.verbose: import traceback + traceback.print_exc() return 1 finally: diff --git a/src/eopf_geozarr/conversion/geozarr.py b/src/eopf_geozarr/conversion/geozarr.py index 8178c53d..3c300578 100644 --- a/src/eopf_geozarr/conversion/geozarr.py +++ b/src/eopf_geozarr/conversion/geozarr.py @@ -26,7 +26,7 @@ import zarr from pyproj import CRS from rasterio.warp import calculate_default_transform -from zarr.codecs import BloscCodec, ShardingCodec +from zarr.codecs import BloscCodec from zarr.core.sync import sync from zarr.storage import StoreLike from zarr.storage._common import make_store_path diff --git a/src/eopf_geozarr/s2_optimization/s2_band_mapping.py b/src/eopf_geozarr/s2_optimization/s2_band_mapping.py index af4bbe67..1d5f1cec 100644 --- a/src/eopf_geozarr/s2_optimization/s2_band_mapping.py +++ b/src/eopf_geozarr/s2_optimization/s2_band_mapping.py @@ -2,64 +2,68 @@ Band mapping and resolution definitions for Sentinel-2 optimization. """ -from typing import Dict, List, Set from dataclasses import dataclass +from typing import Dict, List, Set + @dataclass class BandInfo: """Information about a spectral band.""" + name: str native_resolution: int # meters data_type: str wavelength_center: float # nanometers - wavelength_width: float # nanometers + wavelength_width: float # nanometers + # Native resolution definitions NATIVE_BANDS: Dict[int, List[str]] = { - 10: ['b02', 'b03', 'b04', 'b08'], # Blue, Green, Red, NIR - 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], # Red Edge, SWIR - 60: ['b01', 'b09'] # Coastal, Water Vapor + 10: ["b02", "b03", "b04", "b08"], # Blue, Green, Red, NIR + 20: ["b05", "b06", "b07", "b11", "b12", "b8a"], # Red Edge, SWIR + 60: ["b01", "b09"], # Coastal, Water Vapor } # Complete band information BAND_INFO: Dict[str, BandInfo] = { - 'b01': BandInfo('b01', 60, 'uint16', 443, 21), # Coastal aerosol - 'b02': BandInfo('b02', 10, 'uint16', 490, 66), # Blue - 'b03': BandInfo('b03', 10, 'uint16', 560, 36), # Green - 'b04': BandInfo('b04', 10, 'uint16', 665, 31), # Red - 'b05': BandInfo('b05', 20, 'uint16', 705, 15), # Red Edge 1 - 'b06': BandInfo('b06', 20, 'uint16', 740, 15), # Red Edge 2 - 'b07': BandInfo('b07', 20, 'uint16', 783, 20), # Red Edge 3 - 'b08': BandInfo('b08', 10, 'uint16', 842, 106), # NIR - 'b8a': BandInfo('b8a', 20, 'uint16', 865, 21), # NIR Narrow - 'b09': BandInfo('b09', 60, 'uint16', 945, 20), # Water Vapor - 'b11': BandInfo('b11', 20, 'uint16', 1614, 91), # SWIR 1 - 'b12': BandInfo('b12', 20, 'uint16', 2202, 175), # SWIR 2 + "b01": BandInfo("b01", 60, "uint16", 443, 21), # Coastal aerosol + "b02": BandInfo("b02", 10, "uint16", 490, 66), # Blue + "b03": BandInfo("b03", 10, "uint16", 560, 36), # Green + "b04": BandInfo("b04", 10, "uint16", 665, 31), # Red + "b05": BandInfo("b05", 20, "uint16", 705, 15), # Red Edge 1 + "b06": BandInfo("b06", 20, "uint16", 740, 15), # Red Edge 2 + "b07": BandInfo("b07", 20, "uint16", 783, 20), # Red Edge 3 + "b08": BandInfo("b08", 10, "uint16", 842, 106), # NIR + "b8a": BandInfo("b8a", 20, "uint16", 865, 21), # NIR Narrow + "b09": BandInfo("b09", 60, "uint16", 945, 20), # Water Vapor + "b11": BandInfo("b11", 20, "uint16", 1614, 91), # SWIR 1 + "b12": BandInfo("b12", 20, "uint16", 2202, 175), # SWIR 2 } # Quality data mapping - defines which auxiliary data exists at which resolutions QUALITY_DATA_NATIVE: Dict[str, int] = { - 'scl': 20, # Scene Classification Layer - native 20m - 'aot': 20, # Aerosol Optical Thickness - native 20m - 'wvp': 20, # Water Vapor - native 20m - 'cld': 20, # Cloud probability - native 20m - 'snw': 20, # Snow probability - native 20m + "scl": 20, # Scene Classification Layer - native 20m + "aot": 20, # Aerosol Optical Thickness - native 20m + "wvp": 20, # Water Vapor - native 20m + "cld": 20, # Cloud probability - native 20m + "snw": 20, # Snow probability - native 20m } # Detector footprint availability - matches spectral bands DETECTOR_FOOTPRINT_NATIVE: Dict[int, List[str]] = { - 10: ['b02', 'b03', 'b04', 'b08'], - 20: ['b05', 'b06', 'b07', 'b11', 'b12', 'b8a'], - 60: ['b01', 'b09'] + 10: ["b02", "b03", "b04", "b08"], + 20: ["b05", "b06", "b07", "b11", "b12", "b8a"], + 60: ["b01", "b09"], } + def get_bands_for_level(level: int) -> Set[str]: """ Get all bands available at a given pyramid level. - + Args: level: Pyramid level (0=10m, 1=20m, 2=60m, 3+=downsampled) - + Returns: Set of band names available at this level """ @@ -72,9 +76,10 @@ def get_bands_for_level(level: int) -> Set[str]: else: # Further downsampling - all bands return set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + def get_quality_data_for_level(level: int) -> Set[str]: """Get quality data available at a given level (no upsampling).""" if level == 0: # 10m - no quality data (would require upsampling) return set() elif level >= 1: # 20m and below - all quality data available - return set(QUALITY_DATA_NATIVE.keys()) \ No newline at end of file + return set(QUALITY_DATA_NATIVE.keys()) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index d1a7aadc..59b50eb1 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -2,43 +2,48 @@ Main S2 optimization converter. """ -import os import time -from pathlib import Path -from typing import Dict, Optional, List +from typing import Dict + import xarray as xr -from .s2_data_consolidator import S2DataConsolidator, create_consolidated_dataset +from eopf_geozarr.conversion.fs_utils import get_storage_options +from eopf_geozarr.conversion.geozarr import ( + _create_tile_matrix_limits, + create_native_crs_tile_matrix_set, +) + +from .s2_data_consolidator import S2DataConsolidator from .s2_multiscale import S2MultiscalePyramid from .s2_validation import S2OptimizationValidator -from eopf_geozarr.conversion.fs_utils import get_storage_options, normalize_path -from eopf_geozarr.conversion.geozarr import create_native_crs_tile_matrix_set, _create_tile_matrix_limits try: import distributed + DISTRIBUTED_AVAILABLE = True except ImportError: DISTRIBUTED_AVAILABLE = False + class S2OptimizedConverter: """Optimized Sentinel-2 to GeoZarr converter.""" - + def __init__( self, enable_sharding: bool = True, spatial_chunk: int = 1024, compression_level: int = 3, - max_retries: int = 3 + max_retries: int = 3, ): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk self.compression_level = compression_level self.max_retries = max_retries - + # Initialize components self.pyramid_creator = S2MultiscalePyramid(enable_sharding, spatial_chunk) self.validator = S2OptimizationValidator() - + def convert_s2_optimized( self, dt_input: xr.DataTree, @@ -46,11 +51,11 @@ def convert_s2_optimized( create_geometry_group: bool = True, create_meteorology_group: bool = True, validate_output: bool = True, - verbose: bool = False + verbose: bool = False, ) -> xr.DataTree: """ Convert S2 dataset to optimized structure. - + Args: dt_input: Input Sentinel-2 DataTree output_path: Output path for optimized dataset @@ -58,132 +63,137 @@ def convert_s2_optimized( create_meteorology_group: Whether to create meteorology group validate_output: Whether to validate the output verbose: Enable verbose logging - + Returns: Optimized DataTree """ start_time = time.time() - + if verbose: - print(f"Starting S2 optimized conversion...") + print("Starting S2 optimized conversion...") print(f"Input: {len(dt_input.groups)} groups") print(f"Output: {output_path}") - + # Validate input is S2 if not self._is_sentinel2_dataset(dt_input): raise ValueError("Input dataset is not a Sentinel-2 product") - + # Step 1: Consolidate data from scattered structure print("Step 1: Consolidating EOPF data structure...") consolidator = S2DataConsolidator(dt_input) - measurements_data, geometry_data, meteorology_data = consolidator.consolidate_all_data() - + measurements_data, geometry_data, meteorology_data = ( + consolidator.consolidate_all_data() + ) + if verbose: - print(f" Measurements data extracted: {sum(len(d['bands']) for d in measurements_data.values())} bands") + print( + f" Measurements data extracted: {sum(len(d['bands']) for d in measurements_data.values())} bands" + ) print(f" Geometry variables: {len(geometry_data)}") print(f" Meteorology variables: {len(meteorology_data)}") - + # Step 2: Create multiscale measurements print("Step 2: Creating multiscale measurements pyramid...") pyramid_datasets = self.pyramid_creator.create_multiscale_measurements( measurements_data, output_path ) - + print(f" Created {len(pyramid_datasets)} pyramid levels") - + # Step 3: Create geometry group if create_geometry_group and geometry_data: print("Step 3: Creating consolidated geometry group...") geometry_ds = xr.Dataset(geometry_data) geometry_path = f"{output_path}/geometry" self._write_auxiliary_group(geometry_ds, geometry_path, "geometry", verbose) - + # Step 4: Create meteorology group if create_meteorology_group and meteorology_data: print("Step 4: Creating consolidated meteorology group...") meteorology_ds = xr.Dataset(meteorology_data) meteorology_path = f"{output_path}/meteorology" - self._write_auxiliary_group(meteorology_ds, meteorology_path, "meteorology", verbose) - + self._write_auxiliary_group( + meteorology_ds, meteorology_path, "meteorology", verbose + ) + # Step 5: Add multiscales metadata to measurements group print("Step 5: Adding multiscales metadata to measurements group...") self._add_measurements_multiscales_metadata(output_path, pyramid_datasets) - + # Step 6: Simple root-level consolidation print("Step 6: Final root-level metadata consolidation...") self._simple_root_consolidation(output_path, pyramid_datasets) - + # Step 7: Validation if validate_output: print("Step 7: Validating optimized dataset...") validation_results = self.validator.validate_optimized_dataset(output_path) - if not validation_results['is_valid']: + if not validation_results["is_valid"]: print(" Warning: Validation issues found:") - for issue in validation_results['issues']: + for issue in validation_results["issues"]: print(f" - {issue}") - + # Create result DataTree result_dt = self._create_result_datatree(output_path) - + total_time = time.time() - start_time print(f"Optimization complete in {total_time:.2f}s") - + if verbose: self._print_optimization_summary(dt_input, result_dt, output_path) - + return result_dt - + def _is_sentinel2_dataset(self, dt: xr.DataTree) -> bool: """Check if dataset is Sentinel-2.""" # Check STAC properties - stac_props = dt.attrs.get('stac_discovery', {}).get('properties', {}) - mission = stac_props.get('mission', '') - - if mission.lower().startswith('sentinel-2'): + stac_props = dt.attrs.get("stac_discovery", {}).get("properties", {}) + mission = stac_props.get("mission", "") + + if mission.lower().startswith("sentinel-2"): return True - + # Check for characteristic S2 groups s2_indicators = [ - '/measurements/reflectance', - '/conditions/geometry', - '/quality/atmosphere' + "/measurements/reflectance", + "/conditions/geometry", + "/quality/atmosphere", ] - - found_indicators = sum(1 for indicator in s2_indicators if indicator in dt.groups) + + found_indicators = sum( + 1 for indicator in s2_indicators if indicator in dt.groups + ) return found_indicators >= 2 - + def _write_auxiliary_group( - self, - dataset: xr.Dataset, - group_path: str, - group_type: str, - verbose: bool + self, dataset: xr.Dataset, group_path: str, group_type: str, verbose: bool ) -> None: """Write auxiliary group (geometry or meteorology).""" # Create simple encoding following geozarr.py pattern from zarr.codecs import BloscCodec + compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) encoding = {} for var_name in dataset.data_vars: - encoding[var_name] = {'compressors': [compressor]} + encoding[var_name] = {"compressors": [compressor]} for coord_name in dataset.coords: - encoding[coord_name] = {'compressors': None} - + encoding[coord_name] = {"compressors": None} + # Write dataset with progress bar storage_options = get_storage_options(group_path) - + # Create zarr write job with progress bar write_job = dataset.to_zarr( group_path, - mode='w', + mode="w", consolidated=True, zarr_format=3, encoding=encoding, storage_options=storage_options, - compute=False + compute=False, ) write_job = write_job.persist() - + # Show progress bar if distributed is available if DISTRIBUTED_AVAILABLE: try: @@ -196,218 +206,249 @@ def _write_auxiliary_group( else: print(f" Writing {group_type} zarr file...") write_job.compute() - + if verbose: - print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") - - def _add_measurements_multiscales_metadata(self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset]) -> None: + print( + f" {group_type.title()} group written: {len(dataset.data_vars)} variables" + ) + + def _add_measurements_multiscales_metadata( + self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset] + ) -> None: """Add multiscales metadata to the measurements group using rioxarray.""" try: measurements_path = f"{output_path}/measurements" - + # Create multiscales metadata using rioxarray .rio accessor - multiscales_metadata = self._create_multiscales_metadata_with_rio(pyramid_datasets) - + multiscales_metadata = self._create_multiscales_metadata_with_rio( + pyramid_datasets + ) + if multiscales_metadata: # Use zarr to add metadata to the measurements group storage_options = get_storage_options(measurements_path) - + try: import zarr + if storage_options: - store = zarr.storage.FSStore(measurements_path, **storage_options) + store = zarr.storage.FSStore( + measurements_path, **storage_options + ) else: store = measurements_path - + # Open the measurements group and add multiscales metadata - measurements_group = zarr.open_group(store, mode='r+') - measurements_group.attrs['multiscales'] = multiscales_metadata - + measurements_group = zarr.open_group(store, mode="r+") + measurements_group.attrs["multiscales"] = multiscales_metadata + print(" ✅ Added multiscales metadata to measurements group") - + except Exception as e: print(f" ⚠️ Could not add multiscales metadata: {e}") - + except Exception as e: print(f" ⚠️ Error adding multiscales metadata: {e}") - - def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.Dataset]) -> Dict: + + def _create_multiscales_metadata_with_rio( + self, pyramid_datasets: Dict[int, xr.Dataset] + ) -> Dict: """Create multiscales metadata using rioxarray .rio accessor, following geozarr.py format.""" if not pyramid_datasets: return {} - + # Get the first available dataset to extract spatial information using .rio reference_ds = None for level in sorted(pyramid_datasets.keys()): if pyramid_datasets[level] is not None: reference_ds = pyramid_datasets[level] break - + if not reference_ds or not reference_ds.data_vars: return {} - + try: # Use .rio accessor to get CRS and bounds directly from the dataset - if not hasattr(reference_ds, 'rio') or not reference_ds.rio.crs: + if not hasattr(reference_ds, "rio") or not reference_ds.rio.crs: return {} - + native_crs = reference_ds.rio.crs native_bounds = reference_ds.rio.bounds() - + # Create overview levels list following geozarr.py format overview_levels = [] for level in sorted(pyramid_datasets.keys()): if pyramid_datasets[level] is not None: level_ds = pyramid_datasets[level] - resolution = self.pyramid_creator.pyramid_levels.get(level, level * 10) - - if hasattr(level_ds, 'rio'): + resolution = self.pyramid_creator.pyramid_levels.get( + level, level * 10 + ) + + if hasattr(level_ds, "rio"): width = level_ds.rio.width height = level_ds.rio.height - scale_factor = 2 ** level if level > 0 else 1 - - overview_levels.append({ - 'level': level, - 'width': width, - 'height': height, - 'scale_factor': scale_factor, - 'zoom': max(0, level) # Simple zoom calculation - }) - + scale_factor = 2**level if level > 0 else 1 + + overview_levels.append( + { + "level": level, + "width": width, + "height": height, + "scale_factor": scale_factor, + "zoom": max(0, level), # Simple zoom calculation + } + ) + if not overview_levels: return {} - + # Create tile matrix set following geozarr.py exactly tile_matrix_set = create_native_crs_tile_matrix_set( - native_crs, - native_bounds, - overview_levels, - "measurements" # group prefix + native_crs, + native_bounds, + overview_levels, + "measurements", # group prefix ) - - # Create tile matrix limits following geozarr.py exactly - tile_matrix_limits = _create_tile_matrix_limits(overview_levels, 256) # tile_width=256 - + + # Create tile matrix limits following geozarr.py exactly + tile_matrix_limits = _create_tile_matrix_limits( + overview_levels, 256 + ) # tile_width=256 + # Create multiscales metadata following geozarr.py format exactly multiscales_metadata = { "tile_matrix_set": tile_matrix_set, "resampling_method": "average", "tile_matrix_limits": tile_matrix_limits, } - + return multiscales_metadata - + except Exception as e: - print(f" Warning: Could not create multiscales metadata with .rio accessor: {e}") + print( + f" Warning: Could not create multiscales metadata with .rio accessor: {e}" + ) return {} - def _simple_root_consolidation(self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset]) -> None: + def _simple_root_consolidation( + self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset] + ) -> None: """Simple root-level metadata consolidation using only xarray.""" try: # Since each level and auxiliary group was written with consolidated=True, # we just need to create a simple root-level consolidated metadata print(" Performing simple root consolidation...") - + # Use xarray to open and immediately close the root group with consolidation # This creates/updates the root .zmetadata file storage_options = get_storage_options(output_path) - + # Open the root zarr group and let xarray handle consolidation try: # This will create consolidated metadata at the root level - with xr.open_zarr(output_path, storage_options=storage_options, - consolidated=True, chunks={}) as root_ds: + with xr.open_zarr( + output_path, + storage_options=storage_options, + consolidated=True, + chunks={}, + ) as root_ds: # Just opening and closing with consolidated=True should be enough pass print(" ✅ Root consolidation completed") except Exception as e: - print(f" ⚠️ Root consolidation using xarray failed, trying zarr directly: {e}") - + print( + f" ⚠️ Root consolidation using xarray failed, trying zarr directly: {e}" + ) + # Fallback: minimal zarr consolidation if needed import zarr - store = zarr.storage.FSStore(output_path, **storage_options) if storage_options else output_path + + store = ( + zarr.storage.FSStore(output_path, **storage_options) + if storage_options + else output_path + ) try: zarr.consolidate_metadata(store) print(" ✅ Root consolidation completed with zarr") except Exception as e2: print(f" ⚠️ Warning: Root consolidation failed: {e2}") - + except Exception as e: print(f" ⚠️ Warning: Root consolidation failed: {e}") - + def _create_result_datatree(self, output_path: str) -> xr.DataTree: """Create result DataTree from written output.""" try: storage_options = get_storage_options(output_path) return xr.open_datatree( output_path, - engine='zarr', - chunks='auto', - storage_options=storage_options + engine="zarr", + chunks="auto", + storage_options=storage_options, ) except Exception as e: print(f"Warning: Could not open result DataTree: {e}") return xr.DataTree() - + def _print_optimization_summary( - self, - dt_input: xr.DataTree, - dt_output: xr.DataTree, - output_path: str + self, dt_input: xr.DataTree, dt_output: xr.DataTree, output_path: str ) -> None: """Print optimization summary statistics.""" - print("\n" + "="*50) + print("\n" + "=" * 50) print("OPTIMIZATION SUMMARY") - print("="*50) - + print("=" * 50) + # Count groups - input_groups = len(dt_input.groups) if hasattr(dt_input, 'groups') else 0 - output_groups = len(dt_output.groups) if hasattr(dt_output, 'groups') else 0 - - print(f"Groups: {input_groups} → {output_groups} ({((output_groups-input_groups)/input_groups*100):+.1f}%)") - + input_groups = len(dt_input.groups) if hasattr(dt_input, "groups") else 0 + output_groups = len(dt_output.groups) if hasattr(dt_output, "groups") else 0 + + print( + f"Groups: {input_groups} → {output_groups} ({((output_groups - input_groups) / input_groups * 100):+.1f}%)" + ) + # Estimate file count reduction estimated_input_files = input_groups * 10 # Rough estimate estimated_output_files = output_groups * 5 # Fewer files per group - print(f"Estimated files: {estimated_input_files} → {estimated_output_files} ({((estimated_output_files-estimated_input_files)/estimated_input_files*100):+.1f}%)") - + print( + f"Estimated files: {estimated_input_files} → {estimated_output_files} ({((estimated_output_files - estimated_input_files) / estimated_input_files * 100):+.1f}%)" + ) + # Show structure - print(f"\nNew structure:") - print(f" /measurements/ (multiscale: levels 0-6)") + print("\nNew structure:") + print(" /measurements/ (multiscale: levels 0-6)") if f"{output_path}/geometry" in str(dt_output): - print(f" /geometry/ (consolidated)") + print(" /geometry/ (consolidated)") if f"{output_path}/meteorology" in str(dt_output): - print(f" /meteorology/ (consolidated)") - - print("="*50) + print(" /meteorology/ (consolidated)") + + print("=" * 50) def convert_s2_optimized( - dt_input: xr.DataTree, - output_path: str, - **kwargs + dt_input: xr.DataTree, output_path: str, **kwargs ) -> xr.DataTree: """ Convenience function for S2 optimization. - + Args: dt_input: Input Sentinel-2 DataTree output_path: Output path **kwargs: Additional arguments for S2OptimizedConverter - + Returns: Optimized DataTree """ # Separate constructor args from method args constructor_args = { - 'enable_sharding': kwargs.pop('enable_sharding', True), - 'spatial_chunk': kwargs.pop('spatial_chunk', 1024), - 'compression_level': kwargs.pop('compression_level', 3), - 'max_retries': kwargs.pop('max_retries', 3) + "enable_sharding": kwargs.pop("enable_sharding", True), + "spatial_chunk": kwargs.pop("spatial_chunk", 1024), + "compression_level": kwargs.pop("compression_level", 3), + "max_retries": kwargs.pop("max_retries", 3), } - + # Remaining kwargs are for the convert_s2_optimized method method_args = kwargs - + converter = S2OptimizedConverter(**constructor_args) return converter.convert_s2_optimized(dt_input, output_path, **method_args) diff --git a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py index 1bf1d2bd..b5452558 100644 --- a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py +++ b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py @@ -2,227 +2,234 @@ Data consolidation logic for reorganizing S2 structure. """ +from typing import Dict, Tuple + import xarray as xr -from typing import Dict, List, Tuple, Optional + from .s2_band_mapping import ( - NATIVE_BANDS, QUALITY_DATA_NATIVE, DETECTOR_FOOTPRINT_NATIVE, - get_bands_for_level, get_quality_data_for_level + NATIVE_BANDS, ) + class S2DataConsolidator: """Consolidates S2 data from scattered structure into organized groups.""" - + def __init__(self, dt_input: xr.DataTree): self.dt_input = dt_input self.measurements_data = {} self.geometry_data = {} self.meteorology_data = {} - + def consolidate_all_data(self) -> Tuple[Dict, Dict, Dict]: """ Consolidate all data into three main categories. - + Returns: Tuple of (measurements, geometry, meteorology) data dictionaries """ self._extract_measurements_data() self._extract_geometry_data() self._extract_meteorology_data() - + return self.measurements_data, self.geometry_data, self.meteorology_data - + def _extract_measurements_data(self) -> None: """Extract and organize all measurement-related data by native resolution.""" - + # Initialize resolution groups for resolution in [10, 20, 60]: self.measurements_data[resolution] = { - 'bands': {}, - 'quality': {}, - 'detector_footprints': {}, - 'classification': {}, - 'atmosphere': {}, - 'probability': {} + "bands": {}, + "quality": {}, + "detector_footprints": {}, + "classification": {}, + "atmosphere": {}, + "probability": {}, } - + # Extract reflectance bands - if any('/measurements/reflectance' in group for group in self.dt_input.groups): + if any("/measurements/reflectance" in group for group in self.dt_input.groups): self._extract_reflectance_bands() - + # Extract quality data self._extract_quality_data() - + # Extract detector footprints self._extract_detector_footprints() - + # Extract atmosphere quality self._extract_atmosphere_data() - + # Extract classification data self._extract_classification_data() - + # Extract probability data self._extract_probability_data() - + def _extract_reflectance_bands(self) -> None: """Extract reflectance bands from measurements/reflectance groups.""" - for resolution in ['r10m', 'r20m', 'r60m']: + for resolution in ["r10m", "r20m", "r60m"]: res_num = int(resolution[1:-1]) # Extract number from 'r10m' - group_path = f'/measurements/reflectance/{resolution}' - + group_path = f"/measurements/reflectance/{resolution}" + if group_path in self.dt_input.groups: # Check if this is a multiscale group (has numeric subgroups) group_node = self.dt_input[group_path] - if hasattr(group_node, 'children') and group_node.children: + if hasattr(group_node, "children") and group_node.children: # Take level 0 (native resolution) - native_path = f'{group_path}/0' + native_path = f"{group_path}/0" if native_path in self.dt_input.groups: ds = self.dt_input[native_path].to_dataset() else: ds = group_node.to_dataset() else: ds = group_node.to_dataset() - + # Extract only native bands for this resolution native_bands = NATIVE_BANDS.get(res_num, []) for band in native_bands: if band in ds.data_vars: - self.measurements_data[res_num]['bands'][band] = ds[band] - + self.measurements_data[res_num]["bands"][band] = ds[band] + def _extract_quality_data(self) -> None: """Extract quality mask data.""" - quality_base = '/quality/mask' - - for resolution in ['r10m', 'r20m', 'r60m']: + quality_base = "/quality/mask" + + for resolution in ["r10m", "r20m", "r60m"]: res_num = int(resolution[1:-1]) - group_path = f'{quality_base}/{resolution}' - + group_path = f"{quality_base}/{resolution}" + if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - + # Only extract quality for native bands at this resolution native_bands = NATIVE_BANDS.get(res_num, []) for band in native_bands: if band in ds.data_vars: - self.measurements_data[res_num]['quality'][f'quality_{band}'] = ds[band] - + self.measurements_data[res_num]["quality"][ + f"quality_{band}" + ] = ds[band] + def _extract_detector_footprints(self) -> None: """Extract detector footprint data.""" - footprint_base = '/conditions/mask/detector_footprint' - - for resolution in ['r10m', 'r20m', 'r60m']: + footprint_base = "/conditions/mask/detector_footprint" + + for resolution in ["r10m", "r20m", "r60m"]: res_num = int(resolution[1:-1]) - group_path = f'{footprint_base}/{resolution}' - + group_path = f"{footprint_base}/{resolution}" + if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - + # Only extract footprints for native bands native_bands = NATIVE_BANDS.get(res_num, []) for band in native_bands: if band in ds.data_vars: - var_name = f'detector_footprint_{band}' - self.measurements_data[res_num]['detector_footprints'][var_name] = ds[band] - + var_name = f"detector_footprint_{band}" + self.measurements_data[res_num]["detector_footprints"][ + var_name + ] = ds[band] + def _extract_atmosphere_data(self) -> None: """Extract atmosphere quality data (aot, wvp) - native at 20m.""" - atm_base = '/quality/atmosphere' - + atm_base = "/quality/atmosphere" + # Atmosphere data is native at 20m resolution - group_path = f'{atm_base}/r20m' + group_path = f"{atm_base}/r20m" if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - - for var in ['aot', 'wvp']: + + for var in ["aot", "wvp"]: if var in ds.data_vars: - self.measurements_data[20]['atmosphere'][var] = ds[var] - + self.measurements_data[20]["atmosphere"][var] = ds[var] + def _extract_classification_data(self) -> None: """Extract scene classification data - native at 20m.""" - class_base = '/conditions/mask/l2a_classification' - + class_base = "/conditions/mask/l2a_classification" + # Classification is native at 20m - group_path = f'{class_base}/r20m' + group_path = f"{class_base}/r20m" if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - - if 'scl' in ds.data_vars: - self.measurements_data[20]['classification']['scl'] = ds['scl'] - + + if "scl" in ds.data_vars: + self.measurements_data[20]["classification"]["scl"] = ds["scl"] + def _extract_probability_data(self) -> None: """Extract cloud and snow probability data - native at 20m.""" - prob_base = '/quality/probability/r20m' - + prob_base = "/quality/probability/r20m" + if prob_base in self.dt_input.groups: ds = self.dt_input[prob_base].to_dataset() - - for var in ['cld', 'snw']: + + for var in ["cld", "snw"]: if var in ds.data_vars: - self.measurements_data[20]['probability'][var] = ds[var] - + self.measurements_data[20]["probability"][var] = ds[var] + def _extract_geometry_data(self) -> None: """Extract all geometry-related data into single group.""" - geom_base = '/conditions/geometry' - + geom_base = "/conditions/geometry" + if geom_base in self.dt_input.groups: ds = self.dt_input[geom_base].to_dataset() - + # Consolidate all geometry variables for var_name in ds.data_vars: self.geometry_data[var_name] = ds[var_name] - + def _extract_meteorology_data(self) -> None: """Extract meteorological data (CAMS and ECMWF).""" # CAMS data - cams_path = '/conditions/meteorology/cams' + cams_path = "/conditions/meteorology/cams" if cams_path in self.dt_input.groups: ds = self.dt_input[cams_path].to_dataset() for var_name in ds.data_vars: - self.meteorology_data[f'cams_{var_name}'] = ds[var_name] - + self.meteorology_data[f"cams_{var_name}"] = ds[var_name] + # ECMWF data - ecmwf_path = '/conditions/meteorology/ecmwf' + ecmwf_path = "/conditions/meteorology/ecmwf" if ecmwf_path in self.dt_input.groups: ds = self.dt_input[ecmwf_path].to_dataset() for var_name in ds.data_vars: - self.meteorology_data[f'ecmwf_{var_name}'] = ds[var_name] + self.meteorology_data[f"ecmwf_{var_name}"] = ds[var_name] + def create_consolidated_dataset(data_dict: Dict, resolution: int) -> xr.Dataset: """ Create a consolidated dataset from categorized data. - + Args: data_dict: Dictionary with categorized data resolution: Target resolution in meters - + Returns: Consolidated xarray Dataset """ all_vars = {} - + # Combine all data variables for category, vars_dict in data_dict.items(): all_vars.update(vars_dict) - + if not all_vars: return xr.Dataset() - + # Create dataset ds = xr.Dataset(all_vars) - + # Set up coordinate system and metadata - if 'x' in ds.coords and 'y' in ds.coords: + if "x" in ds.coords and "y" in ds.coords: # Ensure CRS information is present if ds.rio.crs is None: # Try to infer CRS from one of the variables for var_name, var_data in all_vars.items(): - if hasattr(var_data, 'rio') and var_data.rio.crs: + if hasattr(var_data, "rio") and var_data.rio.crs: ds.rio.write_crs(var_data.rio.crs, inplace=True) break - + # Add resolution metadata - ds.attrs['native_resolution_meters'] = resolution - ds.attrs['processing_level'] = 'L2A' - ds.attrs['product_type'] = 'S2MSI2A' - - return ds \ No newline at end of file + ds.attrs["native_resolution_meters"] = resolution + ds.attrs["processing_level"] = "L2A" + ds.attrs["product_type"] = "S2MSI2A" + + return ds diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index d6169fb0..3686ec6e 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -2,58 +2,60 @@ Multiscale pyramid creation for optimized S2 structure. """ -import numpy as np -from pyproj import CRS +from typing import Dict, Tuple + import xarray as xr -from typing import Dict, List, Tuple +from pyproj import CRS + from .s2_resampling import S2ResamplingEngine, determine_variable_type -from .s2_band_mapping import get_bands_for_level, get_quality_data_for_level try: import distributed - from dask import delayed, compute + from dask import compute, delayed + DISTRIBUTED_AVAILABLE = True DASK_AVAILABLE = True except ImportError: DISTRIBUTED_AVAILABLE = False DASK_AVAILABLE = False + # Create dummy delayed function for non-dask environments def delayed(func): return func + def compute(*args, **kwargs): return args + class S2MultiscalePyramid: """Creates multiscale pyramids for consolidated S2 data.""" - + def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk self.resampler = S2ResamplingEngine() - + # Define pyramid levels: resolution in meters self.pyramid_levels = { - 0: 10, # Level 0: 10m (native for b02,b03,b04,b08) - 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) - 2: 60, # Level 2: 60m (3x downsampling from 20m) - 3: 120, # Level 3: 120m (2x downsampling from 60m) - 4: 240, # Level 4: 240m (2x downsampling from 120m) - 5: 480, # Level 5: 480m (2x downsampling from 240m) - 6: 960 # Level 6: 960m (2x downsampling from 480m) + 0: 10, # Level 0: 10m (native for b02,b03,b04,b08) + 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) + 2: 60, # Level 2: 60m (3x downsampling from 20m) + 3: 120, # Level 3: 120m (2x downsampling from 60m) + 4: 240, # Level 4: 240m (2x downsampling from 120m) + 5: 480, # Level 5: 480m (2x downsampling from 240m) + 6: 960, # Level 6: 960m (2x downsampling from 480m) } - + def create_multiscale_measurements( - self, - measurements_by_resolution: Dict[int, Dict], - output_path: str + self, measurements_by_resolution: Dict[int, Dict], output_path: str ) -> Dict[int, xr.Dataset]: """ Create multiscale pyramid from consolidated measurements with parallelization. - + Args: measurements_by_resolution: Data organized by native resolution output_path: Base output path - + Returns: Dictionary of datasets by pyramid level """ @@ -65,11 +67,9 @@ def create_multiscale_measurements( return self._create_multiscale_measurements_sequential( measurements_by_resolution, output_path ) - + def _create_multiscale_measurements_parallel( - self, - measurements_by_resolution: Dict[int, Dict], - output_path: str + self, measurements_by_resolution: Dict[int, Dict], output_path: str ) -> Dict[int, xr.Dataset]: """ Create multiscale pyramid with streaming Dask parallelization. @@ -124,45 +124,45 @@ def _create_multiscale_measurements_parallel( else: print(f" Skipping empty level {level}") - print(f"✅ Streaming pyramid creation complete: {len([k for k, v in pyramid_datasets.items() if v is not None])} levels") + print( + f"✅ Streaming pyramid creation complete: {len([k for k, v in pyramid_datasets.items() if v is not None])} levels" + ) return pyramid_datasets - + def _create_multiscale_measurements_sequential( - self, - measurements_by_resolution: Dict[int, Dict], - output_path: str + self, measurements_by_resolution: Dict[int, Dict], output_path: str ) -> Dict[int, xr.Dataset]: """ Create multiscale pyramid sequentially (fallback for non-Dask environments). """ print("Creating multiscale pyramid sequentially...") pyramid_datasets = {} - + # Create each pyramid level for level, target_resolution in self.pyramid_levels.items(): print(f"Creating pyramid level {level} ({target_resolution}m)...") - + dataset = self._create_level_dataset( level, target_resolution, measurements_by_resolution ) - + if dataset and len(dataset.data_vars) > 0: pyramid_datasets[level] = dataset - + # Write this level level_path = f"{output_path}/measurements/{level}" self._write_level_dataset(dataset, level_path, level) - + return pyramid_datasets - + def _create_level_dataset( self, level: int, target_resolution: int, - measurements_by_resolution: Dict[int, Dict] + measurements_by_resolution: Dict[int, Dict], ) -> xr.Dataset: """Create dataset for a specific pyramid level.""" - + if level == 0: # Level 0: Only native 10m data return self._create_level_0_dataset(measurements_by_resolution) @@ -177,15 +177,15 @@ def _create_level_dataset( return self._create_downsampled_dataset( level, target_resolution, measurements_by_resolution ) - + def _create_level_dataset_parallel( self, level: int, target_resolution: int, - measurements_by_resolution: Dict[int, Dict] + measurements_by_resolution: Dict[int, Dict], ) -> xr.Dataset: """Create dataset for a specific pyramid level with parallel processing.""" - + if level == 0: # Level 0: Only native 10m data (no parallelization needed) return self._create_level_0_dataset(measurements_by_resolution) @@ -200,32 +200,34 @@ def _create_level_dataset_parallel( return self._create_downsampled_dataset( level, target_resolution, measurements_by_resolution ) - - def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> xr.Dataset: + + def _create_level_1_dataset_parallel( + self, measurements_by_resolution: Dict + ) -> xr.Dataset: """Create level 1 dataset with parallel downsampling from 10m data.""" all_vars = {} reference_coords = None - + # Start with native 20m data if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] - for category, vars_dict in data_20m.items(): + for category, vars_dict in data_20m.items(): all_vars.update(vars_dict) - + # Get reference coordinates from 20m data if all_vars: first_var = next(iter(all_vars.values())) reference_coords = { - 'x': first_var.coords['x'], - 'y': first_var.coords['y'] + "x": first_var.coords["x"], + "y": first_var.coords["y"], } - + # Add downsampled 10m data with parallelization if 10 in measurements_by_resolution and reference_coords: data_10m = measurements_by_resolution[10] - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + # Collect all 10m variables for parallel processing vars_to_downsample = [] for category, vars_dict in data_10m.items(): @@ -234,9 +236,10 @@ def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> if var_name in all_vars: continue vars_to_downsample.append((var_name, var_data)) - + # Process variables in parallel if Dask is available if DASK_AVAILABLE and vars_to_downsample: + @delayed def downsample_10m_variable(var_name: str, var_data: xr.DataArray): var_type = determine_variable_type(var_name, var_data) @@ -246,15 +249,17 @@ def downsample_10m_variable(var_name: str, var_data: xr.DataArray): # Align coordinates downsampled = downsampled.assign_coords(reference_coords) return var_name, downsampled - + # Create tasks for all variables downsample_tasks = [ downsample_10m_variable(var_name, var_data) for var_name, var_data in vars_to_downsample ] - + # Compute all in parallel - print(f" Parallel downsampling {len(downsample_tasks)} variables from 10m to 20m...") + print( + f" Parallel downsampling {len(downsample_tasks)} variables from 10m to 20m..." + ) results = compute(*downsample_tasks) for var_name, downsampled_var in results: all_vars[var_name] = downsampled_var @@ -268,62 +273,67 @@ def downsample_10m_variable(var_name: str, var_data: xr.DataArray): # Align coordinates downsampled = downsampled.assign_coords(reference_coords) all_vars[var_name] = downsampled - + if not all_vars: return xr.Dataset() - + # Create consolidated dataset dataset = xr.Dataset(all_vars) - dataset.attrs['pyramid_level'] = 1 - dataset.attrs['resolution_meters'] = 20 - + dataset.attrs["pyramid_level"] = 1 + dataset.attrs["resolution_meters"] = 20 + self._write_geo_metadata(dataset) - + return dataset - - def _create_level_2_dataset_parallel(self, measurements_by_resolution: Dict) -> xr.Dataset: + + def _create_level_2_dataset_parallel( + self, measurements_by_resolution: Dict + ) -> xr.Dataset: """Create level 2 dataset with parallel downsampling to 60m.""" all_vars = {} reference_coords = None - + # Start with native 60m data if 60 in measurements_by_resolution: data_60m = measurements_by_resolution[60] for category, vars_dict in data_60m.items(): all_vars.update(vars_dict) - + # Get reference coordinates from 60m data if all_vars: first_var = next(iter(all_vars.values())) reference_coords = { - 'x': first_var.coords['x'], - 'y': first_var.coords['y'] + "x": first_var.coords["x"], + "y": first_var.coords["y"], } - + # Collect all variables that need downsampling to 60m vars_to_downsample = [] if reference_coords: - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + # Add 20m data for downsampling if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] for category, vars_dict in data_20m.items(): for var_name, var_data in vars_dict.items(): - vars_to_downsample.append((var_name, var_data, '20m')) - + vars_to_downsample.append((var_name, var_data, "20m")) + # Add 10m data for downsampling if 10 in measurements_by_resolution: data_10m = measurements_by_resolution[10] for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): - vars_to_downsample.append((var_name, var_data, '10m')) - + vars_to_downsample.append((var_name, var_data, "10m")) + # Process all downsampling in parallel if Dask is available if DASK_AVAILABLE and vars_to_downsample: + @delayed - def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res: str): + def downsample_to_60m_variable( + var_name: str, var_data: xr.DataArray, source_res: str + ): var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( var_data, target_height, target_width, var_type @@ -331,15 +341,17 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res # Align coordinates downsampled = downsampled.assign_coords(reference_coords) return var_name, downsampled - + # Create tasks for all variables downsample_tasks = [ downsample_to_60m_variable(var_name, var_data, source_res) for var_name, var_data, source_res in vars_to_downsample ] - + # Compute all in parallel - print(f" Parallel downsampling {len(downsample_tasks)} variables to 60m...") + print( + f" Parallel downsampling {len(downsample_tasks)} variables to 60m..." + ) results = compute(*downsample_tasks) for var_name, downsampled_var in results: all_vars[var_name] = downsampled_var @@ -353,90 +365,87 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res # Align coordinates downsampled = downsampled.assign_coords(reference_coords) all_vars[var_name] = downsampled - + if not all_vars: return xr.Dataset() - + # Create consolidated dataset dataset = xr.Dataset(all_vars) - dataset.attrs['pyramid_level'] = 2 - dataset.attrs['resolution_meters'] = 60 - + dataset.attrs["pyramid_level"] = 2 + dataset.attrs["resolution_meters"] = 60 + self._write_geo_metadata(dataset) - + return dataset - + def _create_downsampled_dataset_from_level2_parallel( - self, - level: int, - target_resolution: int, - level_2_dataset: xr.Dataset + self, level: int, target_resolution: int, level_2_dataset: xr.Dataset ) -> xr.Dataset: """Create downsampled dataset from level 2 with parallel processing.""" if len(level_2_dataset.data_vars) == 0: return xr.Dataset() - + # Calculate target dimensions (downsample by factor of 2^(level-2)) downsample_factor = 2 ** (level - 2) - + # Get reference dimensions from level 2 ref_var = next(iter(level_2_dataset.data_vars.values())) current_height, current_width = ref_var.shape[-2:] target_height = current_height // downsample_factor target_width = current_width // downsample_factor - + # Always use parallel processing for higher levels return self._downsample_variables_parallel( level_2_dataset, level, target_resolution, target_height, target_width ) - + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: """Create level 0 dataset with only native 10m data.""" if 10 not in measurements_by_resolution: return xr.Dataset() - + data_10m = measurements_by_resolution[10] all_vars = {} - + # Add only native 10m bands and their associated data for category, vars_dict in data_10m.items(): all_vars.update(vars_dict) - + if not all_vars: return xr.Dataset() - + # Create consolidated dataset dataset = xr.Dataset(all_vars) - dataset.attrs['pyramid_level'] = 0 - dataset.attrs['resolution_meters'] = 10 - + dataset.attrs["pyramid_level"] = 0 + dataset.attrs["resolution_meters"] = 10 + self._write_geo_metadata(dataset) - + return dataset - + def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: """Create level 1 dataset with all data at 20m resolution.""" all_vars = {} reference_coords = None - + # Start with native 20m data if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] for category, vars_dict in data_20m.items(): all_vars.update(vars_dict) - + # Get reference coordinates from 20m data if all_vars: first_var = next(iter(all_vars.values())) reference_coords = { - 'x': first_var.coords['x'], - 'y': first_var.coords['y'] + "x": first_var.coords["x"], + "y": first_var.coords["y"], } - + # Add downsampled 10m data if 10 in measurements_by_resolution: data_10m = measurements_by_resolution[10] - + for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): # skip if already present from 20m data @@ -444,55 +453,55 @@ def _create_level_1_dataset(self, measurements_by_resolution: Dict) -> xr.Datase continue if reference_coords: # Downsample to match 20m grid - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( var_data, target_height, target_width, var_type ) - + # Align coordinates downsampled = downsampled.assign_coords(reference_coords) all_vars[var_name] = downsampled - + if not all_vars: return xr.Dataset() - + # Create consolidated dataset dataset = xr.Dataset(all_vars) - dataset.attrs['pyramid_level'] = 1 - dataset.attrs['resolution_meters'] = 20 - + dataset.attrs["pyramid_level"] = 1 + dataset.attrs["resolution_meters"] = 20 + self._write_geo_metadata(dataset) - + return dataset - + def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: """Create level 2 dataset with all data at 60m resolution.""" all_vars = {} reference_coords = None - + # Start with native 60m data if 60 in measurements_by_resolution: data_60m = measurements_by_resolution[60] for category, vars_dict in data_60m.items(): all_vars.update(vars_dict) - + # Get reference coordinates from 60m data if all_vars: first_var = next(iter(all_vars.values())) reference_coords = { - 'x': first_var.coords['x'], - 'y': first_var.coords['y'] + "x": first_var.coords["x"], + "y": first_var.coords["y"], } - + # Collect all variables that need downsampling to 60m vars_to_downsample = [] if reference_coords: - target_height = len(reference_coords['y']) - target_width = len(reference_coords['x']) - + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + # Add 20m data for downsampling if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] @@ -501,8 +510,8 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase # skip if already present from 20m data if var_name in all_vars: continue - vars_to_downsample.append((var_name, var_data, '20m')) - + vars_to_downsample.append((var_name, var_data, "20m")) + # Add 10m data for downsampling if 10 in measurements_by_resolution: data_10m = measurements_by_resolution[10] @@ -511,12 +520,15 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase # skip if already present from 20m data if var_name in all_vars: continue - vars_to_downsample.append((var_name, var_data, '10m')) - + vars_to_downsample.append((var_name, var_data, "10m")) + # Process all downsampling in parallel if Dask is available if DASK_AVAILABLE and vars_to_downsample: + @delayed - def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res: str): + def downsample_to_60m_variable( + var_name: str, var_data: xr.DataArray, source_res: str + ): var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( var_data, target_height, target_width, var_type @@ -524,13 +536,13 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res # Align coordinates downsampled = downsampled.assign_coords(reference_coords) return var_name, downsampled - + # Create tasks for all variables downsample_tasks = [ downsample_to_60m_variable(var_name, var_data, source_res) for var_name, var_data, source_res in vars_to_downsample ] - + # Compute all in parallel results = compute(*downsample_tasks) for var_name, downsampled_var in results: @@ -545,55 +557,49 @@ def downsample_to_60m_variable(var_name: str, var_data: xr.DataArray, source_res # Align coordinates downsampled = downsampled.assign_coords(reference_coords) all_vars[var_name] = downsampled - + if not all_vars: return xr.Dataset() - + # Create consolidated dataset dataset = xr.Dataset(all_vars) - dataset.attrs['pyramid_level'] = 2 - dataset.attrs['resolution_meters'] = 60 - + dataset.attrs["pyramid_level"] = 2 + dataset.attrs["resolution_meters"] = 60 + self._write_geo_metadata(dataset) - + return dataset - + def _create_downsampled_dataset( - self, - level: int, - target_resolution: int, - measurements_by_resolution: Dict + self, level: int, target_resolution: int, measurements_by_resolution: Dict ) -> xr.Dataset: """Create downsampled dataset for levels 3+.""" # Start from level 2 data (60m) which includes all bands, and downsample level_2_dataset = self._create_level_2_dataset(measurements_by_resolution) - + if len(level_2_dataset.data_vars) == 0: return xr.Dataset() - + return self._create_downsampled_dataset_from_level2( level, target_resolution, level_2_dataset ) - + def _create_downsampled_dataset_from_level2( - self, - level: int, - target_resolution: int, - level_2_dataset: xr.Dataset + self, level: int, target_resolution: int, level_2_dataset: xr.Dataset ) -> xr.Dataset: """Create downsampled dataset from existing level 2 data.""" if len(level_2_dataset.data_vars) == 0: return xr.Dataset() - + # Calculate target dimensions (downsample by factor of 2^(level-2)) downsample_factor = 2 ** (level - 2) - + # Get reference dimensions from level 2 ref_var = next(iter(level_2_dataset.data_vars.values())) current_height, current_width = ref_var.shape[-2:] target_height = current_height // downsample_factor target_width = current_width // downsample_factor - + # Parallelize variable downsampling if Dask is available if DASK_AVAILABLE: return self._downsample_variables_parallel( @@ -603,16 +609,17 @@ def _create_downsampled_dataset_from_level2( return self._downsample_variables_sequential( level_2_dataset, level, target_resolution, target_height, target_width ) - + def _downsample_variables_parallel( self, level_2_dataset: xr.Dataset, level: int, target_resolution: int, target_height: int, - target_width: int + target_width: int, ) -> xr.Dataset: """Downsample all variables in parallel using Dask.""" + @delayed def downsample_single_variable(var_name: str, var_data: xr.DataArray): """Downsample a single variable.""" @@ -621,89 +628,93 @@ def downsample_single_variable(var_name: str, var_data: xr.DataArray): var_data, target_height, target_width, var_type ) return var_name, downsampled - + # Create downsampling tasks for all variables downsample_tasks = [] for var_name, var_data in level_2_dataset.data_vars.items(): task = downsample_single_variable(var_name, var_data) downsample_tasks.append(task) - + # Compute all downsampling in parallel if downsample_tasks: results = compute(*downsample_tasks) downsampled_vars = dict(results) else: downsampled_vars = {} - + # Create dataset dataset = xr.Dataset(downsampled_vars) - dataset.attrs['pyramid_level'] = level - dataset.attrs['resolution_meters'] = target_resolution - + dataset.attrs["pyramid_level"] = level + dataset.attrs["resolution_meters"] = target_resolution + self._write_geo_metadata(dataset) - + return dataset - + def _downsample_variables_sequential( self, level_2_dataset: xr.Dataset, level: int, target_resolution: int, target_height: int, - target_width: int + target_width: int, ) -> xr.Dataset: """Downsample all variables sequentially (fallback).""" downsampled_vars = {} - + for var_name, var_data in level_2_dataset.data_vars.items(): var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( var_data, target_height, target_width, var_type ) downsampled_vars[var_name] = downsampled - + # Create dataset dataset = xr.Dataset(downsampled_vars) - dataset.attrs['pyramid_level'] = level - dataset.attrs['resolution_meters'] = target_resolution - + dataset.attrs["pyramid_level"] = level + dataset.attrs["resolution_meters"] = target_resolution + self._write_geo_metadata(dataset) - + return dataset - - def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> None: + + def _write_level_dataset( + self, dataset: xr.Dataset, level_path: str, level: int + ) -> None: """ Write a pyramid level dataset to storage with xy-aligned sharding. - + Ensures single file per variable per time point when time dimension exists. """ # Create encoding with xy-aligned sharding encoding = self._create_level_encoding(dataset, level) - + # Check if we have time dimension for single file per time handling - has_time_dim = any('time' in str(var.dims) for var in dataset.data_vars.values()) - + has_time_dim = any( + "time" in str(var.dims) for var in dataset.data_vars.values() + ) + if has_time_dim and self._should_separate_time_files(dataset): # Write each time slice separately to ensure single file per variable per time self._write_time_separated_dataset(dataset, level_path, level, encoding) else: # Write as single dataset with xy-aligned sharding print(f" Writing level {level} to {level_path} (xy-aligned sharding)") - + # Rechunk the dataset to align with encoding chunks (following geozarr.py pattern) rechunked_dataset = self._rechunk_dataset_for_encoding(dataset, encoding) - + # Create zarr write job with progress bar write_job = rechunked_dataset.to_zarr( level_path, - mode='w', + mode="w", consolidated=True, zarr_format=3, encoding=encoding, - compute=False + compute=False, ) write_job = write_job.persist() - + # Show progress bar if distributed is available if DISTRIBUTED_AVAILABLE: try: @@ -714,68 +725,64 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) print(f" Warning: Could not display progress bar: {e}") write_job.compute() else: - print(f" Writing zarr file...") + print(" Writing zarr file...") write_job.compute() - + def _should_separate_time_files(self, dataset: xr.Dataset) -> bool: """Determine if time files should be separated for single file per variable per time.""" for var in dataset.data_vars.values(): - if 'time' in var.dims and len(var.coords.get('time', [])) > 1: + if "time" in var.dims and len(var.coords.get("time", [])) > 1: return True return False - + def _write_time_separated_dataset( - self, - dataset: xr.Dataset, - level_path: str, - level: int, - encoding: Dict + self, dataset: xr.Dataset, level_path: str, level: int, encoding: Dict ) -> None: """Write dataset with separate files for each time point.""" import os - + # Get time coordinate time_coord = None for var in dataset.data_vars.values(): - if 'time' in var.dims: - time_coord = var.coords['time'] + if "time" in var.dims: + time_coord = var.coords["time"] break - + if time_coord is None: # Fallback to regular writing if no time found print(f" Writing level {level} to {level_path} (no time coord found)") dataset.to_zarr( level_path, - mode='w', + mode="w", consolidated=True, zarr_format=3, - encoding=encoding + encoding=encoding, ) return - + print(f" Writing level {level} with time separation to {level_path}") - + # Write each time slice separately for t_idx, time_val in enumerate(time_coord.values): time_slice = dataset.isel(time=t_idx) time_path = os.path.join(level_path, f"time_{t_idx:04d}") - + # Update encoding for time slice (remove time dimension) time_encoding = self._update_encoding_for_time_slice(encoding, time_slice) - + print(f" Writing time slice {t_idx} to {time_path}") - + # Create zarr write job with progress bar for time slice write_job = time_slice.to_zarr( time_path, - mode='w', + mode="w", consolidated=True, zarr_format=3, encoding=time_encoding, - compute=False + compute=False, ) write_job = write_job.persist() - + # Show progress bar if distributed is available if DISTRIBUTED_AVAILABLE: try: @@ -785,52 +792,54 @@ def _write_time_separated_dataset( write_job.compute() else: write_job.compute() - - def _update_encoding_for_time_slice(self, encoding: Dict, time_slice: xr.Dataset) -> Dict: + + def _update_encoding_for_time_slice( + self, encoding: Dict, time_slice: xr.Dataset + ) -> Dict: """Update encoding configuration for time slice data.""" updated_encoding = {} - + for var_name, var_encoding in encoding.items(): if var_name in time_slice.data_vars: var_data = time_slice[var_name] - + # Update chunks and shards for time slice (remove time dimension) - if 'chunks' in var_encoding and len(var_encoding['chunks']) > 2: + if "chunks" in var_encoding and len(var_encoding["chunks"]) > 2: # Remove time dimension from chunks (first dimension) - updated_chunks = var_encoding['chunks'][1:] + updated_chunks = var_encoding["chunks"][1:] updated_encoding[var_name] = var_encoding.copy() - updated_encoding[var_name]['chunks'] = updated_chunks - + updated_encoding[var_name]["chunks"] = updated_chunks + # Update shards if present - if 'shards' in var_encoding and len(var_encoding['shards']) > 2: - updated_shards = var_encoding['shards'][1:] - updated_encoding[var_name]['shards'] = updated_shards + if "shards" in var_encoding and len(var_encoding["shards"]) > 2: + updated_shards = var_encoding["shards"][1:] + updated_encoding[var_name]["shards"] = updated_shards else: updated_encoding[var_name] = var_encoding else: # Coordinate or other variable updated_encoding[var_name] = encoding[var_name] - + return updated_encoding - + def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: """Create optimized encoding for a pyramid level with xy-aligned sharding.""" encoding = {} - + # Calculate level-appropriate chunk sizes - chunk_size = max(256, self.spatial_chunk // (2 ** level)) - + chunk_size = max(256, self.spatial_chunk // (2**level)) + for var_name, var_data in dataset.data_vars.items(): if var_data.ndim >= 2: height, width = var_data.shape[-2:] - + # Use original geozarr.py chunk alignment logic spatial_chunk_aligned = min( chunk_size, self._calculate_aligned_chunk_size(width, chunk_size), self._calculate_aligned_chunk_size(height, chunk_size), ) - + if var_data.ndim == 3: # Single file per variable per time: chunk time dimension to 1 chunks = (1, spatial_chunk_aligned, spatial_chunk_aligned) @@ -838,55 +847,61 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: chunks = (spatial_chunk_aligned, spatial_chunk_aligned) else: chunks = (min(chunk_size, var_data.shape[0]),) - + # Configure encoding - use proper compressor following geozarr.py pattern from zarr.codecs import BloscCodec - compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) - var_encoding = { - 'chunks': chunks, - 'compressors': [compressor] - } - + + compressor = BloscCodec( + cname="zstd", clevel=3, shuffle="shuffle", blocksize=0 + ) + var_encoding = {"chunks": chunks, "compressors": [compressor]} + # Add simplified sharding if enabled - shards match x/y dimensions exactly if self.enable_sharding and var_data.ndim >= 2: - shard_dims = self._calculate_simple_shard_dimensions(var_data.shape, chunks) - var_encoding['shards'] = shard_dims - + shard_dims = self._calculate_simple_shard_dimensions( + var_data.shape, chunks + ) + var_encoding["shards"] = shard_dims + encoding[var_name] = var_encoding - + # Add coordinate encoding for coord_name in dataset.coords: - encoding[coord_name] = {'compressors': None} - + encoding[coord_name] = {"compressors": None} + return encoding - def _calculate_aligned_chunk_size(self, dimension_size: int, target_chunk: int) -> int: + def _calculate_aligned_chunk_size( + self, dimension_size: int, target_chunk: int + ) -> int: """ Calculate aligned chunk size following geozarr.py logic. - + This ensures good chunk alignment without complex calculations. """ if target_chunk >= dimension_size: return dimension_size - + # Find the largest divisor of dimension_size that's close to target_chunk best_chunk = target_chunk for chunk_candidate in range(target_chunk, max(target_chunk // 2, 1), -1): if dimension_size % chunk_candidate == 0: best_chunk = chunk_candidate break - + return best_chunk - def _calculate_simple_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) -> Tuple: + def _calculate_simple_shard_dimensions( + self, data_shape: Tuple, chunks: Tuple + ) -> Tuple: """ Calculate shard dimensions that are compatible with chunk dimensions. - + Shard dimensions must be evenly divisible by chunk dimensions for Zarr v3. When possible, shards should match x/y dimensions exactly as required. """ shard_dims = [] - + for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): if i == 0 and len(data_shape) == 3: # First dimension in 3D data (time) - use single time slice per shard @@ -906,60 +921,70 @@ def _calculate_simple_shard_dimensions(self, data_shape: Tuple, chunks: Tuple) - else: # Fallback: use chunk size itself shard_dims.append(chunk_size) - + return tuple(shard_dims) - def _rechunk_dataset_for_encoding(self, dataset: xr.Dataset, encoding: Dict) -> xr.Dataset: + def _rechunk_dataset_for_encoding( + self, dataset: xr.Dataset, encoding: Dict + ) -> xr.Dataset: """ Rechunk dataset variables to align with sharding dimensions when sharding is enabled. - - When using Zarr v3 sharding, Dask chunks must align with shard dimensions to avoid + + When using Zarr v3 sharding, Dask chunks must align with shard dimensions to avoid checksum validation errors. """ rechunked_vars = {} - + for var_name, var_data in dataset.data_vars.items(): if var_name in encoding: var_encoding = encoding[var_name] - + # If sharding is enabled, rechunk based on shard dimensions - if 'shards' in var_encoding and var_encoding['shards'] is not None: - target_chunks = var_encoding['shards'] # Use shard dimensions for rechunking - elif 'chunks' in var_encoding: - target_chunks = var_encoding['chunks'] # Fallback to chunk dimensions + if "shards" in var_encoding and var_encoding["shards"] is not None: + target_chunks = var_encoding[ + "shards" + ] # Use shard dimensions for rechunking + elif "chunks" in var_encoding: + target_chunks = var_encoding[ + "chunks" + ] # Fallback to chunk dimensions else: # No specific chunking needed, use original variable rechunked_vars[var_name] = var_data continue - + # Create chunk dict using the actual dimensions of the variable var_dims = var_data.dims chunk_dict = {} for i, dim in enumerate(var_dims): if i < len(target_chunks): chunk_dict[dim] = target_chunks[i] - + # Rechunk the variable to match the target dimensions rechunked_vars[var_name] = var_data.chunk(chunk_dict) else: # No specific chunking needed, use original variable rechunked_vars[var_name] = var_data - + # Create new dataset with rechunked variables, preserving coordinates - rechunked_dataset = xr.Dataset(rechunked_vars, coords=dataset.coords, attrs=dataset.attrs) - + rechunked_dataset = xr.Dataset( + rechunked_vars, coords=dataset.coords, attrs=dataset.attrs + ) + return rechunked_dataset - def _write_geo_metadata(self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref") -> None: + def _write_geo_metadata( + self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref" + ) -> None: """ Write geographic metadata to the dataset. Adds a grid_mapping variable and updates all data variables to reference it. """ - + # take the CRS from one of the data variables if available crs = None for var in dataset.data_vars.values(): - if hasattr(var, 'rio') and var.rio.crs: + if hasattr(var, "rio") and var.rio.crs: crs = var.rio.crs break elif "proj:epsg" in var.attrs: @@ -969,11 +994,12 @@ def _write_geo_metadata(self, dataset: xr.Dataset, grid_mapping_var_name: str = # Use standard CRS and transform if available if crs is not None: - dataset.rio.write_crs(crs, grid_mapping_name=grid_mapping_var_name, inplace=True) + dataset.rio.write_crs( + crs, grid_mapping_name=grid_mapping_var_name, inplace=True + ) # Add grid mapping variable dataset.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) - + # Set the grid mapping variable for all data variables for var in dataset.data_vars.values(): var.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) - diff --git a/src/eopf_geozarr/s2_optimization/s2_resampling.py b/src/eopf_geozarr/s2_optimization/s2_resampling.py index 5f835f04..0840e30c 100644 --- a/src/eopf_geozarr/s2_optimization/s2_resampling.py +++ b/src/eopf_geozarr/s2_optimization/s2_resampling.py @@ -5,53 +5,56 @@ import numpy as np import xarray as xr + class S2ResamplingEngine: """Handles downsampling operations for S2 multiscale creation.""" - + def __init__(self): self.resampling_methods = { - 'reflectance': self._downsample_reflectance, - 'classification': self._downsample_classification, - 'quality_mask': self._downsample_quality_mask, - 'probability': self._downsample_probability, - 'detector_footprint': self._downsample_quality_mask, # Same as quality mask + "reflectance": self._downsample_reflectance, + "classification": self._downsample_classification, + "quality_mask": self._downsample_quality_mask, + "probability": self._downsample_probability, + "detector_footprint": self._downsample_quality_mask, # Same as quality mask } - - def downsample_variable(self, data: xr.DataArray, target_height: int, - target_width: int, var_type: str) -> xr.DataArray: + + def downsample_variable( + self, data: xr.DataArray, target_height: int, target_width: int, var_type: str + ) -> xr.DataArray: """ Downsample a variable to target dimensions. - + Args: data: Input data array target_height: Target height in pixels - target_width: Target width in pixels + target_width: Target width in pixels var_type: Type of variable ('reflectance', 'classification', etc.) - + Returns: Downsampled data array """ if var_type not in self.resampling_methods: raise ValueError(f"Unknown variable type: {var_type}") - + method = self.resampling_methods[var_type] return method(data, target_height, target_width) - - def _downsample_reflectance(self, data: xr.DataArray, target_height: int, - target_width: int) -> xr.DataArray: + + def _downsample_reflectance( + self, data: xr.DataArray, target_height: int, target_width: int + ) -> xr.DataArray: """Block averaging for reflectance bands.""" # Calculate block sizes current_height, current_width = data.shape[-2:] block_h = current_height // target_height block_w = current_width // target_width - + # Ensure exact divisibility if current_height % target_height != 0 or current_width % target_width != 0: # Crop to make it divisible new_height = (current_height // block_h) * block_h new_width = (current_width // block_w) * block_w data = data[..., :new_height, :new_width] - + # Perform block averaging if data.ndim == 3: # (time, y, x) or similar reshaped = data.values.reshape( @@ -59,54 +62,53 @@ def _downsample_reflectance(self, data: xr.DataArray, target_height: int, ) downsampled = reshaped.mean(axis=(2, 4)) else: # (y, x) - reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + reshaped = data.values.reshape( + target_height, block_h, target_width, block_w + ) downsampled = reshaped.mean(axis=(1, 3)) - + # Create new coordinates y_coords = data.coords[data.dims[-2]][::block_h][:target_height] x_coords = data.coords[data.dims[-1]][::block_w][:target_width] - + # Create new DataArray if data.ndim == 3: coords = { data.dims[0]: data.coords[data.dims[0]], data.dims[-2]: y_coords, - data.dims[-1]: x_coords + data.dims[-1]: x_coords, } else: - coords = { - data.dims[-2]: y_coords, - data.dims[-1]: x_coords - } - + coords = {data.dims[-2]: y_coords, data.dims[-1]: x_coords} + return xr.DataArray( - downsampled, - dims=data.dims, - coords=coords, - attrs=data.attrs.copy() + downsampled, dims=data.dims, coords=coords, attrs=data.attrs.copy() ) - - def _downsample_classification(self, data: xr.DataArray, target_height: int, - target_width: int) -> xr.DataArray: + + def _downsample_classification( + self, data: xr.DataArray, target_height: int, target_width: int + ) -> xr.DataArray: """Mode-based downsampling for classification data.""" from scipy import stats - + current_height, current_width = data.shape[-2:] block_h = current_height // target_height block_w = current_width // target_width - + # Crop to make divisible new_height = (current_height // block_h) * block_h new_width = (current_width // block_w) * block_w data = data[..., :new_height, :new_width] - + # Reshape for block processing if data.ndim == 3: reshaped = data.values.reshape( data.shape[0], target_height, block_h, target_width, block_w ) # Compute mode for each block - downsampled = np.zeros((data.shape[0], target_height, target_width), dtype=data.dtype) + downsampled = np.zeros( + (data.shape[0], target_height, target_width), dtype=data.dtype + ) for t in range(data.shape[0]): for i in range(target_height): for j in range(target_width): @@ -114,49 +116,46 @@ def _downsample_classification(self, data: xr.DataArray, target_height: int, mode_val = stats.mode(block, keepdims=False)[0] downsampled[t, i, j] = mode_val else: - reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + reshaped = data.values.reshape( + target_height, block_h, target_width, block_w + ) downsampled = np.zeros((target_height, target_width), dtype=data.dtype) for i in range(target_height): for j in range(target_width): block = reshaped[i, :, j, :].flatten() mode_val = stats.mode(block, keepdims=False)[0] downsampled[i, j] = mode_val - + # Create coordinates y_coords = data.coords[data.dims[-2]][::block_h][:target_height] x_coords = data.coords[data.dims[-1]][::block_w][:target_width] - + if data.ndim == 3: coords = { data.dims[0]: data.coords[data.dims[0]], data.dims[-2]: y_coords, - data.dims[-1]: x_coords + data.dims[-1]: x_coords, } else: - coords = { - data.dims[-2]: y_coords, - data.dims[-1]: x_coords - } - + coords = {data.dims[-2]: y_coords, data.dims[-1]: x_coords} + return xr.DataArray( - downsampled, - dims=data.dims, - coords=coords, - attrs=data.attrs.copy() + downsampled, dims=data.dims, coords=coords, attrs=data.attrs.copy() ) - - def _downsample_quality_mask(self, data: xr.DataArray, target_height: int, - target_width: int) -> xr.DataArray: + + def _downsample_quality_mask( + self, data: xr.DataArray, target_height: int, target_width: int + ) -> xr.DataArray: """Logical OR downsampling for quality masks (any bad pixel = bad block).""" current_height, current_width = data.shape[-2:] block_h = current_height // target_height block_w = current_width // target_width - + # Crop to make divisible new_height = (current_height // block_h) * block_h new_width = (current_width // block_w) * block_w data = data[..., :new_height, :new_width] - + if data.ndim == 3: reshaped = data.values.reshape( data.shape[0], target_height, block_h, target_width, block_w @@ -164,73 +163,71 @@ def _downsample_quality_mask(self, data: xr.DataArray, target_height: int, # Any non-zero value in block makes the downsampled pixel non-zero downsampled = (reshaped.sum(axis=(2, 4)) > 0).astype(data.dtype) else: - reshaped = data.values.reshape(target_height, block_h, target_width, block_w) + reshaped = data.values.reshape( + target_height, block_h, target_width, block_w + ) downsampled = (reshaped.sum(axis=(1, 3)) > 0).astype(data.dtype) - + # Create coordinates y_coords = data.coords[data.dims[-2]][::block_h][:target_height] x_coords = data.coords[data.dims[-1]][::block_w][:target_width] - + if data.ndim == 3: coords = { data.dims[0]: data.coords[data.dims[0]], data.dims[-2]: y_coords, - data.dims[-1]: x_coords + data.dims[-1]: x_coords, } else: - coords = { - data.dims[-2]: y_coords, - data.dims[-1]: x_coords - } - + coords = {data.dims[-2]: y_coords, data.dims[-1]: x_coords} + return xr.DataArray( - downsampled, - dims=data.dims, - coords=coords, - attrs=data.attrs.copy() + downsampled, dims=data.dims, coords=coords, attrs=data.attrs.copy() ) - - def _downsample_probability(self, data: xr.DataArray, target_height: int, - target_width: int) -> xr.DataArray: + + def _downsample_probability( + self, data: xr.DataArray, target_height: int, target_width: int + ) -> xr.DataArray: """Average downsampling for probability data.""" # Use same method as reflectance but ensure values stay in [0,1] or [0,100] range result = self._downsample_reflectance(data, target_height, target_width) - + # Clamp values to valid probability range if result.max() <= 1.0: # [0,1] probabilities result.values = np.clip(result.values, 0, 1) else: # [0,100] percentages result.values = np.clip(result.values, 0, 100) - + return result + def determine_variable_type(var_name: str, var_data: xr.DataArray) -> str: """ Determine the type of a variable for appropriate resampling. - + Args: var_name: Name of the variable var_data: The data array - + Returns: Variable type string """ # Spectral bands - if var_name.startswith('b') and (var_name[1:].isdigit() or var_name == 'b8a'): - return 'reflectance' - + if var_name.startswith("b") and (var_name[1:].isdigit() or var_name == "b8a"): + return "reflectance" + # Quality data - if var_name in ['scl']: # Scene Classification Layer - return 'classification' - - if var_name in ['cld', 'snw']: # Probability data - return 'probability' - - if var_name in ['aot', 'wvp']: # Atmosphere quality - treat as reflectance - return 'reflectance' - - if var_name.startswith('detector_footprint_') or var_name.startswith('quality_'): - return 'quality_mask' - + if var_name in ["scl"]: # Scene Classification Layer + return "classification" + + if var_name in ["cld", "snw"]: # Probability data + return "probability" + + if var_name in ["aot", "wvp"]: # Atmosphere quality - treat as reflectance + return "reflectance" + + if var_name.startswith("detector_footprint_") or var_name.startswith("quality_"): + return "quality_mask" + # Default to reflectance for unknown variables - return 'reflectance' \ No newline at end of file + return "reflectance" diff --git a/src/eopf_geozarr/s2_optimization/s2_validation.py b/src/eopf_geozarr/s2_optimization/s2_validation.py index 030ac5c2..15727a33 100644 --- a/src/eopf_geozarr/s2_optimization/s2_validation.py +++ b/src/eopf_geozarr/s2_optimization/s2_validation.py @@ -2,28 +2,23 @@ Validation for optimized Sentinel-2 datasets. """ -from typing import Dict, Any -import xarray as xr +from typing import Any, Dict + class S2OptimizationValidator: """Validates optimized Sentinel-2 dataset structure and integrity.""" - + def validate_optimized_dataset(self, dataset_path: str) -> Dict[str, Any]: """ Validate an optimized Sentinel-2 dataset. - + Args: dataset_path: Path to the optimized dataset - + Returns: Validation results dictionary """ - results = { - 'is_valid': True, - 'issues': [], - 'warnings': [], - 'summary': {} - } - + results = {"is_valid": True, "issues": [], "warnings": [], "summary": {}} + # Placeholder for validation logic return results diff --git a/src/eopf_geozarr/tests/test_s2_band_mapping.py b/src/eopf_geozarr/tests/test_s2_band_mapping.py index e18f5051..859a100b 100644 --- a/src/eopf_geozarr/tests/test_s2_band_mapping.py +++ b/src/eopf_geozarr/tests/test_s2_band_mapping.py @@ -1,14 +1,13 @@ -import pytest from eopf_geozarr.s2_optimization.s2_band_mapping import ( - BandInfo, - NATIVE_BANDS, BAND_INFO, + NATIVE_BANDS, QUALITY_DATA_NATIVE, - DETECTOR_FOOTPRINT_NATIVE, + BandInfo, get_bands_for_level, get_quality_data_for_level, ) + def test_bandinfo_initialization(): band = BandInfo("b01", 60, "uint16", 443, 21) assert band.name == "b01" @@ -17,11 +16,13 @@ def test_bandinfo_initialization(): assert band.wavelength_center == 443 assert band.wavelength_width == 21 + def test_native_bands(): assert NATIVE_BANDS[10] == ["b02", "b03", "b04", "b08"] assert NATIVE_BANDS[20] == ["b05", "b06", "b07", "b11", "b12", "b8a"] assert NATIVE_BANDS[60] == ["b01", "b09"] + def test_band_info(): assert BAND_INFO["b01"].name == "b01" assert BAND_INFO["b01"].native_resolution == 60 @@ -29,6 +30,7 @@ def test_band_info(): assert BAND_INFO["b01"].wavelength_center == 443 assert BAND_INFO["b01"].wavelength_width == 21 + def test_quality_data_native(): assert QUALITY_DATA_NATIVE["scl"] == 20 assert QUALITY_DATA_NATIVE["aot"] == 20 @@ -36,11 +38,19 @@ def test_quality_data_native(): assert QUALITY_DATA_NATIVE["cld"] == 20 assert QUALITY_DATA_NATIVE["snw"] == 20 + def test_get_bands_for_level(): assert get_bands_for_level(0) == set(NATIVE_BANDS[10]) - assert get_bands_for_level(1) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) - assert get_bands_for_level(2) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) - assert get_bands_for_level(3) == set(NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60]) + assert get_bands_for_level(1) == set( + NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60] + ) + assert get_bands_for_level(2) == set( + NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60] + ) + assert get_bands_for_level(3) == set( + NATIVE_BANDS[10] + NATIVE_BANDS[20] + NATIVE_BANDS[60] + ) + def test_get_quality_data_for_level(): assert get_quality_data_for_level(0) == set() diff --git a/src/eopf_geozarr/tests/test_s2_converter.py b/src/eopf_geozarr/tests/test_s2_converter.py index 723a2220..8fd52cdb 100644 --- a/src/eopf_geozarr/tests/test_s2_converter.py +++ b/src/eopf_geozarr/tests/test_s2_converter.py @@ -3,31 +3,32 @@ """ import pytest -import xarray as xr from xarray import DataTree from eopf_geozarr.s2_optimization.s2_converter import S2OptimizedConverter + @pytest.fixture def mock_input_data(): """Create mock input DataTree for testing.""" # Placeholder for creating mock DataTree return DataTree() + def test_conversion_pipeline(mock_input_data, tmp_path): """Test the full conversion pipeline.""" output_path = tmp_path / "optimized_output" converter = S2OptimizedConverter(enable_sharding=True, spatial_chunk=1024) - + result = converter.convert_s2(mock_input_data, str(output_path)) - + # Validate multiscale data assert "multiscale_data" in result assert isinstance(result["multiscale_data"], dict) - + # Validate output path assert output_path.exists() - + # Validate validation results assert "validation_results" in result assert result["validation_results"]["is_valid"] diff --git a/src/eopf_geozarr/tests/test_s2_converter_simplified.py b/src/eopf_geozarr/tests/test_s2_converter_simplified.py index 47096485..aef8e868 100644 --- a/src/eopf_geozarr/tests/test_s2_converter_simplified.py +++ b/src/eopf_geozarr/tests/test_s2_converter_simplified.py @@ -5,14 +5,14 @@ """ import os -import tempfile import shutil -from unittest.mock import Mock, patch, MagicMock -import pytest +import tempfile +from unittest.mock import Mock, patch + import numpy as np +import pytest import xarray as xr import zarr -from rasterio.crs import CRS from eopf_geozarr.s2_optimization.s2_converter import S2OptimizedConverter @@ -22,33 +22,27 @@ def mock_s2_dataset(): """Create a mock S2 dataset for testing.""" # Create test data arrays coords = { - 'x': (['x'], np.linspace(0, 1000, 100)), - 'y': (['y'], np.linspace(0, 1000, 100)), - 'time': (['time'], [np.datetime64('2023-01-01')]) + "x": (["x"], np.linspace(0, 1000, 100)), + "y": (["y"], np.linspace(0, 1000, 100)), + "time": (["time"], [np.datetime64("2023-01-01")]), } - + # Create test variables data_vars = { - 'b02': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), - 'b03': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), - 'b04': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), + "b02": (["time", "y", "x"], np.random.rand(1, 100, 100)), + "b03": (["time", "y", "x"], np.random.rand(1, 100, 100)), + "b04": (["time", "y", "x"], np.random.rand(1, 100, 100)), } - + ds = xr.Dataset(data_vars, coords=coords) - + # Add rioxarray CRS - ds = ds.rio.write_crs('EPSG:32632') - + ds = ds.rio.write_crs("EPSG:32632") + # Create datatree dt = xr.DataTree(ds) - dt.attrs = { - 'stac_discovery': { - 'properties': { - 'mission': 'sentinel-2' - } - } - } - + dt.attrs = {"stac_discovery": {"properties": {"mission": "sentinel-2"}}} + return dt @@ -62,330 +56,357 @@ def temp_output_dir(): class TestS2OptimizedConverter: """Test the S2OptimizedConverter class.""" - + def test_init(self): """Test converter initialization.""" converter = S2OptimizedConverter( - enable_sharding=True, - spatial_chunk=512, - compression_level=5, - max_retries=2 + enable_sharding=True, spatial_chunk=512, compression_level=5, max_retries=2 ) - + assert converter.enable_sharding is True assert converter.spatial_chunk == 512 assert converter.compression_level == 5 assert converter.max_retries == 2 assert converter.pyramid_creator is not None assert converter.validator is not None - + def test_is_sentinel2_dataset_with_mission(self): """Test S2 detection via mission attribute.""" converter = S2OptimizedConverter() - + # Test with S2 mission dt = xr.DataTree() - dt.attrs = { - 'stac_discovery': { - 'properties': { - 'mission': 'sentinel-2a' - } - } - } - + dt.attrs = {"stac_discovery": {"properties": {"mission": "sentinel-2a"}}} + assert converter._is_sentinel2_dataset(dt) is True - + # Test with non-S2 mission - dt.attrs['stac_discovery']['properties']['mission'] = 'sentinel-1' + dt.attrs["stac_discovery"]["properties"]["mission"] = "sentinel-1" assert converter._is_sentinel2_dataset(dt) is False - + def test_is_sentinel2_dataset_with_groups(self): """Test S2 detection via characteristic groups.""" converter = S2OptimizedConverter() - + dt = xr.DataTree() dt.attrs = {} - + # Mock groups property using patch - with patch.object(type(dt), 'groups', new_callable=lambda: property(lambda self: [ - '/measurements/reflectance', - '/conditions/geometry', - '/quality/atmosphere' - ])): + with patch.object( + type(dt), + "groups", + new_callable=lambda: property( + lambda self: [ + "/measurements/reflectance", + "/conditions/geometry", + "/quality/atmosphere", + ] + ), + ): assert converter._is_sentinel2_dataset(dt) is True - + # Test with insufficient indicators - with patch.object(type(dt), 'groups', new_callable=lambda: property(lambda self: ['/measurements/reflectance'])): + with patch.object( + type(dt), + "groups", + new_callable=lambda: property(lambda self: ["/measurements/reflectance"]), + ): assert converter._is_sentinel2_dataset(dt) is False class TestMultiscalesMetadata: """Test multiscales metadata creation.""" - + def test_create_multiscales_metadata_with_rio(self, temp_output_dir): """Test multiscales metadata creation using rioxarray.""" converter = S2OptimizedConverter() - + # Create mock pyramid datasets with rioxarray pyramid_datasets = {} for level in [0, 1, 2]: # Create test dataset coords = { - 'x': (['x'], np.linspace(0, 1000, 100 // (2**level))), - 'y': (['y'], np.linspace(0, 1000, 100 // (2**level))) + "x": (["x"], np.linspace(0, 1000, 100 // (2**level))), + "y": (["y"], np.linspace(0, 1000, 100 // (2**level))), } data_vars = { - 'b02': (['y', 'x'], np.random.rand(100 // (2**level), 100 // (2**level))) + "b02": ( + ["y", "x"], + np.random.rand(100 // (2**level), 100 // (2**level)), + ) } ds = xr.Dataset(data_vars, coords=coords) - ds = ds.rio.write_crs('EPSG:32632') - + ds = ds.rio.write_crs("EPSG:32632") + pyramid_datasets[level] = ds - + # Test metadata creation metadata = converter._create_multiscales_metadata_with_rio(pyramid_datasets) - + # Verify structure matches geozarr.py format - assert 'tile_matrix_set' in metadata - assert 'resampling_method' in metadata - assert 'tile_matrix_limits' in metadata - assert metadata['resampling_method'] == 'average' - + assert "tile_matrix_set" in metadata + assert "resampling_method" in metadata + assert "tile_matrix_limits" in metadata + assert metadata["resampling_method"] == "average" + # Verify tile matrix set structure - tms = metadata['tile_matrix_set'] - assert 'id' in tms - assert 'crs' in tms - assert 'tileMatrices' in tms - assert len(tms['tileMatrices']) == 3 # 3 levels - + tms = metadata["tile_matrix_set"] + assert "id" in tms + assert "crs" in tms + assert "tileMatrices" in tms + assert len(tms["tileMatrices"]) == 3 # 3 levels + def test_create_multiscales_metadata_no_datasets(self): """Test metadata creation with no datasets.""" converter = S2OptimizedConverter() - + metadata = converter._create_multiscales_metadata_with_rio({}) assert metadata == {} - + def test_create_multiscales_metadata_no_crs(self): """Test metadata creation with datasets lacking CRS.""" converter = S2OptimizedConverter() - + # Create dataset without CRS - ds = xr.Dataset({'b02': (['y', 'x'], np.random.rand(10, 10))}) + ds = xr.Dataset({"b02": (["y", "x"], np.random.rand(10, 10))}) pyramid_datasets = {0: ds} - + metadata = converter._create_multiscales_metadata_with_rio(pyramid_datasets) assert metadata == {} class TestAuxiliaryGroupWriting: """Test auxiliary group writing functionality.""" - - @patch('eopf_geozarr.s2_optimization.s2_converter.distributed') - def test_write_auxiliary_group_with_distributed(self, mock_distributed, temp_output_dir): + + @patch("eopf_geozarr.s2_optimization.s2_converter.distributed") + def test_write_auxiliary_group_with_distributed( + self, mock_distributed, temp_output_dir + ): """Test auxiliary group writing with distributed available.""" converter = S2OptimizedConverter() - + # Create test dataset data_vars = { - 'solar_zenith': (['y', 'x'], np.random.rand(50, 50)), - 'solar_azimuth': (['y', 'x'], np.random.rand(50, 50)) + "solar_zenith": (["y", "x"], np.random.rand(50, 50)), + "solar_azimuth": (["y", "x"], np.random.rand(50, 50)), } coords = { - 'x': (['x'], np.linspace(0, 1000, 50)), - 'y': (['y'], np.linspace(0, 1000, 50)) + "x": (["x"], np.linspace(0, 1000, 50)), + "y": (["y"], np.linspace(0, 1000, 50)), } dataset = xr.Dataset(data_vars, coords=coords) - - group_path = os.path.join(temp_output_dir, 'geometry') - + + group_path = os.path.join(temp_output_dir, "geometry") + # Mock distributed progress mock_progress = Mock() mock_distributed.progress = mock_progress - + # Test writing - converter._write_auxiliary_group(dataset, group_path, 'geometry', verbose=True) - + converter._write_auxiliary_group(dataset, group_path, "geometry", verbose=True) + # Verify zarr group was created assert os.path.exists(group_path) - + # Verify group can be opened - zarr_group = zarr.open_group(group_path, mode='r') - assert 'solar_zenith' in zarr_group - assert 'solar_azimuth' in zarr_group - + zarr_group = zarr.open_group(group_path, mode="r") + assert "solar_zenith" in zarr_group + assert "solar_azimuth" in zarr_group + def test_write_auxiliary_group_without_distributed(self, temp_output_dir): """Test auxiliary group writing without distributed.""" converter = S2OptimizedConverter() - + # Create test dataset data_vars = { - 'temperature': (['y', 'x'], np.random.rand(30, 30)), - 'pressure': (['y', 'x'], np.random.rand(30, 30)) + "temperature": (["y", "x"], np.random.rand(30, 30)), + "pressure": (["y", "x"], np.random.rand(30, 30)), } coords = { - 'x': (['x'], np.linspace(0, 1000, 30)), - 'y': (['y'], np.linspace(0, 1000, 30)) + "x": (["x"], np.linspace(0, 1000, 30)), + "y": (["y"], np.linspace(0, 1000, 30)), } dataset = xr.Dataset(data_vars, coords=coords) - - group_path = os.path.join(temp_output_dir, 'meteorology') - + + group_path = os.path.join(temp_output_dir, "meteorology") + # Patch DISTRIBUTED_AVAILABLE to False - with patch('eopf_geozarr.s2_optimization.s2_converter.DISTRIBUTED_AVAILABLE', False): - converter._write_auxiliary_group(dataset, group_path, 'meteorology', verbose=False) - + with patch( + "eopf_geozarr.s2_optimization.s2_converter.DISTRIBUTED_AVAILABLE", False + ): + converter._write_auxiliary_group( + dataset, group_path, "meteorology", verbose=False + ) + # Verify zarr group was created assert os.path.exists(group_path) - + # Verify group can be opened - zarr_group = zarr.open_group(group_path, mode='r') - assert 'temperature' in zarr_group - assert 'pressure' in zarr_group + zarr_group = zarr.open_group(group_path, mode="r") + assert "temperature" in zarr_group + assert "pressure" in zarr_group class TestMetadataConsolidation: """Test metadata consolidation functionality.""" - + def test_add_measurements_multiscales_metadata(self, temp_output_dir): """Test adding multiscales metadata to measurements group.""" converter = S2OptimizedConverter() - + # Create measurements group structure - measurements_path = os.path.join(temp_output_dir, 'measurements') + measurements_path = os.path.join(temp_output_dir, "measurements") os.makedirs(measurements_path) - + # Create a minimal zarr group - zarr_group = zarr.open_group(measurements_path, mode='w') - zarr_group.attrs['test'] = 'value' - + zarr_group = zarr.open_group(measurements_path, mode="w") + zarr_group.attrs["test"] = "value" + # Create mock pyramid datasets pyramid_datasets = {} for level in [0, 1]: coords = { - 'x': (['x'], np.linspace(0, 1000, 50 // (2**level))), - 'y': (['y'], np.linspace(0, 1000, 50 // (2**level))) + "x": (["x"], np.linspace(0, 1000, 50 // (2**level))), + "y": (["y"], np.linspace(0, 1000, 50 // (2**level))), } data_vars = { - 'b02': (['y', 'x'], np.random.rand(50 // (2**level), 50 // (2**level))) + "b02": (["y", "x"], np.random.rand(50 // (2**level), 50 // (2**level))) } ds = xr.Dataset(data_vars, coords=coords) - ds = ds.rio.write_crs('EPSG:32632') + ds = ds.rio.write_crs("EPSG:32632") pyramid_datasets[level] = ds - + # Test adding metadata - converter._add_measurements_multiscales_metadata(temp_output_dir, pyramid_datasets) - + converter._add_measurements_multiscales_metadata( + temp_output_dir, pyramid_datasets + ) + # Verify metadata was added - zarr_group = zarr.open_group(measurements_path, mode='r') - assert 'multiscales' in zarr_group.attrs - - multiscales = zarr_group.attrs['multiscales'] - assert 'tile_matrix_set' in multiscales - assert 'resampling_method' in multiscales - assert 'tile_matrix_limits' in multiscales - - def test_add_measurements_multiscales_metadata_error_handling(self, temp_output_dir): + zarr_group = zarr.open_group(measurements_path, mode="r") + assert "multiscales" in zarr_group.attrs + + multiscales = zarr_group.attrs["multiscales"] + assert "tile_matrix_set" in multiscales + assert "resampling_method" in multiscales + assert "tile_matrix_limits" in multiscales + + def test_add_measurements_multiscales_metadata_error_handling( + self, temp_output_dir + ): """Test error handling in multiscales metadata addition.""" converter = S2OptimizedConverter() - + # Test with non-existent measurements path converter._add_measurements_multiscales_metadata(temp_output_dir, {}) - + # Should not raise an exception, just print warnings # (We can't easily test print output in unit tests, but the method should handle errors gracefully) - - @patch('xarray.open_zarr') + + @patch("xarray.open_zarr") def test_simple_root_consolidation_success(self, mock_open_zarr, temp_output_dir): """Test successful root consolidation with xarray.""" converter = S2OptimizedConverter() - + # Mock successful xarray consolidation mock_ds = Mock() mock_open_zarr.return_value.__enter__.return_value = mock_ds - + converter._simple_root_consolidation(temp_output_dir, {}) - + # Verify xarray.open_zarr was called with correct parameters mock_open_zarr.assert_called_once() args, kwargs = mock_open_zarr.call_args assert args[0] == temp_output_dir - assert kwargs['consolidated'] is True - assert kwargs['chunks'] == {} - - @patch('zarr.consolidate_metadata') - @patch('xarray.open_zarr') - def test_simple_root_consolidation_fallback(self, mock_open_zarr, mock_consolidate, temp_output_dir): + assert kwargs["consolidated"] is True + assert kwargs["chunks"] == {} + + @patch("zarr.consolidate_metadata") + @patch("xarray.open_zarr") + def test_simple_root_consolidation_fallback( + self, mock_open_zarr, mock_consolidate, temp_output_dir + ): """Test fallback to zarr consolidation when xarray fails.""" converter = S2OptimizedConverter() - + # Mock xarray failure mock_open_zarr.side_effect = Exception("xarray failed") - + converter._simple_root_consolidation(temp_output_dir, {}) - + # Verify fallback to zarr.consolidate_metadata mock_consolidate.assert_called_once() class TestEndToEndSimplified: """Test simplified end-to-end functionality with mocks.""" - - @patch('eopf_geozarr.s2_optimization.s2_converter.S2DataConsolidator') - @patch('eopf_geozarr.s2_optimization.s2_converter.S2MultiscalePyramid') - @patch('eopf_geozarr.s2_optimization.s2_converter.S2OptimizationValidator') - def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid, mock_consolidator, - mock_s2_dataset, temp_output_dir): + + @patch("eopf_geozarr.s2_optimization.s2_converter.S2DataConsolidator") + @patch("eopf_geozarr.s2_optimization.s2_converter.S2MultiscalePyramid") + @patch("eopf_geozarr.s2_optimization.s2_converter.S2OptimizationValidator") + def test_convert_s2_optimized_simplified_flow( + self, + mock_validator, + mock_pyramid, + mock_consolidator, + mock_s2_dataset, + temp_output_dir, + ): """Test the simplified conversion flow with all major components mocked.""" - + # Create mock pyramid datasets with rioxarray pyramid_datasets = {} for level in [0, 1]: coords = { - 'x': (['x'], np.linspace(0, 1000, 50 // (2**level))), - 'y': (['y'], np.linspace(0, 1000, 50 // (2**level))) + "x": (["x"], np.linspace(0, 1000, 50 // (2**level))), + "y": (["y"], np.linspace(0, 1000, 50 // (2**level))), } data_vars = { - 'b02': (['y', 'x'], np.random.rand(50 // (2**level), 50 // (2**level))) + "b02": (["y", "x"], np.random.rand(50 // (2**level), 50 // (2**level))) } ds = xr.Dataset(data_vars, coords=coords) - ds = ds.rio.write_crs('EPSG:32632') + ds = ds.rio.write_crs("EPSG:32632") pyramid_datasets[level] = ds - + # Mock consolidator mock_consolidator_instance = Mock() mock_consolidator.return_value = mock_consolidator_instance mock_consolidator_instance.consolidate_all_data.return_value = ( - {10: {'bands': {'b02': Mock(), 'b03': Mock()}}}, # measurements - {'solar_zenith': Mock()}, # geometry - {'temperature': Mock()} # meteorology + {10: {"bands": {"b02": Mock(), "b03": Mock()}}}, # measurements + {"solar_zenith": Mock()}, # geometry + {"temperature": Mock()}, # meteorology ) - + # Mock pyramid creator mock_pyramid_instance = Mock() mock_pyramid.return_value = mock_pyramid_instance - mock_pyramid_instance.create_multiscale_measurements.return_value = pyramid_datasets - + mock_pyramid_instance.create_multiscale_measurements.return_value = ( + pyramid_datasets + ) + # Mock validator mock_validator_instance = Mock() mock_validator.return_value = mock_validator_instance mock_validator_instance.validate_optimized_dataset.return_value = { - 'is_valid': True, - 'issues': [] + "is_valid": True, + "issues": [], } - + # Create converter and replace instances that were created during initialization converter = S2OptimizedConverter() converter.pyramid_creator = mock_pyramid_instance converter.validator = mock_validator_instance - + # Mock the multiscales metadata methods - with patch.object(converter, '_add_measurements_multiscales_metadata') as mock_add_metadata, \ - patch.object(converter, '_simple_root_consolidation') as mock_consolidation, \ - patch.object(converter, '_write_auxiliary_group') as mock_write_aux, \ - patch.object(converter, '_create_result_datatree') as mock_create_result: - + with ( + patch.object( + converter, "_add_measurements_multiscales_metadata" + ) as mock_add_metadata, + patch.object(converter, "_simple_root_consolidation") as mock_consolidation, + patch.object(converter, "_write_auxiliary_group") as mock_write_aux, + patch.object(converter, "_create_result_datatree") as mock_create_result, + ): mock_create_result.return_value = xr.DataTree() - + # Run conversion result = converter.convert_s2_optimized( mock_s2_dataset, @@ -393,36 +414,38 @@ def test_convert_s2_optimized_simplified_flow(self, mock_validator, mock_pyramid create_geometry_group=True, create_meteorology_group=True, validate_output=True, - verbose=True + verbose=True, ) - + # Verify all steps were called mock_consolidator_instance.consolidate_all_data.assert_called_once() mock_pyramid_instance.create_multiscale_measurements.assert_called_once() mock_write_aux.assert_called() # Should be called twice (geometry + meteorology) mock_add_metadata.assert_called_once_with(temp_output_dir, pyramid_datasets) - mock_consolidation.assert_called_once_with(temp_output_dir, pyramid_datasets) + mock_consolidation.assert_called_once_with( + temp_output_dir, pyramid_datasets + ) mock_validator_instance.validate_optimized_dataset.assert_called_once() - + assert result is not None class TestConvenienceFunction: """Test the convenience function.""" - - @patch('eopf_geozarr.s2_optimization.s2_converter.S2OptimizedConverter') + + @patch("eopf_geozarr.s2_optimization.s2_converter.S2OptimizedConverter") def test_convert_s2_optimized_convenience_function(self, mock_converter_class): """Test the convenience function parameter separation.""" from eopf_geozarr.s2_optimization.s2_converter import convert_s2_optimized - + mock_converter_instance = Mock() mock_converter_class.return_value = mock_converter_instance mock_converter_instance.convert_s2_optimized.return_value = Mock() - + # Test parameter separation dt_input = Mock() output_path = "/test/path" - + result = convert_s2_optimized( dt_input, output_path, @@ -432,26 +455,23 @@ def test_convert_s2_optimized_convenience_function(self, mock_converter_class): max_retries=5, create_geometry_group=False, validate_output=False, - verbose=True + verbose=True, ) - + # Verify constructor was called with correct args mock_converter_class.assert_called_once_with( - enable_sharding=False, - spatial_chunk=512, - compression_level=2, - max_retries=5 + enable_sharding=False, spatial_chunk=512, compression_level=2, max_retries=5 ) - + # Verify method was called with remaining args mock_converter_instance.convert_s2_optimized.assert_called_once_with( dt_input, output_path, create_geometry_group=False, validate_output=False, - verbose=True + verbose=True, ) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/src/eopf_geozarr/tests/test_s2_data_consolidator.py b/src/eopf_geozarr/tests/test_s2_data_consolidator.py index 3f876e65..89cd1d19 100644 --- a/src/eopf_geozarr/tests/test_s2_data_consolidator.py +++ b/src/eopf_geozarr/tests/test_s2_data_consolidator.py @@ -1,10 +1,10 @@ """Tests for S2 data consolidator module.""" -import pytest +from unittest.mock import MagicMock, Mock + import numpy as np +import pytest import xarray as xr -from unittest.mock import Mock, MagicMock -from typing import Dict, List, Tuple, Any from eopf_geozarr.s2_optimization.s2_data_consolidator import ( S2DataConsolidator, @@ -14,7 +14,7 @@ class TestS2DataConsolidator: """Test S2DataConsolidator class.""" - + @pytest.fixture def sample_s2_datatree(self): """Create a sample S2 DataTree structure for testing.""" @@ -25,351 +25,446 @@ def sample_s2_datatree(self): y_20m = y_10m[::2] x_60m = x_10m[::6] # 183 points y_60m = y_10m[::6] - time = np.array(['2023-01-15'], dtype='datetime64[ns]') - + time = np.array(["2023-01-15"], dtype="datetime64[ns]") + # Create sample data arrays data_10m = np.random.randint(0, 10000, (1, 1098, 1098), dtype=np.uint16) data_20m = np.random.randint(0, 10000, (1, 549, 549), dtype=np.uint16) data_60m = np.random.randint(0, 10000, (1, 183, 183), dtype=np.uint16) - + # Create datasets for different resolution groups (using lowercase band names) - ds_10m = xr.Dataset({ - 'b02': (['time', 'y', 'x'], data_10m), - 'b03': (['time', 'y', 'x'], data_10m.copy()), - 'b04': (['time', 'y', 'x'], data_10m.copy()), - 'b08': (['time', 'y', 'x'], data_10m.copy()), - }, coords={'time': time, 'x': x_10m, 'y': y_10m}) - - ds_20m = xr.Dataset({ - 'b05': (['time', 'y', 'x'], data_20m), - 'b06': (['time', 'y', 'x'], data_20m.copy()), - 'b07': (['time', 'y', 'x'], data_20m.copy()), - 'b8a': (['time', 'y', 'x'], data_20m.copy()), - 'b11': (['time', 'y', 'x'], data_20m.copy()), - 'b12': (['time', 'y', 'x'], data_20m.copy()), - 'aot': (['time', 'y', 'x'], data_20m.copy()), # atmosphere - 'wvp': (['time', 'y', 'x'], data_20m.copy()), - 'scl': (['time', 'y', 'x'], data_20m.copy()), # classification - 'cld': (['time', 'y', 'x'], data_20m.copy()), # probability - 'snw': (['time', 'y', 'x'], data_20m.copy()), - }, coords={'time': time, 'x': x_20m, 'y': y_20m}) - - ds_60m = xr.Dataset({ - 'b01': (['time', 'y', 'x'], data_60m), - 'b09': (['time', 'y', 'x'], data_60m.copy()), - }, coords={'time': time, 'x': x_60m, 'y': y_60m}) - + ds_10m = xr.Dataset( + { + "b02": (["time", "y", "x"], data_10m), + "b03": (["time", "y", "x"], data_10m.copy()), + "b04": (["time", "y", "x"], data_10m.copy()), + "b08": (["time", "y", "x"], data_10m.copy()), + }, + coords={"time": time, "x": x_10m, "y": y_10m}, + ) + + ds_20m = xr.Dataset( + { + "b05": (["time", "y", "x"], data_20m), + "b06": (["time", "y", "x"], data_20m.copy()), + "b07": (["time", "y", "x"], data_20m.copy()), + "b8a": (["time", "y", "x"], data_20m.copy()), + "b11": (["time", "y", "x"], data_20m.copy()), + "b12": (["time", "y", "x"], data_20m.copy()), + "aot": (["time", "y", "x"], data_20m.copy()), # atmosphere + "wvp": (["time", "y", "x"], data_20m.copy()), + "scl": (["time", "y", "x"], data_20m.copy()), # classification + "cld": (["time", "y", "x"], data_20m.copy()), # probability + "snw": (["time", "y", "x"], data_20m.copy()), + }, + coords={"time": time, "x": x_20m, "y": y_20m}, + ) + + ds_60m = xr.Dataset( + { + "b01": (["time", "y", "x"], data_60m), + "b09": (["time", "y", "x"], data_60m.copy()), + }, + coords={"time": time, "x": x_60m, "y": y_60m}, + ) + # Create quality datasets (using lowercase band names) - quality_10m = xr.Dataset({ - 'b02': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), - 'b03': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), - 'b04': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), - 'b08': (['time', 'y', 'x'], np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8)), - }, coords={'time': time, 'x': x_10m, 'y': y_10m}) - + quality_10m = xr.Dataset( + { + "b02": ( + ["time", "y", "x"], + np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8), + ), + "b03": ( + ["time", "y", "x"], + np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8), + ), + "b04": ( + ["time", "y", "x"], + np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8), + ), + "b08": ( + ["time", "y", "x"], + np.random.randint(0, 2, (1, 1098, 1098), dtype=np.uint8), + ), + }, + coords={"time": time, "x": x_10m, "y": y_10m}, + ) + # Create detector footprint datasets (using lowercase band names) - detector_10m = xr.Dataset({ - 'b02': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), - 'b03': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), - 'b04': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), - 'b08': (['time', 'y', 'x'], np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8)), - }, coords={'time': time, 'x': x_10m, 'y': y_10m}) - + detector_10m = xr.Dataset( + { + "b02": ( + ["time", "y", "x"], + np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8), + ), + "b03": ( + ["time", "y", "x"], + np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8), + ), + "b04": ( + ["time", "y", "x"], + np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8), + ), + "b08": ( + ["time", "y", "x"], + np.random.randint(0, 13, (1, 1098, 1098), dtype=np.uint8), + ), + }, + coords={"time": time, "x": x_10m, "y": y_10m}, + ) + # Create geometry data - geometry_ds = xr.Dataset({ - 'solar_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 549, 549))), - 'solar_azimuth_angle': (['time', 'y', 'x'], np.random.uniform(0, 360, (1, 549, 549))), - 'view_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 549, 549))), - 'view_azimuth_angle': (['time', 'y', 'x'], np.random.uniform(0, 360, (1, 549, 549))), - }, coords={'time': time, 'x': x_20m, 'y': y_20m}) - + geometry_ds = xr.Dataset( + { + "solar_zenith_angle": ( + ["time", "y", "x"], + np.random.uniform(0, 90, (1, 549, 549)), + ), + "solar_azimuth_angle": ( + ["time", "y", "x"], + np.random.uniform(0, 360, (1, 549, 549)), + ), + "view_zenith_angle": ( + ["time", "y", "x"], + np.random.uniform(0, 90, (1, 549, 549)), + ), + "view_azimuth_angle": ( + ["time", "y", "x"], + np.random.uniform(0, 360, (1, 549, 549)), + ), + }, + coords={"time": time, "x": x_20m, "y": y_20m}, + ) + # Create meteorology data - cams_ds = xr.Dataset({ - 'total_ozone': (['time', 'y', 'x'], np.random.uniform(200, 400, (1, 183, 183))), - 'relative_humidity': (['time', 'y', 'x'], np.random.uniform(0, 100, (1, 183, 183))), - }, coords={'time': time, 'x': x_60m, 'y': y_60m}) - - ecmwf_ds = xr.Dataset({ - 'temperature': (['time', 'y', 'x'], np.random.uniform(250, 320, (1, 183, 183))), - 'pressure': (['time', 'y', 'x'], np.random.uniform(950, 1050, (1, 183, 183))), - }, coords={'time': time, 'x': x_60m, 'y': y_60m}) - + cams_ds = xr.Dataset( + { + "total_ozone": ( + ["time", "y", "x"], + np.random.uniform(200, 400, (1, 183, 183)), + ), + "relative_humidity": ( + ["time", "y", "x"], + np.random.uniform(0, 100, (1, 183, 183)), + ), + }, + coords={"time": time, "x": x_60m, "y": y_60m}, + ) + + ecmwf_ds = xr.Dataset( + { + "temperature": ( + ["time", "y", "x"], + np.random.uniform(250, 320, (1, 183, 183)), + ), + "pressure": ( + ["time", "y", "x"], + np.random.uniform(950, 1050, (1, 183, 183)), + ), + }, + coords={"time": time, "x": x_60m, "y": y_60m}, + ) + # Build the mock DataTree structure mock_dt = MagicMock() mock_dt.groups = { - '/measurements/reflectance/r10m': Mock(), - '/measurements/reflectance/r20m': Mock(), - '/measurements/reflectance/r60m': Mock(), - '/quality/mask/r10m': Mock(), - '/quality/mask/r20m': Mock(), - '/quality/mask/r60m': Mock(), - '/conditions/mask/detector_footprint/r10m': Mock(), - '/conditions/mask/detector_footprint/r20m': Mock(), - '/conditions/mask/detector_footprint/r60m': Mock(), - '/quality/atmosphere/r20m': Mock(), - '/conditions/mask/l2a_classification/r20m': Mock(), - '/quality/probability/r20m': Mock(), - '/conditions/geometry': Mock(), - '/conditions/meteorology/cams': Mock(), - '/conditions/meteorology/ecmwf': Mock(), + "/measurements/reflectance/r10m": Mock(), + "/measurements/reflectance/r20m": Mock(), + "/measurements/reflectance/r60m": Mock(), + "/quality/mask/r10m": Mock(), + "/quality/mask/r20m": Mock(), + "/quality/mask/r60m": Mock(), + "/conditions/mask/detector_footprint/r10m": Mock(), + "/conditions/mask/detector_footprint/r20m": Mock(), + "/conditions/mask/detector_footprint/r60m": Mock(), + "/quality/atmosphere/r20m": Mock(), + "/conditions/mask/l2a_classification/r20m": Mock(), + "/quality/probability/r20m": Mock(), + "/conditions/geometry": Mock(), + "/conditions/meteorology/cams": Mock(), + "/conditions/meteorology/ecmwf": Mock(), } - + # Mock the dataset access def mock_getitem(self, path): mock_node = MagicMock() - if 'r10m' in path: - if 'reflectance' in path: + if "r10m" in path: + if "reflectance" in path: mock_node.to_dataset.return_value = ds_10m - elif 'quality/mask' in path: + elif "quality/mask" in path: mock_node.to_dataset.return_value = quality_10m - elif 'detector_footprint' in path: + elif "detector_footprint" in path: mock_node.to_dataset.return_value = detector_10m - elif 'r20m' in path: - if 'reflectance' in path: + elif "r20m" in path: + if "reflectance" in path: mock_node.to_dataset.return_value = ds_20m - elif 'atmosphere' in path: - mock_node.to_dataset.return_value = ds_20m[['aot', 'wvp']] - elif 'classification' in path: - mock_node.to_dataset.return_value = ds_20m[['scl']] - elif 'probability' in path: - mock_node.to_dataset.return_value = ds_20m[['cld', 'snw']] - elif 'r60m' in path: - if 'reflectance' in path: + elif "atmosphere" in path: + mock_node.to_dataset.return_value = ds_20m[["aot", "wvp"]] + elif "classification" in path: + mock_node.to_dataset.return_value = ds_20m[["scl"]] + elif "probability" in path: + mock_node.to_dataset.return_value = ds_20m[["cld", "snw"]] + elif "r60m" in path: + if "reflectance" in path: mock_node.to_dataset.return_value = ds_60m - elif 'geometry' in path: + elif "geometry" in path: mock_node.to_dataset.return_value = geometry_ds - elif 'cams' in path: + elif "cams" in path: mock_node.to_dataset.return_value = cams_ds - elif 'ecmwf' in path: + elif "ecmwf" in path: mock_node.to_dataset.return_value = ecmwf_ds - + return mock_node - + mock_dt.__getitem__ = mock_getitem return mock_dt - + def test_init(self, sample_s2_datatree): """Test consolidator initialization.""" consolidator = S2DataConsolidator(sample_s2_datatree) - + assert consolidator.dt_input == sample_s2_datatree assert consolidator.measurements_data == {} assert consolidator.geometry_data == {} assert consolidator.meteorology_data == {} - + def test_consolidate_all_data(self, sample_s2_datatree): """Test complete data consolidation.""" consolidator = S2DataConsolidator(sample_s2_datatree) measurements, geometry, meteorology = consolidator.consolidate_all_data() - + # Check that all three categories are returned assert isinstance(measurements, dict) assert isinstance(geometry, dict) assert isinstance(meteorology, dict) - + # Check resolution groups in measurements assert 10 in measurements assert 20 in measurements assert 60 in measurements - + # Check data categories exist for resolution in [10, 20, 60]: - assert 'bands' in measurements[resolution] - assert 'quality' in measurements[resolution] - assert 'detector_footprints' in measurements[resolution] - assert 'classification' in measurements[resolution] - assert 'atmosphere' in measurements[resolution] - assert 'probability' in measurements[resolution] - + assert "bands" in measurements[resolution] + assert "quality" in measurements[resolution] + assert "detector_footprints" in measurements[resolution] + assert "classification" in measurements[resolution] + assert "atmosphere" in measurements[resolution] + assert "probability" in measurements[resolution] + def test_extract_reflectance_bands(self, sample_s2_datatree): """Test reflectance band extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Check 10m bands - assert 'b02' in consolidator.measurements_data[10]['bands'] - assert 'b03' in consolidator.measurements_data[10]['bands'] - assert 'b04' in consolidator.measurements_data[10]['bands'] - assert 'b08' in consolidator.measurements_data[10]['bands'] - - # Check 20m bands - assert 'b05' in consolidator.measurements_data[20]['bands'] - assert 'b06' in consolidator.measurements_data[20]['bands'] - assert 'b11' in consolidator.measurements_data[20]['bands'] - assert 'b12' in consolidator.measurements_data[20]['bands'] - + assert "b02" in consolidator.measurements_data[10]["bands"] + assert "b03" in consolidator.measurements_data[10]["bands"] + assert "b04" in consolidator.measurements_data[10]["bands"] + assert "b08" in consolidator.measurements_data[10]["bands"] + + # Check 20m bands + assert "b05" in consolidator.measurements_data[20]["bands"] + assert "b06" in consolidator.measurements_data[20]["bands"] + assert "b11" in consolidator.measurements_data[20]["bands"] + assert "b12" in consolidator.measurements_data[20]["bands"] + # Check 60m bands - assert 'b01' in consolidator.measurements_data[60]['bands'] - assert 'b09' in consolidator.measurements_data[60]['bands'] - + assert "b01" in consolidator.measurements_data[60]["bands"] + assert "b09" in consolidator.measurements_data[60]["bands"] + def test_extract_quality_data(self, sample_s2_datatree): """Test quality data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Check quality data exists for native bands - assert 'quality_b02' in consolidator.measurements_data[10]['quality'] - assert 'quality_b03' in consolidator.measurements_data[10]['quality'] - + assert "quality_b02" in consolidator.measurements_data[10]["quality"] + assert "quality_b03" in consolidator.measurements_data[10]["quality"] + def test_extract_detector_footprints(self, sample_s2_datatree): """Test detector footprint extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Check detector footprint data - assert 'detector_footprint_b02' in consolidator.measurements_data[10]['detector_footprints'] - assert 'detector_footprint_b03' in consolidator.measurements_data[10]['detector_footprints'] - + assert ( + "detector_footprint_b02" + in consolidator.measurements_data[10]["detector_footprints"] + ) + assert ( + "detector_footprint_b03" + in consolidator.measurements_data[10]["detector_footprints"] + ) + def test_extract_atmosphere_data(self, sample_s2_datatree): """Test atmosphere data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Atmosphere data should be at 20m resolution - assert 'aot' in consolidator.measurements_data[20]['atmosphere'] - assert 'wvp' in consolidator.measurements_data[20]['atmosphere'] - + assert "aot" in consolidator.measurements_data[20]["atmosphere"] + assert "wvp" in consolidator.measurements_data[20]["atmosphere"] + def test_extract_classification_data(self, sample_s2_datatree): """Test classification data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Classification should be at 20m resolution - assert 'scl' in consolidator.measurements_data[20]['classification'] - + assert "scl" in consolidator.measurements_data[20]["classification"] + def test_extract_probability_data(self, sample_s2_datatree): """Test probability data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_measurements_data() - + # Probability data should be at 20m resolution - assert 'cld' in consolidator.measurements_data[20]['probability'] - assert 'snw' in consolidator.measurements_data[20]['probability'] - + assert "cld" in consolidator.measurements_data[20]["probability"] + assert "snw" in consolidator.measurements_data[20]["probability"] + def test_extract_geometry_data(self, sample_s2_datatree): """Test geometry data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_geometry_data() - + # Check that geometry variables are extracted - assert 'solar_zenith_angle' in consolidator.geometry_data - assert 'solar_azimuth_angle' in consolidator.geometry_data - assert 'view_zenith_angle' in consolidator.geometry_data - assert 'view_azimuth_angle' in consolidator.geometry_data - + assert "solar_zenith_angle" in consolidator.geometry_data + assert "solar_azimuth_angle" in consolidator.geometry_data + assert "view_zenith_angle" in consolidator.geometry_data + assert "view_azimuth_angle" in consolidator.geometry_data + def test_extract_meteorology_data(self, sample_s2_datatree): """Test meteorology data extraction.""" consolidator = S2DataConsolidator(sample_s2_datatree) consolidator._extract_meteorology_data() - + # Check CAMS data - assert 'cams_total_ozone' in consolidator.meteorology_data - assert 'cams_relative_humidity' in consolidator.meteorology_data - + assert "cams_total_ozone" in consolidator.meteorology_data + assert "cams_relative_humidity" in consolidator.meteorology_data + # Check ECMWF data - assert 'ecmwf_temperature' in consolidator.meteorology_data - assert 'ecmwf_pressure' in consolidator.meteorology_data - + assert "ecmwf_temperature" in consolidator.meteorology_data + assert "ecmwf_pressure" in consolidator.meteorology_data + def test_missing_groups_handling(self): """Test handling of missing data groups.""" # Create DataTree with missing groups mock_dt = MagicMock() mock_dt.groups = {} # No groups present - + consolidator = S2DataConsolidator(mock_dt) measurements, geometry, meteorology = consolidator.consolidate_all_data() - + # Should handle missing groups gracefully assert isinstance(measurements, dict) assert isinstance(geometry, dict) assert isinstance(meteorology, dict) - + # Data structures should be initialized but empty for resolution in [10, 20, 60]: assert resolution in measurements - for category in ['bands', 'quality', 'detector_footprints', 'classification', 'atmosphere', 'probability']: + for category in [ + "bands", + "quality", + "detector_footprints", + "classification", + "atmosphere", + "probability", + ]: assert category in measurements[resolution] assert len(measurements[resolution][category]) == 0 class TestCreateConsolidatedDataset: """Test the create_consolidated_dataset function.""" - + @pytest.fixture def sample_data_dict(self): """Create sample consolidated data dictionary.""" # Create coordinate arrays x = np.linspace(100000, 200000, 100) y = np.linspace(5000000, 5100000, 100) - time = np.array(['2023-01-15'], dtype='datetime64[ns]') - + time = np.array(["2023-01-15"], dtype="datetime64[ns]") + # Create sample data arrays data = np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16) - + return { - 'bands': { - 'b02': xr.DataArray(data, dims=['time', 'y', 'x'], - coords={'time': time, 'x': x, 'y': y}), - 'b03': xr.DataArray(data.copy(), dims=['time', 'y', 'x'], - coords={'time': time, 'x': x, 'y': y}), + "bands": { + "b02": xr.DataArray( + data, dims=["time", "y", "x"], coords={"time": time, "x": x, "y": y} + ), + "b03": xr.DataArray( + data.copy(), + dims=["time", "y", "x"], + coords={"time": time, "x": x, "y": y}, + ), }, - 'quality': { - 'quality_b02': xr.DataArray(np.random.randint(0, 2, (1, 100, 100), dtype=np.uint8), - dims=['time', 'y', 'x'], - coords={'time': time, 'x': x, 'y': y}), + "quality": { + "quality_b02": xr.DataArray( + np.random.randint(0, 2, (1, 100, 100), dtype=np.uint8), + dims=["time", "y", "x"], + coords={"time": time, "x": x, "y": y}, + ), + }, + "atmosphere": { + "aot": xr.DataArray( + np.random.uniform(0.1, 0.5, (1, 100, 100)), + dims=["time", "y", "x"], + coords={"time": time, "x": x, "y": y}, + ), }, - 'atmosphere': { - 'aot': xr.DataArray(np.random.uniform(0.1, 0.5, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'x': x, 'y': y}), - } } - + def test_create_consolidated_dataset_success(self, sample_data_dict): """Test successful dataset creation.""" ds = create_consolidated_dataset(sample_data_dict, resolution=10) - + assert isinstance(ds, xr.Dataset) - + # Check that all variables are included - expected_vars = {'b02', 'b03', 'quality_b02', 'aot'} + expected_vars = {"b02", "b03", "quality_b02", "aot"} assert set(ds.data_vars.keys()) == expected_vars - + # Check metadata - assert ds.attrs['native_resolution_meters'] == 10 - assert ds.attrs['processing_level'] == 'L2A' - assert ds.attrs['product_type'] == 'S2MSI2A' - + assert ds.attrs["native_resolution_meters"] == 10 + assert ds.attrs["processing_level"] == "L2A" + assert ds.attrs["product_type"] == "S2MSI2A" + # Check coordinates - assert 'x' in ds.coords - assert 'y' in ds.coords - assert 'time' in ds.coords - + assert "x" in ds.coords + assert "y" in ds.coords + assert "time" in ds.coords + def test_create_consolidated_dataset_empty_data(self): """Test dataset creation with empty data.""" - empty_data_dict = {'bands': {}, 'quality': {}, 'atmosphere': {}} + empty_data_dict = {"bands": {}, "quality": {}, "atmosphere": {}} ds = create_consolidated_dataset(empty_data_dict, resolution=20) - + # Should return empty dataset assert isinstance(ds, xr.Dataset) assert len(ds.data_vars) == 0 - + def test_create_consolidated_dataset_with_crs(self, sample_data_dict): """Test dataset creation with CRS information.""" # Add CRS to one of the data arrays - sample_data_dict['bands']['b02'] = sample_data_dict['bands']['b02'].rio.write_crs('EPSG:32632') - + sample_data_dict["bands"]["b02"] = sample_data_dict["bands"][ + "b02" + ].rio.write_crs("EPSG:32632") + ds = create_consolidated_dataset(sample_data_dict, resolution=10) - + assert isinstance(ds, xr.Dataset) # Check that CRS is propagated (assuming rio accessor is available) - if hasattr(ds, 'rio'): + if hasattr(ds, "rio"): assert ds.rio.crs is not None class TestIntegration: """Integration tests combining consolidator and dataset creation.""" - + @pytest.fixture def complete_s2_datatree(self): """Create a complete S2 DataTree for integration testing.""" @@ -379,49 +474,70 @@ def complete_s2_datatree(self): y_10m = np.linspace(5000000, 5100000, 100) x_20m = x_10m[::2] y_20m = y_10m[::2] - time = np.array(['2023-01-15'], dtype='datetime64[ns]') - + time = np.array(["2023-01-15"], dtype="datetime64[ns]") + # Create complete mock DataTree (simplified for integration test) mock_dt = MagicMock() mock_dt.groups = { - '/measurements/reflectance/r10m': Mock(), - '/conditions/geometry': Mock(), - '/conditions/meteorology/cams': Mock(), + "/measurements/reflectance/r10m": Mock(), + "/conditions/geometry": Mock(), + "/conditions/meteorology/cams": Mock(), } - + # Mock datasets - reflectance_10m = xr.Dataset({ - 'b02': (['time', 'y', 'x'], np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16)), - 'b03': (['time', 'y', 'x'], np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16)), - }, coords={'time': time, 'x': x_10m, 'y': y_10m}) - - geometry_ds = xr.Dataset({ - 'solar_zenith_angle': (['time', 'y', 'x'], np.random.uniform(0, 90, (1, 50, 50))), - }, coords={'time': time, 'x': x_20m, 'y': y_20m}) - - cams_ds = xr.Dataset({ - 'total_ozone': (['time', 'y', 'x'], np.random.uniform(200, 400, (1, 50, 50))), - }, coords={'time': time, 'x': x_20m, 'y': y_20m}) - + reflectance_10m = xr.Dataset( + { + "b02": ( + ["time", "y", "x"], + np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16), + ), + "b03": ( + ["time", "y", "x"], + np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16), + ), + }, + coords={"time": time, "x": x_10m, "y": y_10m}, + ) + + geometry_ds = xr.Dataset( + { + "solar_zenith_angle": ( + ["time", "y", "x"], + np.random.uniform(0, 90, (1, 50, 50)), + ), + }, + coords={"time": time, "x": x_20m, "y": y_20m}, + ) + + cams_ds = xr.Dataset( + { + "total_ozone": ( + ["time", "y", "x"], + np.random.uniform(200, 400, (1, 50, 50)), + ), + }, + coords={"time": time, "x": x_20m, "y": y_20m}, + ) + def mock_getitem(self, path): mock_node = MagicMock() - if '/measurements/reflectance/r10m' in path: + if "/measurements/reflectance/r10m" in path: mock_node.to_dataset.return_value = reflectance_10m - elif '/conditions/geometry' in path: + elif "/conditions/geometry" in path: mock_node.to_dataset.return_value = geometry_ds - elif '/conditions/meteorology/cams' in path: + elif "/conditions/meteorology/cams" in path: mock_node.to_dataset.return_value = cams_ds return mock_node - + mock_dt.__getitem__ = mock_getitem return mock_dt - + def test_end_to_end_consolidation(self, complete_s2_datatree): """Test complete end-to-end consolidation and dataset creation.""" # Step 1: Consolidate data consolidator = S2DataConsolidator(complete_s2_datatree) measurements, geometry, meteorology = consolidator.consolidate_all_data() - + # Step 2: Create consolidated datasets for each resolution consolidated_datasets = {} for resolution in [10, 20, 60]: @@ -429,53 +545,59 @@ def test_end_to_end_consolidation(self, complete_s2_datatree): ds = create_consolidated_dataset(measurements[resolution], resolution) if len(ds.data_vars) > 0: # Only keep non-empty datasets consolidated_datasets[resolution] = ds - + # Step 3: Verify results assert len(consolidated_datasets) > 0 - + # Check that 10m data is present (from our mock) if 10 in consolidated_datasets: ds_10m = consolidated_datasets[10] - assert 'b02' in ds_10m.data_vars - assert 'b03' in ds_10m.data_vars - assert ds_10m.attrs['native_resolution_meters'] == 10 - + assert "b02" in ds_10m.data_vars + assert "b03" in ds_10m.data_vars + assert ds_10m.attrs["native_resolution_meters"] == 10 + # Verify geometry data assert len(geometry) > 0 - geometry_ds = create_consolidated_dataset({'geometry': geometry}, resolution=20) + geometry_ds = create_consolidated_dataset({"geometry": geometry}, resolution=20) if len(geometry_ds.data_vars) > 0: - assert 'solar_zenith_angle' in geometry_ds.data_vars - - # Verify meteorology data + assert "solar_zenith_angle" in geometry_ds.data_vars + + # Verify meteorology data assert len(meteorology) > 0 - met_ds = create_consolidated_dataset({'meteorology': meteorology}, resolution=60) + met_ds = create_consolidated_dataset( + {"meteorology": meteorology}, resolution=60 + ) if len(met_ds.data_vars) > 0: - assert 'cams_total_ozone' in met_ds.data_vars + assert "cams_total_ozone" in met_ds.data_vars class TestEdgeCases: """Test edge cases and error conditions.""" - + def test_create_dataset_with_inconsistent_coordinates(self): """Test dataset creation with inconsistent coordinate systems.""" # Create data with mismatched coordinates x1 = np.linspace(100000, 200000, 50) - y1 = np.linspace(5000000, 5100000, 50) + y1 = np.linspace(5000000, 5100000, 50) x2 = np.linspace(100000, 200000, 100) # Different size y2 = np.linspace(5000000, 5100000, 100) - time = np.array(['2023-01-15'], dtype='datetime64[ns]') - + time = np.array(["2023-01-15"], dtype="datetime64[ns]") + inconsistent_data = { - 'bands': { - 'b02': xr.DataArray(np.random.randint(0, 10000, (1, 50, 50), dtype=np.uint16), - dims=['time', 'y', 'x'], - coords={'time': time, 'x': x1, 'y': y1}), - 'b03': xr.DataArray(np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16), - dims=['time', 'y', 'x'], - coords={'time': time, 'x': x2, 'y': y2}), + "bands": { + "b02": xr.DataArray( + np.random.randint(0, 10000, (1, 50, 50), dtype=np.uint16), + dims=["time", "y", "x"], + coords={"time": time, "x": x1, "y": y1}, + ), + "b03": xr.DataArray( + np.random.randint(0, 10000, (1, 100, 100), dtype=np.uint16), + dims=["time", "y", "x"], + coords={"time": time, "x": x2, "y": y2}, + ), } } - + # Should handle inconsistent coordinates gracefully or raise appropriate error # The exact behavior depends on xarray's handling of mixed coordinates try: diff --git a/src/eopf_geozarr/tests/test_s2_multiscale.py b/src/eopf_geozarr/tests/test_s2_multiscale.py index 6a03f1cd..198b605a 100644 --- a/src/eopf_geozarr/tests/test_s2_multiscale.py +++ b/src/eopf_geozarr/tests/test_s2_multiscale.py @@ -2,12 +2,12 @@ Tests for S2 multiscale pyramid creation with xy-aligned sharding. """ +import shutil +import tempfile +from unittest.mock import Mock, patch + import numpy as np import pytest -import tempfile -import shutil -from pathlib import Path -from unittest.mock import Mock, patch, MagicMock import xarray as xr from eopf_geozarr.s2_optimization.s2_multiscale import S2MultiscalePyramid @@ -26,36 +26,32 @@ def sample_dataset(self): """Create a sample xarray dataset for testing.""" x = np.linspace(0, 1000, 100) y = np.linspace(0, 1000, 100) - time = np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[ns]') - + time = np.array(["2023-01-01", "2023-01-02"], dtype="datetime64[ns]") + # Create sample variables with different dimensions b02 = xr.DataArray( np.random.randint(0, 4000, (2, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y, 'x': x}, - name='b02' + dims=["time", "y", "x"], + coords={"time": time, "y": y, "x": x}, + name="b02", ) - + b05 = xr.DataArray( np.random.randint(0, 4000, (2, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y, 'x': x}, - name='b05' + dims=["time", "y", "x"], + coords={"time": time, "y": y, "x": x}, + name="b05", ) - + scl = xr.DataArray( np.random.randint(0, 11, (2, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y, 'x': x}, - name='scl' + dims=["time", "y", "x"], + coords={"time": time, "y": y, "x": x}, + name="scl", ) - dataset = xr.Dataset({ - 'b02': b02, - 'b05': b05, - 'scl': scl - }) - + dataset = xr.Dataset({"b02": b02, "b05": b05, "scl": scl}) + return dataset @pytest.fixture @@ -67,50 +63,43 @@ def sample_measurements_by_resolution(self): y_20m = np.linspace(0, 1000, 100) x_60m = np.linspace(0, 1000, 50) y_60m = np.linspace(0, 1000, 50) - time = np.array(['2023-01-01'], dtype='datetime64[ns]') + time = np.array(["2023-01-01"], dtype="datetime64[ns]") # 10m data b02_10m = xr.DataArray( np.random.randint(0, 4000, (1, 200, 200)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - name='b02' + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + name="b02", ) # 20m data b05_20m = xr.DataArray( np.random.randint(0, 4000, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - name='b05' + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + name="b05", ) scl_20m = xr.DataArray( np.random.randint(0, 11, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - name='scl' + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + name="scl", ) # 60m data b01_60m = xr.DataArray( np.random.randint(0, 4000, (1, 50, 50)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_60m, 'x': x_60m}, - name='b01' + dims=["time", "y", "x"], + coords={"time": time, "y": y_60m, "x": x_60m}, + name="b01", ) return { - 10: { - 'reflectance': {'b02': b02_10m} - }, - 20: { - 'reflectance': {'b05': b05_20m}, - 'quality': {'scl': scl_20m} - }, - 60: { - 'reflectance': {'b01': b01_60m} - } + 10: {"reflectance": {"b02": b02_10m}}, + 20: {"reflectance": {"b05": b05_20m}, "quality": {"scl": scl_20m}}, + 60: {"reflectance": {"b01": b01_60m}}, } @pytest.fixture @@ -123,10 +112,10 @@ def temp_dir(self): def test_init(self): """Test S2MultiscalePyramid initialization.""" pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=512) - + assert pyramid.enable_sharding is True assert pyramid.spatial_chunk == 512 - assert hasattr(pyramid, 'resampler') + assert hasattr(pyramid, "resampler") assert len(pyramid.pyramid_levels) == 7 assert pyramid.pyramid_levels[0] == 10 assert pyramid.pyramid_levels[1] == 20 @@ -135,246 +124,256 @@ def test_init(self): def test_pyramid_levels_structure(self, pyramid): """Test the pyramid levels structure.""" expected_levels = { - 0: 10, # Level 0: 10m - 1: 20, # Level 1: 20m - 2: 60, # Level 2: 60m - 3: 120, # Level 3: 120m - 4: 240, # Level 4: 240m - 5: 480, # Level 5: 480m - 6: 960 # Level 6: 960m + 0: 10, # Level 0: 10m + 1: 20, # Level 1: 20m + 2: 60, # Level 2: 60m + 3: 120, # Level 3: 120m + 4: 240, # Level 4: 240m + 5: 480, # Level 5: 480m + 6: 960, # Level 6: 960m } - + assert pyramid.pyramid_levels == expected_levels def test_calculate_simple_shard_dimensions(self, pyramid): """Test simplified shard dimensions calculation.""" # Test 3D data (time, y, x) - shards match dimensions exactly data_shape = (5, 1000, 1000) - + shard_dims = pyramid._calculate_simple_shard_dimensions(data_shape) - + assert len(shard_dims) == 3 - assert shard_dims[0] == 1 # Time dimension should be 1 + assert shard_dims[0] == 1 # Time dimension should be 1 assert shard_dims[1] == 1000 # Y dimension matches exactly assert shard_dims[2] == 1000 # X dimension matches exactly - + # Test 2D data (y, x) - shards match dimensions exactly data_shape = (500, 800) - + shard_dims = pyramid._calculate_simple_shard_dimensions(data_shape) - + assert len(shard_dims) == 2 - assert shard_dims[0] == 500 # Y dimension matches exactly - assert shard_dims[1] == 800 # X dimension matches exactly + assert shard_dims[0] == 500 # Y dimension matches exactly + assert shard_dims[1] == 800 # X dimension matches exactly def test_create_level_encoding(self, pyramid, sample_dataset): """Test level encoding creation with xy-aligned sharding.""" encoding = pyramid._create_level_encoding(sample_dataset, level=1) - + # Check that encoding is created for all variables for var_name in sample_dataset.data_vars: assert var_name in encoding var_encoding = encoding[var_name] - + # Check basic encoding structure - assert 'chunks' in var_encoding - assert 'compressor' in var_encoding - + assert "chunks" in var_encoding + assert "compressor" in var_encoding + # Check sharding is included when enabled if pyramid.enable_sharding: - assert 'shards' in var_encoding - + assert "shards" in var_encoding + # Check coordinate encoding for coord_name in sample_dataset.coords: if coord_name in encoding: - assert encoding[coord_name]['compressor'] is None + assert encoding[coord_name]["compressor"] is None def test_create_level_encoding_time_chunking(self, pyramid, sample_dataset): """Test that time dimension is chunked to 1 for single file per time.""" encoding = pyramid._create_level_encoding(sample_dataset, level=0) - + for var_name in sample_dataset.data_vars: if sample_dataset[var_name].ndim == 3: # 3D variable with time - chunks = encoding[var_name]['chunks'] + chunks = encoding[var_name]["chunks"] assert chunks[0] == 1 # Time dimension should be chunked to 1 def test_should_separate_time_files(self, pyramid): """Test time file separation detection.""" # Create dataset with multiple time points - time = np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[ns]') + time = np.array(["2023-01-01", "2023-01-02"], dtype="datetime64[ns]") x = np.linspace(0, 100, 10) y = np.linspace(0, 100, 10) - + data_multi_time = xr.DataArray( np.random.rand(2, 10, 10), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y, 'x': x} + dims=["time", "y", "x"], + coords={"time": time, "y": y, "x": x}, ) - - dataset_multi_time = xr.Dataset({'var1': data_multi_time}) + + dataset_multi_time = xr.Dataset({"var1": data_multi_time}) assert pyramid._should_separate_time_files(dataset_multi_time) is True - + # Create dataset with single time point data_single_time = xr.DataArray( np.random.rand(1, 10, 10), - dims=['time', 'y', 'x'], - coords={'time': time[:1], 'y': y, 'x': x} + dims=["time", "y", "x"], + coords={"time": time[:1], "y": y, "x": x}, ) - - dataset_single_time = xr.Dataset({'var1': data_single_time}) + + dataset_single_time = xr.Dataset({"var1": data_single_time}) assert pyramid._should_separate_time_files(dataset_single_time) is False - + # Create dataset with no time dimension data_no_time = xr.DataArray( - np.random.rand(10, 10), - dims=['y', 'x'], - coords={'y': y, 'x': x} + np.random.rand(10, 10), dims=["y", "x"], coords={"y": y, "x": x} ) - - dataset_no_time = xr.Dataset({'var1': data_no_time}) + + dataset_no_time = xr.Dataset({"var1": data_no_time}) assert pyramid._should_separate_time_files(dataset_no_time) is False def test_update_encoding_for_time_slice(self, pyramid): """Test encoding update for time slices.""" # Original encoding with 3D chunks original_encoding = { - 'var1': { - 'chunks': (1, 100, 100), - 'shards': (1, 200, 200), - 'compressor': 'default' + "var1": { + "chunks": (1, 100, 100), + "shards": (1, 200, 200), + "compressor": "default", }, - 'x': {'compressor': None}, - 'y': {'compressor': None} + "x": {"compressor": None}, + "y": {"compressor": None}, } - + # Create a time slice dataset x = np.linspace(0, 100, 100) y = np.linspace(0, 100, 100) - - time_slice = xr.Dataset({ - 'var1': xr.DataArray( - np.random.rand(100, 100), - dims=['y', 'x'], - coords={'y': y, 'x': x} - ) - }) - - updated_encoding = pyramid._update_encoding_for_time_slice(original_encoding, time_slice) - + + time_slice = xr.Dataset( + { + "var1": xr.DataArray( + np.random.rand(100, 100), dims=["y", "x"], coords={"y": y, "x": x} + ) + } + ) + + updated_encoding = pyramid._update_encoding_for_time_slice( + original_encoding, time_slice + ) + # Check that time dimension is removed from chunks and shards - assert updated_encoding['var1']['chunks'] == (100, 100) - assert updated_encoding['var1']['shards'] == (200, 200) - assert updated_encoding['var1']['compressor'] == 'default' - - # Check coordinates are preserved - assert updated_encoding['x']['compressor'] is None - assert updated_encoding['y']['compressor'] is None + assert updated_encoding["var1"]["chunks"] == (100, 100) + assert updated_encoding["var1"]["shards"] == (200, 200) + assert updated_encoding["var1"]["compressor"] == "default" - @patch('builtins.print') - @patch('xarray.Dataset.to_zarr') - def test_write_level_dataset_no_time(self, mock_to_zarr, mock_print, pyramid, sample_dataset, temp_dir): + # Check coordinates are preserved + assert updated_encoding["x"]["compressor"] is None + assert updated_encoding["y"]["compressor"] is None + + @patch("builtins.print") + @patch("xarray.Dataset.to_zarr") + def test_write_level_dataset_no_time( + self, mock_to_zarr, mock_print, pyramid, sample_dataset, temp_dir + ): """Test writing level dataset without time separation.""" # Create dataset without multiple time points single_time_dataset = sample_dataset.isel(time=0) - + pyramid._write_level_dataset(single_time_dataset, temp_dir, level=0) - + # Should call to_zarr once (no time separation) mock_to_zarr.assert_called_once() args, kwargs = mock_to_zarr.call_args - - assert kwargs['mode'] == 'w' - assert kwargs['consolidated'] is True - assert kwargs['zarr_format'] == 3 - assert 'encoding' in kwargs - - @patch('builtins.print') - def test_write_level_dataset_with_time_separation(self, mock_print, pyramid, sample_dataset, temp_dir): + + assert kwargs["mode"] == "w" + assert kwargs["consolidated"] is True + assert kwargs["zarr_format"] == 3 + assert "encoding" in kwargs + + @patch("builtins.print") + def test_write_level_dataset_with_time_separation( + self, mock_print, pyramid, sample_dataset, temp_dir + ): """Test writing level dataset with time separation.""" - with patch.object(pyramid, '_write_time_separated_dataset') as mock_time_sep: + with patch.object(pyramid, "_write_time_separated_dataset") as mock_time_sep: pyramid._write_level_dataset(sample_dataset, temp_dir, level=0) - + # Should call time separation method mock_time_sep.assert_called_once() def test_create_level_0_dataset(self, pyramid, sample_measurements_by_resolution): """Test level 0 dataset creation.""" dataset = pyramid._create_level_0_dataset(sample_measurements_by_resolution) - + assert len(dataset.data_vars) > 0 - assert dataset.attrs['pyramid_level'] == 0 - assert dataset.attrs['resolution_meters'] == 10 - + assert dataset.attrs["pyramid_level"] == 0 + assert dataset.attrs["resolution_meters"] == 10 + # Should only contain 10m native data - assert 'b02' in dataset.data_vars + assert "b02" in dataset.data_vars def test_create_level_0_dataset_no_10m_data(self, pyramid): """Test level 0 dataset creation with no 10m data.""" measurements_no_10m = { - 20: {'reflectance': {'b05': Mock()}}, - 60: {'reflectance': {'b01': Mock()}} + 20: {"reflectance": {"b05": Mock()}}, + 60: {"reflectance": {"b01": Mock()}}, } - + dataset = pyramid._create_level_0_dataset(measurements_no_10m) assert len(dataset.data_vars) == 0 - @patch.object(S2MultiscalePyramid, '_create_level_0_dataset') - @patch.object(S2MultiscalePyramid, '_create_level_1_dataset') - @patch.object(S2MultiscalePyramid, '_create_level_2_dataset') - @patch.object(S2MultiscalePyramid, '_create_downsampled_dataset') - def test_create_level_dataset_routing(self, mock_downsampled, mock_level2, mock_level1, mock_level0, pyramid): + @patch.object(S2MultiscalePyramid, "_create_level_0_dataset") + @patch.object(S2MultiscalePyramid, "_create_level_1_dataset") + @patch.object(S2MultiscalePyramid, "_create_level_2_dataset") + @patch.object(S2MultiscalePyramid, "_create_downsampled_dataset") + def test_create_level_dataset_routing( + self, mock_downsampled, mock_level2, mock_level1, mock_level0, pyramid + ): """Test that _create_level_dataset routes to correct methods.""" measurements = {} - + # Test level 0 pyramid._create_level_dataset(0, 10, measurements) mock_level0.assert_called_once_with(measurements) - + # Test level 1 pyramid._create_level_dataset(1, 20, measurements) mock_level1.assert_called_once_with(measurements) - + # Test level 2 pyramid._create_level_dataset(2, 60, measurements) mock_level2.assert_called_once_with(measurements) - + # Test level 3+ pyramid._create_level_dataset(3, 120, measurements) mock_downsampled.assert_called_once_with(3, 120, measurements) - @patch('builtins.print') - @patch.object(S2MultiscalePyramid, '_write_level_dataset') - @patch.object(S2MultiscalePyramid, '_create_level_dataset') - def test_create_multiscale_measurements(self, mock_create, mock_write, mock_print, pyramid, temp_dir): + @patch("builtins.print") + @patch.object(S2MultiscalePyramid, "_write_level_dataset") + @patch.object(S2MultiscalePyramid, "_create_level_dataset") + def test_create_multiscale_measurements( + self, mock_create, mock_write, mock_print, pyramid, temp_dir + ): """Test multiscale measurements creation.""" # Mock dataset creation mock_dataset = Mock() - mock_dataset.data_vars = {'b02': Mock()} # Non-empty dataset + mock_dataset.data_vars = {"b02": Mock()} # Non-empty dataset mock_create.return_value = mock_dataset - - measurements = {10: {'reflectance': {'b02': Mock()}}} - + + measurements = {10: {"reflectance": {"b02": Mock()}}} + result = pyramid.create_multiscale_measurements(measurements, temp_dir) - + # Should create all pyramid levels assert len(result) == len(pyramid.pyramid_levels) assert mock_create.call_count == len(pyramid.pyramid_levels) assert mock_write.call_count == len(pyramid.pyramid_levels) - @patch('builtins.print') - @patch.object(S2MultiscalePyramid, '_write_level_dataset') - @patch.object(S2MultiscalePyramid, '_create_level_dataset') - def test_create_multiscale_measurements_empty_dataset(self, mock_create, mock_write, mock_print, pyramid, temp_dir): + @patch("builtins.print") + @patch.object(S2MultiscalePyramid, "_write_level_dataset") + @patch.object(S2MultiscalePyramid, "_create_level_dataset") + def test_create_multiscale_measurements_empty_dataset( + self, mock_create, mock_write, mock_print, pyramid, temp_dir + ): """Test multiscale measurements creation with empty dataset.""" # Mock empty dataset creation mock_dataset = Mock() mock_dataset.data_vars = {} # Empty dataset mock_create.return_value = mock_dataset - + measurements = {} - + result = pyramid.create_multiscale_measurements(measurements, temp_dir) - + # Should not include empty datasets assert len(result) == 0 assert mock_write.call_count == 0 @@ -386,36 +385,36 @@ def test_create_level_1_dataset_with_downsampling(self, pyramid): y_20m = np.linspace(0, 1000, 100) x_10m = np.linspace(0, 1000, 200) y_10m = np.linspace(0, 1000, 200) - time = np.array(['2023-01-01'], dtype='datetime64[ns]') + time = np.array(["2023-01-01"], dtype="datetime64[ns]") # 20m native data b05_20m = xr.DataArray( np.random.randint(0, 4000, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - name='b05' + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + name="b05", ) # 10m data to be downsampled b02_10m = xr.DataArray( np.random.randint(0, 4000, (1, 200, 200)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - name='b02' + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + name="b02", ) measurements = { - 20: {'reflectance': {'b05': b05_20m}}, - 10: {'reflectance': {'b02': b02_10m}} + 20: {"reflectance": {"b05": b05_20m}}, + 10: {"reflectance": {"b02": b02_10m}}, } - with patch.object(pyramid.resampler, 'downsample_variable') as mock_downsample: + with patch.object(pyramid.resampler, "downsample_variable") as mock_downsample: # Mock the downsampling to return a properly shaped array mock_downsampled = xr.DataArray( np.random.randint(0, 4000, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - name='b02' + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + name="b02", ) mock_downsample.return_value = mock_downsampled @@ -423,23 +422,25 @@ def test_create_level_1_dataset_with_downsampling(self, pyramid): # Should call downsampling for 10m data mock_downsample.assert_called() - - # Should contain both native 20m and downsampled 10m data - assert 'b05' in dataset.data_vars - assert 'b02' in dataset.data_vars - assert dataset.attrs['pyramid_level'] == 1 - assert dataset.attrs['resolution_meters'] == 20 - def test_create_level_2_dataset_structure(self, pyramid, sample_measurements_by_resolution): + # Should contain both native 20m and downsampled 10m data + assert "b05" in dataset.data_vars + assert "b02" in dataset.data_vars + assert dataset.attrs["pyramid_level"] == 1 + assert dataset.attrs["resolution_meters"] == 20 + + def test_create_level_2_dataset_structure( + self, pyramid, sample_measurements_by_resolution + ): """Test level 2 dataset creation according to optimization plan.""" dataset = pyramid._create_level_2_dataset(sample_measurements_by_resolution) - + # Check basic structure - assert dataset.attrs['pyramid_level'] == 2 - assert dataset.attrs['resolution_meters'] == 60 - + assert dataset.attrs["pyramid_level"] == 2 + assert dataset.attrs["resolution_meters"] == 60 + # Should contain 60m native data - assert 'b01' in dataset.data_vars + assert "b01" in dataset.data_vars def test_create_level_2_dataset_with_downsampling(self, pyramid): """Test level 2 dataset creation with 20m data downsampling.""" @@ -448,36 +449,36 @@ def test_create_level_2_dataset_with_downsampling(self, pyramid): y_60m = np.linspace(0, 1000, 50) x_20m = np.linspace(0, 1000, 100) y_20m = np.linspace(0, 1000, 100) - time = np.array(['2023-01-01'], dtype='datetime64[ns]') + time = np.array(["2023-01-01"], dtype="datetime64[ns]") # 60m native data b01_60m = xr.DataArray( np.random.randint(0, 4000, (1, 50, 50)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_60m, 'x': x_60m}, - name='b01' + dims=["time", "y", "x"], + coords={"time": time, "y": y_60m, "x": x_60m}, + name="b01", ) # 20m data to be downsampled scl_20m = xr.DataArray( np.random.randint(0, 11, (1, 100, 100)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - name='scl' + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + name="scl", ) measurements = { - 60: {'reflectance': {'b01': b01_60m}}, - 20: {'quality': {'scl': scl_20m}} + 60: {"reflectance": {"b01": b01_60m}}, + 20: {"quality": {"scl": scl_20m}}, } - with patch.object(pyramid.resampler, 'downsample_variable') as mock_downsample: + with patch.object(pyramid.resampler, "downsample_variable") as mock_downsample: # Mock the downsampling to return a properly shaped array mock_downsampled = xr.DataArray( np.random.randint(0, 11, (1, 50, 50)), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_60m, 'x': x_60m}, - name='scl' + dims=["time", "y", "x"], + coords={"time": time, "y": y_60m, "x": x_60m}, + name="scl", ) mock_downsample.return_value = mock_downsampled @@ -485,17 +486,17 @@ def test_create_level_2_dataset_with_downsampling(self, pyramid): # Should call downsampling for 20m data mock_downsample.assert_called() - + # Should contain both native 60m and downsampled 20m data - assert 'b01' in dataset.data_vars - assert 'scl' in dataset.data_vars - assert dataset.attrs['pyramid_level'] == 2 - assert dataset.attrs['resolution_meters'] == 60 + assert "b01" in dataset.data_vars + assert "scl" in dataset.data_vars + assert dataset.attrs["pyramid_level"] == 2 + assert dataset.attrs["resolution_meters"] == 60 def test_error_handling_invalid_level(self, pyramid): """Test error handling for invalid pyramid levels.""" measurements = {} - + # Test with invalid level (should work but return empty dataset if no source data) dataset = pyramid._create_level_dataset(-1, 5, measurements) # Should create downsampled dataset (empty in this case) @@ -508,16 +509,16 @@ class TestS2MultiscalePyramidIntegration: @pytest.fixture def real_measurements_data(self): """Create realistic measurements data for integration testing.""" - time = np.array(['2023-06-15T10:30:00'], dtype='datetime64[ns]') - + time = np.array(["2023-06-15T10:30:00"], dtype="datetime64[ns]") + # 10m resolution data (200x200 pixels) x_10m = np.linspace(300000, 310000, 200) # UTM coordinates y_10m = np.linspace(4900000, 4910000, 200) - - # 20m resolution data (100x100 pixels) + + # 20m resolution data (100x100 pixels) x_20m = np.linspace(300000, 310000, 100) y_20m = np.linspace(4900000, 4910000, 100) - + # 60m resolution data (50x50 pixels) x_60m = np.linspace(300000, 310000, 50) y_60m = np.linspace(4900000, 4910000, 50) @@ -525,188 +526,200 @@ def real_measurements_data(self): # Create realistic spectral bands measurements = { 10: { - 'reflectance': { - 'b02': xr.DataArray( + "reflectance": { + "b02": xr.DataArray( np.random.randint(500, 3000, (1, 200, 200), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - attrs={'long_name': 'Blue band', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + attrs={"long_name": "Blue band", "units": "digital_number"}, ), - 'b03': xr.DataArray( + "b03": xr.DataArray( np.random.randint(600, 3500, (1, 200, 200), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - attrs={'long_name': 'Green band', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + attrs={"long_name": "Green band", "units": "digital_number"}, ), - 'b04': xr.DataArray( + "b04": xr.DataArray( np.random.randint(400, 3200, (1, 200, 200), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - attrs={'long_name': 'Red band', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + attrs={"long_name": "Red band", "units": "digital_number"}, ), - 'b08': xr.DataArray( + "b08": xr.DataArray( np.random.randint(3000, 6000, (1, 200, 200), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_10m, 'x': x_10m}, - attrs={'long_name': 'NIR band', 'units': 'digital_number'} - ) + dims=["time", "y", "x"], + coords={"time": time, "y": y_10m, "x": x_10m}, + attrs={"long_name": "NIR band", "units": "digital_number"}, + ), } }, 20: { - 'reflectance': { - 'b05': xr.DataArray( + "reflectance": { + "b05": xr.DataArray( np.random.randint(2000, 4000, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Red edge 1', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "Red edge 1", "units": "digital_number"}, ), - 'b06': xr.DataArray( + "b06": xr.DataArray( np.random.randint(2500, 4500, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Red edge 2', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "Red edge 2", "units": "digital_number"}, ), - 'b07': xr.DataArray( + "b07": xr.DataArray( np.random.randint(2800, 4800, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Red edge 3', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "Red edge 3", "units": "digital_number"}, ), - 'b11': xr.DataArray( + "b11": xr.DataArray( np.random.randint(1000, 3000, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'SWIR 1', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "SWIR 1", "units": "digital_number"}, ), - 'b12': xr.DataArray( + "b12": xr.DataArray( np.random.randint(500, 2500, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'SWIR 2', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "SWIR 2", "units": "digital_number"}, ), - 'b8a': xr.DataArray( + "b8a": xr.DataArray( np.random.randint(2800, 5500, (1, 100, 100), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'NIR narrow', 'units': 'digital_number'} - ) + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "NIR narrow", "units": "digital_number"}, + ), }, - 'quality': { - 'scl': xr.DataArray( + "quality": { + "scl": xr.DataArray( np.random.randint(0, 11, (1, 100, 100), dtype=np.uint8), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Scene classification', 'units': 'class'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "Scene classification", "units": "class"}, ), - 'aot': xr.DataArray( + "aot": xr.DataArray( np.random.randint(0, 1000, (1, 100, 100), dtype=np.uint16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Aerosol optical thickness', 'units': 'dimensionless'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={ + "long_name": "Aerosol optical thickness", + "units": "dimensionless", + }, ), - 'wvp': xr.DataArray( + "wvp": xr.DataArray( np.random.randint(0, 5000, (1, 100, 100), dtype=np.uint16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_20m, 'x': x_20m}, - attrs={'long_name': 'Water vapor', 'units': 'kg/m^2'} - ) - } + dims=["time", "y", "x"], + coords={"time": time, "y": y_20m, "x": x_20m}, + attrs={"long_name": "Water vapor", "units": "kg/m^2"}, + ), + }, }, 60: { - 'reflectance': { - 'b01': xr.DataArray( + "reflectance": { + "b01": xr.DataArray( np.random.randint(1500, 3500, (1, 50, 50), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_60m, 'x': x_60m}, - attrs={'long_name': 'Coastal aerosol', 'units': 'digital_number'} + dims=["time", "y", "x"], + coords={"time": time, "y": y_60m, "x": x_60m}, + attrs={ + "long_name": "Coastal aerosol", + "units": "digital_number", + }, ), - 'b09': xr.DataArray( + "b09": xr.DataArray( np.random.randint(100, 1000, (1, 50, 50), dtype=np.int16), - dims=['time', 'y', 'x'], - coords={'time': time, 'y': y_60m, 'x': x_60m}, - attrs={'long_name': 'Water vapor', 'units': 'digital_number'} - ) + dims=["time", "y", "x"], + coords={"time": time, "y": y_60m, "x": x_60m}, + attrs={"long_name": "Water vapor", "units": "digital_number"}, + ), } - } + }, } - + return measurements - @patch('builtins.print') # Mock print to avoid test output + @patch("builtins.print") # Mock print to avoid test output def test_full_pyramid_creation(self, mock_print, real_measurements_data, tmp_path): """Test complete pyramid creation with realistic data.""" pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=512) - + output_path = str(tmp_path) - - with patch.object(pyramid, '_write_level_dataset') as mock_write: - result = pyramid.create_multiscale_measurements(real_measurements_data, output_path) - + + with patch.object(pyramid, "_write_level_dataset") as mock_write: + result = pyramid.create_multiscale_measurements( + real_measurements_data, output_path + ) + # Should create all 7 pyramid levels assert len(result) == 7 - + # Check that each level has appropriate characteristics for level, dataset in result.items(): - assert dataset.attrs['pyramid_level'] == level - assert dataset.attrs['resolution_meters'] == pyramid.pyramid_levels[level] + assert dataset.attrs["pyramid_level"] == level + assert ( + dataset.attrs["resolution_meters"] == pyramid.pyramid_levels[level] + ) assert len(dataset.data_vars) > 0 - + # Verify write was called for each level assert mock_write.call_count == 7 def test_level_specific_content(self, real_measurements_data): """Test that each pyramid level contains appropriate content.""" - pyramid = S2MultiscalePyramid(enable_sharding=False, spatial_chunk=256) # Disable sharding for simpler testing - + pyramid = S2MultiscalePyramid( + enable_sharding=False, spatial_chunk=256 + ) # Disable sharding for simpler testing + # Test level 0 (10m native) level_0 = pyramid._create_level_0_dataset(real_measurements_data) level_0_vars = set(level_0.data_vars.keys()) - expected_10m_vars = {'b02', 'b03', 'b04', 'b08'} + expected_10m_vars = {"b02", "b03", "b04", "b08"} assert len(expected_10m_vars.intersection(level_0_vars)) > 0 - + # Test level 1 (20m consolidated) level_1 = pyramid._create_level_1_dataset(real_measurements_data) # Should contain both native 20m and downsampled 10m data level_1_vars = set(level_1.data_vars.keys()) # Check some expected variables are present - expected_vars = {'b05', 'b06', 'b07', 'b11', 'b12', 'b8a', 'scl', 'aot', 'wvp'} + expected_vars = {"b05", "b06", "b07", "b11", "b12", "b8a", "scl", "aot", "wvp"} assert len(expected_vars.intersection(level_1_vars)) > 0 - - # Test level 2 (60m consolidated) + + # Test level 2 (60m consolidated) level_2 = pyramid._create_level_2_dataset(real_measurements_data) # Should contain native 60m and processed 20m data level_2_vars = set(level_2.data_vars.keys()) - expected_60m_vars = {'b01', 'b09'} + expected_60m_vars = {"b01", "b09"} assert len(expected_60m_vars.intersection(level_2_vars)) > 0 def test_sharding_configuration_integration(self, real_measurements_data): """Test sharding configuration with realistic data.""" pyramid = S2MultiscalePyramid(enable_sharding=True, spatial_chunk=256) - + # Create a test dataset level_0 = pyramid._create_level_0_dataset(real_measurements_data) - + if len(level_0.data_vars) > 0: encoding = pyramid._create_level_encoding(level_0, level=0) - + # Check encoding structure for var_name, var_data in level_0.data_vars.items(): assert var_name in encoding var_encoding = encoding[var_name] - + # Check sharding configuration if var_data.ndim >= 2: - assert 'shards' in var_encoding - shards = var_encoding['shards'] - + assert "shards" in var_encoding + shards = var_encoding["shards"] + # Verify shard dimensions are reasonable if var_data.ndim == 3: assert shards[0] == 1 # Time dimension - assert shards[1] > 0 # Y dimension - assert shards[2] > 0 # X dimension + assert shards[1] > 0 # Y dimension + assert shards[2] > 0 # X dimension elif var_data.ndim == 2: - assert shards[0] > 0 # Y dimension - assert shards[1] > 0 # X dimension + assert shards[0] > 0 # Y dimension + assert shards[1] > 0 # X dimension class TestEdgeCases: @@ -715,77 +728,77 @@ class TestEdgeCases: def test_empty_measurements_data(self): """Test handling of empty measurements data.""" pyramid = S2MultiscalePyramid() - + empty_measurements = {} - - with patch('builtins.print'): - with patch.object(pyramid, '_write_level_dataset'): - result = pyramid.create_multiscale_measurements(empty_measurements, "/tmp") - + + with patch("builtins.print"): + with patch.object(pyramid, "_write_level_dataset"): + result = pyramid.create_multiscale_measurements( + empty_measurements, "/tmp" + ) + # Should return empty results assert len(result) == 0 def test_missing_resolution_data(self): """Test handling when specific resolution data is missing.""" pyramid = S2MultiscalePyramid() - + # Only provide 20m data, missing 10m and 60m measurements_partial = { 20: { - 'reflectance': { - 'b05': xr.DataArray( + "reflectance": { + "b05": xr.DataArray( np.random.rand(1, 50, 50), - dims=['time', 'y', 'x'], + dims=["time", "y", "x"], coords={ - 'time': ['2023-01-01'], - 'y': np.arange(50), - 'x': np.arange(50) - } + "time": ["2023-01-01"], + "y": np.arange(50), + "x": np.arange(50), + }, ) } } } - + # Should handle gracefully level_0 = pyramid._create_level_0_dataset(measurements_partial) assert len(level_0.data_vars) == 0 # No 10m data available - + level_1 = pyramid._create_level_1_dataset(measurements_partial) assert len(level_1.data_vars) > 0 # Should have 20m data def test_coordinate_preservation(self): """Test that coordinate systems are preserved through processing.""" pyramid = S2MultiscalePyramid() - + # Create data with specific coordinate attributes x = np.linspace(300000, 310000, 100) y = np.linspace(4900000, 4910000, 100) - time = np.array(['2023-01-01'], dtype='datetime64[ns]') - + time = np.array(["2023-01-01"], dtype="datetime64[ns]") + # Add coordinate attributes - x_coord = xr.DataArray(x, dims=['x'], attrs={'units': 'm', 'crs': 'EPSG:32633'}) - y_coord = xr.DataArray(y, dims=['y'], attrs={'units': 'm', 'crs': 'EPSG:32633'}) - time_coord = xr.DataArray(time, dims=['time'], attrs={'calendar': 'gregorian'}) - + x_coord = xr.DataArray(x, dims=["x"], attrs={"units": "m", "crs": "EPSG:32633"}) + y_coord = xr.DataArray(y, dims=["y"], attrs={"units": "m", "crs": "EPSG:32633"}) + time_coord = xr.DataArray(time, dims=["time"], attrs={"calendar": "gregorian"}) + test_data = xr.DataArray( np.random.rand(1, 100, 100), - dims=['time', 'y', 'x'], - coords={'time': time_coord, 'y': y_coord, 'x': x_coord}, - name='b05' + dims=["time", "y", "x"], + coords={"time": time_coord, "y": y_coord, "x": x_coord}, + name="b05", ) - - measurements = { - 20: {'reflectance': {'b05': test_data}} - } - + + measurements = {20: {"reflectance": {"b05": test_data}}} + dataset = pyramid._create_level_1_dataset(measurements) - + # Check that coordinate attributes are preserved - if 'b05' in dataset.data_vars: - assert 'x' in dataset.coords - assert 'y' in dataset.coords - assert 'time' in dataset.coords - + if "b05" in dataset.data_vars: + assert "x" in dataset.coords + assert "y" in dataset.coords + assert "time" in dataset.coords + # Check coordinate attributes preservation - assert dataset.coords['x'].attrs.get('units') == 'm' - assert dataset.coords['y'].attrs.get('units') == 'm' + assert dataset.coords["x"].attrs.get("units") == "m" + assert dataset.coords["y"].attrs.get("units") == "m" diff --git a/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py b/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py index cd450cdd..282c62f6 100644 --- a/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py +++ b/src/eopf_geozarr/tests/test_s2_multiscale_geo_metadata.py @@ -4,11 +4,11 @@ Tests the geographic metadata writing functionality added to level creation. """ -import pytest +from unittest.mock import patch + import numpy as np +import pytest import xarray as xr -from unittest.mock import Mock, patch -from pyproj import CRS from eopf_geozarr.s2_optimization.s2_multiscale import S2MultiscalePyramid @@ -23,23 +23,23 @@ def pyramid_creator(): def sample_dataset_with_crs(): """Create a sample dataset with CRS information.""" coords = { - 'x': (['x'], np.linspace(0, 1000, 100)), - 'y': (['y'], np.linspace(0, 1000, 100)), - 'time': (['time'], [np.datetime64('2023-01-01')]) + "x": (["x"], np.linspace(0, 1000, 100)), + "y": (["y"], np.linspace(0, 1000, 100)), + "time": (["time"], [np.datetime64("2023-01-01")]), } - + data_vars = { - 'b02': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), - 'b03': (['time', 'y', 'x'], np.random.rand(1, 100, 100)), - 'b04': (['y', 'x'], np.random.rand(100, 100)) + "b02": (["time", "y", "x"], np.random.rand(1, 100, 100)), + "b03": (["time", "y", "x"], np.random.rand(1, 100, 100)), + "b04": (["y", "x"], np.random.rand(100, 100)), } - + ds = xr.Dataset(data_vars, coords=coords) - - ds['b02'].attrs['proj:epsg'] = 32632 - ds['b03'].attrs['proj:epsg'] = 32632 - ds['b04'].attrs['proj:epsg'] = 32632 - + + ds["b02"].attrs["proj:epsg"] = 32632 + ds["b03"].attrs["proj:epsg"] = 32632 + ds["b04"].attrs["proj:epsg"] = 32632 + return ds @@ -47,21 +47,21 @@ def sample_dataset_with_crs(): def sample_dataset_with_epsg_attrs(): """Create a sample dataset with EPSG in attributes.""" coords = { - 'x': (['x'], np.linspace(0, 1000, 50)), - 'y': (['y'], np.linspace(0, 1000, 50)) + "x": (["x"], np.linspace(0, 1000, 50)), + "y": (["y"], np.linspace(0, 1000, 50)), } - + data_vars = { - 'b05': (['y', 'x'], np.random.rand(50, 50)), - 'b06': (['y', 'x'], np.random.rand(50, 50)) + "b05": (["y", "x"], np.random.rand(50, 50)), + "b06": (["y", "x"], np.random.rand(50, 50)), } - + ds = xr.Dataset(data_vars, coords=coords) - + # Add EPSG to variable attributes - ds['b05'].attrs['proj:epsg'] = 32632 - ds['b06'].attrs['proj:epsg'] = 32632 - + ds["b05"].attrs["proj:epsg"] = 32632 + ds["b06"].attrs["proj:epsg"] = 32632 + return ds @@ -69,230 +69,245 @@ def sample_dataset_with_epsg_attrs(): def sample_dataset_no_crs(): """Create a sample dataset without CRS information.""" coords = { - 'x': (['x'], np.linspace(0, 1000, 25)), - 'y': (['y'], np.linspace(0, 1000, 25)) + "x": (["x"], np.linspace(0, 1000, 25)), + "y": (["y"], np.linspace(0, 1000, 25)), } - + data_vars = { - 'b11': (['y', 'x'], np.random.rand(25, 25)), - 'b12': (['y', 'x'], np.random.rand(25, 25)) + "b11": (["y", "x"], np.random.rand(25, 25)), + "b12": (["y", "x"], np.random.rand(25, 25)), } - + return xr.Dataset(data_vars, coords=coords) class TestWriteGeoMetadata: """Test the _write_geo_metadata method.""" - - def test_write_geo_metadata_with_rio_crs(self, pyramid_creator, sample_dataset_with_crs): + + def test_write_geo_metadata_with_rio_crs( + self, pyramid_creator, sample_dataset_with_crs + ): """Test _write_geo_metadata with dataset that has rioxarray CRS.""" - + # Call the method pyramid_creator._write_geo_metadata(sample_dataset_with_crs) - + # Verify CRS was written - assert hasattr(sample_dataset_with_crs, 'rio') + assert hasattr(sample_dataset_with_crs, "rio") assert sample_dataset_with_crs.rio.crs is not None assert sample_dataset_with_crs.rio.crs.to_epsg() == 32632 - - def test_write_geo_metadata_with_epsg_attrs(self, pyramid_creator, sample_dataset_with_epsg_attrs): + + def test_write_geo_metadata_with_epsg_attrs( + self, pyramid_creator, sample_dataset_with_epsg_attrs + ): """Test _write_geo_metadata with dataset that has EPSG in variable attributes.""" - + # Verify initial state - no CRS - assert not hasattr(sample_dataset_with_epsg_attrs, 'rio') or sample_dataset_with_epsg_attrs.rio.crs is None - + assert ( + not hasattr(sample_dataset_with_epsg_attrs, "rio") + or sample_dataset_with_epsg_attrs.rio.crs is None + ) + # Call the method pyramid_creator._write_geo_metadata(sample_dataset_with_epsg_attrs) - + # Verify CRS was written from attributes - assert hasattr(sample_dataset_with_epsg_attrs, 'rio') + assert hasattr(sample_dataset_with_epsg_attrs, "rio") assert sample_dataset_with_epsg_attrs.rio.crs is not None assert sample_dataset_with_epsg_attrs.rio.crs.to_epsg() == 32632 - + def test_write_geo_metadata_no_crs(self, pyramid_creator, sample_dataset_no_crs): """Test _write_geo_metadata with dataset that has no CRS information.""" - + # Verify initial state - no CRS - assert not hasattr(sample_dataset_no_crs, 'rio') or sample_dataset_no_crs.rio.crs is None - + assert ( + not hasattr(sample_dataset_no_crs, "rio") + or sample_dataset_no_crs.rio.crs is None + ) + # Call the method - should not fail but also not add CRS pyramid_creator._write_geo_metadata(sample_dataset_no_crs) - + # Verify no CRS was added (method handles gracefully) # The method should not fail even when no CRS is available # This tests the robustness of the method - - def test_write_geo_metadata_custom_grid_mapping_name(self, pyramid_creator, sample_dataset_with_crs): + + def test_write_geo_metadata_custom_grid_mapping_name( + self, pyramid_creator, sample_dataset_with_crs + ): """Test _write_geo_metadata with custom grid_mapping variable name.""" - + # Call the method with custom grid mapping name custom_name = "custom_spatial_ref" pyramid_creator._write_geo_metadata(sample_dataset_with_crs, custom_name) - + # Verify CRS was written - assert hasattr(sample_dataset_with_crs, 'rio') + assert hasattr(sample_dataset_with_crs, "rio") assert sample_dataset_with_crs.rio.crs is not None - - def test_write_geo_metadata_preserves_existing_data(self, pyramid_creator, sample_dataset_with_crs): + + def test_write_geo_metadata_preserves_existing_data( + self, pyramid_creator, sample_dataset_with_crs + ): """Test that _write_geo_metadata preserves existing data variables and coordinates.""" - + # Store original data original_vars = list(sample_dataset_with_crs.data_vars.keys()) original_coords = list(sample_dataset_with_crs.coords.keys()) - original_b02_data = sample_dataset_with_crs['b02'].values.copy() - + original_b02_data = sample_dataset_with_crs["b02"].values.copy() + # Call the method pyramid_creator._write_geo_metadata(sample_dataset_with_crs) - + # Verify all original data is preserved assert list(sample_dataset_with_crs.data_vars.keys()) == original_vars assert all(coord in sample_dataset_with_crs.coords for coord in original_coords) - assert np.array_equal(sample_dataset_with_crs['b02'].values, original_b02_data) - + assert np.array_equal(sample_dataset_with_crs["b02"].values, original_b02_data) + def test_write_geo_metadata_empty_dataset(self, pyramid_creator): """Test _write_geo_metadata with empty dataset.""" - + empty_ds = xr.Dataset({}, coords={}) - + # Call the method - should handle gracefully pyramid_creator._write_geo_metadata(empty_ds) - + # Verify method doesn't fail with empty dataset # This tests robustness - - def test_write_geo_metadata_rio_write_crs_called(self, pyramid_creator, sample_dataset_with_crs): + + def test_write_geo_metadata_rio_write_crs_called( + self, pyramid_creator, sample_dataset_with_crs + ): """Test that rio.write_crs is called correctly.""" - + # Mock the rio.write_crs method - with patch.object(sample_dataset_with_crs.rio, 'write_crs') as mock_write_crs: + with patch.object(sample_dataset_with_crs.rio, "write_crs") as mock_write_crs: # Call the method pyramid_creator._write_geo_metadata(sample_dataset_with_crs) - + # Verify rio.write_crs was called with correct arguments mock_write_crs.assert_called_once() call_args = mock_write_crs.call_args - assert call_args[1]['inplace'] is True # inplace=True should be passed - + assert call_args[1]["inplace"] is True # inplace=True should be passed + def test_write_geo_metadata_crs_from_multiple_sources(self, pyramid_creator): """Test CRS detection from multiple sources in priority order.""" - + # Create dataset with both rio CRS and EPSG attributes coords = { - 'x': (['x'], np.linspace(0, 1000, 50)), - 'y': (['y'], np.linspace(0, 1000, 50)) + "x": (["x"], np.linspace(0, 1000, 50)), + "y": (["y"], np.linspace(0, 1000, 50)), } - - data_vars = { - 'b08': (['y', 'x'], np.random.rand(50, 50)) - } - + + data_vars = {"b08": (["y", "x"], np.random.rand(50, 50))} + ds = xr.Dataset(data_vars, coords=coords) - + # Add both rio CRS and EPSG attribute (rio should take priority) - ds = ds.rio.write_crs('EPSG:4326') # Rio CRS - ds['b08'].attrs['proj:epsg'] = 32632 # EPSG attribute - + ds = ds.rio.write_crs("EPSG:4326") # Rio CRS + ds["b08"].attrs["proj:epsg"] = 32632 # EPSG attribute + # Call the method pyramid_creator._write_geo_metadata(ds) - + # Verify rio CRS was used (priority over attributes) assert ds.rio.crs.to_epsg() == 4326 # Should still be 4326, not 32632 - + def test_write_geo_metadata_integration_with_level_creation(self, pyramid_creator): """Test that _write_geo_metadata is properly integrated in level creation methods.""" - + # Create mock measurements data measurements_by_resolution = { 10: { - 'bands': { - 'b02': xr.DataArray( + "bands": { + "b02": xr.DataArray( np.random.rand(100, 100), - dims=['y', 'x'], + dims=["y", "x"], coords={ - 'x': (['x'], np.linspace(0, 1000, 100)), - 'y': (['y'], np.linspace(0, 1000, 100)) - } - ).rio.write_crs('EPSG:32632') + "x": (["x"], np.linspace(0, 1000, 100)), + "y": (["y"], np.linspace(0, 1000, 100)), + }, + ).rio.write_crs("EPSG:32632") } } } - + # Create level 0 dataset (which should call _write_geo_metadata) level_0_ds = pyramid_creator._create_level_0_dataset(measurements_by_resolution) - + # Verify CRS was written by _write_geo_metadata - assert hasattr(level_0_ds, 'rio') + assert hasattr(level_0_ds, "rio") assert level_0_ds.rio.crs is not None assert level_0_ds.rio.crs.to_epsg() == 32632 class TestWriteGeoMetadataEdgeCases: """Test edge cases for _write_geo_metadata method.""" - + def test_write_geo_metadata_invalid_crs(self, pyramid_creator): """Test _write_geo_metadata with invalid CRS data.""" - + coords = { - 'x': (['x'], np.linspace(0, 1000, 10)), - 'y': (['y'], np.linspace(0, 1000, 10)) - } - - data_vars = { - 'test_var': (['y', 'x'], np.random.rand(10, 10)) + "x": (["x"], np.linspace(0, 1000, 10)), + "y": (["y"], np.linspace(0, 1000, 10)), } - + + data_vars = {"test_var": (["y", "x"], np.random.rand(10, 10))} + ds = xr.Dataset(data_vars, coords=coords) - + # Add invalid EPSG code - ds['test_var'].attrs['proj:epsg'] = 'invalid_epsg' - + ds["test_var"].attrs["proj:epsg"] = "invalid_epsg" + # Method should raise an exception for invalid CRS (normal behavior) from pyproj.exceptions import CRSError + with pytest.raises(CRSError): pyramid_creator._write_geo_metadata(ds) - + def test_write_geo_metadata_mixed_crs_variables(self, pyramid_creator): """Test _write_geo_metadata with variables having different CRS information.""" - + coords = { - 'x': (['x'], np.linspace(0, 1000, 20)), - 'y': (['y'], np.linspace(0, 1000, 20)) + "x": (["x"], np.linspace(0, 1000, 20)), + "y": (["y"], np.linspace(0, 1000, 20)), } - + data_vars = { - 'var1': (['y', 'x'], np.random.rand(20, 20)), - 'var2': (['y', 'x'], np.random.rand(20, 20)) + "var1": (["y", "x"], np.random.rand(20, 20)), + "var2": (["y", "x"], np.random.rand(20, 20)), } - + ds = xr.Dataset(data_vars, coords=coords) - + # Add different EPSG codes to different variables - ds['var1'].attrs['proj:epsg'] = 32632 - ds['var2'].attrs['proj:epsg'] = 4326 - + ds["var1"].attrs["proj:epsg"] = 32632 + ds["var2"].attrs["proj:epsg"] = 4326 + # Call the method (should use the first CRS found) pyramid_creator._write_geo_metadata(ds) - + # Verify a CRS was applied (should be the first one found) - assert hasattr(ds, 'rio') - - def test_write_geo_metadata_maintains_dataset_attrs(self, pyramid_creator, sample_dataset_with_crs): + assert hasattr(ds, "rio") + + def test_write_geo_metadata_maintains_dataset_attrs( + self, pyramid_creator, sample_dataset_with_crs + ): """Test that _write_geo_metadata maintains dataset-level attributes.""" - + # Add some dataset attributes - sample_dataset_with_crs.attrs['pyramid_level'] = 1 - sample_dataset_with_crs.attrs['resolution_meters'] = 20 - sample_dataset_with_crs.attrs['custom_attr'] = 'test_value' - + sample_dataset_with_crs.attrs["pyramid_level"] = 1 + sample_dataset_with_crs.attrs["resolution_meters"] = 20 + sample_dataset_with_crs.attrs["custom_attr"] = "test_value" + original_attrs = sample_dataset_with_crs.attrs.copy() - + # Call the method pyramid_creator._write_geo_metadata(sample_dataset_with_crs) - + # Verify dataset attributes are preserved for key, value in original_attrs.items(): assert sample_dataset_with_crs.attrs[key] == value -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/src/eopf_geozarr/tests/test_s2_resampling.py b/src/eopf_geozarr/tests/test_s2_resampling.py index b49392f0..0dca2200 100644 --- a/src/eopf_geozarr/tests/test_s2_resampling.py +++ b/src/eopf_geozarr/tests/test_s2_resampling.py @@ -2,8 +2,8 @@ Unit tests for S2 resampling functionality. """ -import pytest import numpy as np +import pytest import xarray as xr from eopf_geozarr.s2_optimization.s2_resampling import ( @@ -16,23 +16,26 @@ def sample_reflectance_data_2d(): """Create a 2D reflectance data array for testing.""" # Create a 4x4 array with known values - data = np.array([ - [100, 200, 300, 400], - [150, 250, 350, 450], - [110, 210, 310, 410], - [160, 260, 360, 460] - ], dtype=np.uint16) - + data = np.array( + [ + [100, 200, 300, 400], + [150, 250, 350, 450], + [110, 210, 310, 410], + [160, 260, 360, 460], + ], + dtype=np.uint16, + ) + coords = { - 'y': np.array([1000, 990, 980, 970]), - 'x': np.array([500000, 500010, 500020, 500030]) + "y": np.array([1000, 990, 980, 970]), + "x": np.array([500000, 500010, 500020, 500030]), } - + return xr.DataArray( data, - dims=['y', 'x'], + dims=["y", "x"], coords=coords, - attrs={'units': 'reflectance', 'scale_factor': 0.0001} + attrs={"units": "reflectance", "scale_factor": 0.0001}, ) @@ -40,28 +43,35 @@ def sample_reflectance_data_2d(): def sample_reflectance_data_3d(): """Create a 3D reflectance data array with time dimension for testing.""" # Create a 2x4x4 array (time, y, x) - data = np.array([ - [[100, 200, 300, 400], - [150, 250, 350, 450], - [110, 210, 310, 410], - [160, 260, 360, 460]], - [[120, 220, 320, 420], - [170, 270, 370, 470], - [130, 230, 330, 430], - [180, 280, 380, 480]] - ], dtype=np.uint16) - + data = np.array( + [ + [ + [100, 200, 300, 400], + [150, 250, 350, 450], + [110, 210, 310, 410], + [160, 260, 360, 460], + ], + [ + [120, 220, 320, 420], + [170, 270, 370, 470], + [130, 230, 330, 430], + [180, 280, 380, 480], + ], + ], + dtype=np.uint16, + ) + coords = { - 'time': np.array(['2023-01-01', '2023-01-02'], dtype='datetime64[D]'), - 'y': np.array([1000, 990, 980, 970]), - 'x': np.array([500000, 500010, 500020, 500030]) + "time": np.array(["2023-01-01", "2023-01-02"], dtype="datetime64[D]"), + "y": np.array([1000, 990, 980, 970]), + "x": np.array([500000, 500010, 500020, 500030]), } - + return xr.DataArray( data, - dims=['time', 'y', 'x'], + dims=["time", "y", "x"], coords=coords, - attrs={'units': 'reflectance', 'scale_factor': 0.0001} + attrs={"units": "reflectance", "scale_factor": 0.0001}, ) @@ -69,23 +79,20 @@ def sample_reflectance_data_3d(): def sample_classification_data(): """Create classification data for testing.""" # SCL values: 0=no_data, 1=saturated, 4=vegetation, 6=water, etc. - data = np.array([ - [0, 1, 4, 4], - [1, 4, 6, 6], - [4, 4, 6, 8], - [4, 6, 8, 8] - ], dtype=np.uint8) - + data = np.array( + [[0, 1, 4, 4], [1, 4, 6, 6], [4, 4, 6, 8], [4, 6, 8, 8]], dtype=np.uint8 + ) + coords = { - 'y': np.array([1000, 990, 980, 970]), - 'x': np.array([500000, 500010, 500020, 500030]) + "y": np.array([1000, 990, 980, 970]), + "x": np.array([500000, 500010, 500020, 500030]), } - + return xr.DataArray( data, - dims=['y', 'x'], + dims=["y", "x"], coords=coords, - attrs={'long_name': 'Scene Classification Layer'} + attrs={"long_name": "Scene Classification Layer"}, ) @@ -93,23 +100,17 @@ def sample_classification_data(): def sample_quality_mask(): """Create quality mask data for testing.""" # Binary mask: 0=good, 1=bad - data = np.array([ - [0, 0, 1, 0], - [0, 1, 0, 0], - [1, 0, 0, 1], - [0, 0, 1, 1] - ], dtype=np.uint8) - + data = np.array( + [[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 1], [0, 0, 1, 1]], dtype=np.uint8 + ) + coords = { - 'y': np.array([1000, 990, 980, 970]), - 'x': np.array([500000, 500010, 500020, 500030]) + "y": np.array([1000, 990, 980, 970]), + "x": np.array([500000, 500010, 500020, 500030]), } - + return xr.DataArray( - data, - dims=['y', 'x'], - coords=coords, - attrs={'long_name': 'Quality mask'} + data, dims=["y", "x"], coords=coords, attrs={"long_name": "Quality mask"} ) @@ -117,23 +118,26 @@ def sample_quality_mask(): def sample_probability_data(): """Create probability data for testing.""" # Cloud probabilities in percent (0-100) - data = np.array([ - [10.5, 20.3, 85.7, 92.1], - [15.2, 75.8, 88.3, 95.6], - [12.7, 18.9, 90.2, 87.4], - [8.1, 22.4, 78.9, 99.0] - ], dtype=np.float32) - + data = np.array( + [ + [10.5, 20.3, 85.7, 92.1], + [15.2, 75.8, 88.3, 95.6], + [12.7, 18.9, 90.2, 87.4], + [8.1, 22.4, 78.9, 99.0], + ], + dtype=np.float32, + ) + coords = { - 'y': np.array([1000, 990, 980, 970]), - 'x': np.array([500000, 500010, 500020, 500030]) + "y": np.array([1000, 990, 980, 970]), + "x": np.array([500000, 500010, 500020, 500030]), } - + return xr.DataArray( data, - dims=['y', 'x'], + dims=["y", "x"], coords=coords, - attrs={'long_name': 'Cloud probability', 'units': 'percent'} + attrs={"long_name": "Cloud probability", "units": "percent"}, ) @@ -143,134 +147,132 @@ class TestS2ResamplingEngine: def test_initialization(self): """Test engine initialization.""" engine = S2ResamplingEngine() - - assert hasattr(engine, 'resampling_methods') + + assert hasattr(engine, "resampling_methods") assert len(engine.resampling_methods) == 5 - assert 'reflectance' in engine.resampling_methods - assert 'classification' in engine.resampling_methods - assert 'quality_mask' in engine.resampling_methods - assert 'probability' in engine.resampling_methods - assert 'detector_footprint' in engine.resampling_methods + assert "reflectance" in engine.resampling_methods + assert "classification" in engine.resampling_methods + assert "quality_mask" in engine.resampling_methods + assert "probability" in engine.resampling_methods + assert "detector_footprint" in engine.resampling_methods def test_downsample_reflectance_2d(self, sample_reflectance_data_2d): """Test reflectance downsampling for 2D data.""" engine = S2ResamplingEngine() - + # Downsample from 4x4 to 2x2 result = engine.downsample_variable( - sample_reflectance_data_2d, 2, 2, 'reflectance' + sample_reflectance_data_2d, 2, 2, "reflectance" ) - + # Check dimensions assert result.shape == (2, 2) - assert result.dims == ('y', 'x') - + assert result.dims == ("y", "x") + # Check that values are averages of 2x2 blocks # Top-left block: mean of [100, 200, 150, 250] = 175 assert result.values[0, 0] == 175.0 - + # Top-right block: mean of [300, 400, 350, 450] = 375 assert result.values[0, 1] == 375.0 - + # Check coordinates are properly subsampled - assert len(result.coords['y']) == 2 - assert len(result.coords['x']) == 2 - np.testing.assert_array_equal(result.coords['y'].values, [1000, 980]) - np.testing.assert_array_equal(result.coords['x'].values, [500000, 500020]) - + assert len(result.coords["y"]) == 2 + assert len(result.coords["x"]) == 2 + np.testing.assert_array_equal(result.coords["y"].values, [1000, 980]) + np.testing.assert_array_equal(result.coords["x"].values, [500000, 500020]) + # Check attributes are preserved assert result.attrs == sample_reflectance_data_2d.attrs def test_downsample_reflectance_3d(self, sample_reflectance_data_3d): """Test reflectance downsampling for 3D data.""" engine = S2ResamplingEngine() - + # Downsample from 2x4x4 to 2x2x2 result = engine.downsample_variable( - sample_reflectance_data_3d, 2, 2, 'reflectance' + sample_reflectance_data_3d, 2, 2, "reflectance" ) - + # Check dimensions assert result.shape == (2, 2, 2) - assert result.dims == ('time', 'y', 'x') - + assert result.dims == ("time", "y", "x") + # Check first time slice values # Top-left block: mean of [100, 200, 150, 250] = 175 assert result.values[0, 0, 0] == 175.0 - - # Check second time slice values + + # Check second time slice values # Top-left block: mean of [120, 220, 170, 270] = 195 assert result.values[1, 0, 0] == 195.0 - + # Check coordinates - assert len(result.coords['time']) == 2 - assert len(result.coords['y']) == 2 - assert len(result.coords['x']) == 2 + assert len(result.coords["time"]) == 2 + assert len(result.coords["y"]) == 2 + assert len(result.coords["x"]) == 2 def test_downsample_classification(self, sample_classification_data): """Test classification downsampling using mode.""" engine = S2ResamplingEngine() - + # Downsample from 4x4 to 2x2 result = engine.downsample_variable( - sample_classification_data, 2, 2, 'classification' + sample_classification_data, 2, 2, "classification" ) - + # Check dimensions assert result.shape == (2, 2) - assert result.dims == ('y', 'x') - + assert result.dims == ("y", "x") + # Check mode values # Top-left block: [0, 1, 1, 4] -> mode should be 1 (most frequent) # Top-right block: [4, 4, 6, 6] -> mode could be either 4 or 6 (both appear twice) assert result.values[0, 0] in [0, 1, 4] # Allow for mode calculation variations - + # Check data type is preserved assert result.dtype == sample_classification_data.dtype def test_downsample_quality_mask(self, sample_quality_mask): """Test quality mask downsampling using logical OR.""" engine = S2ResamplingEngine() - + # Downsample from 4x4 to 2x2 - result = engine.downsample_variable( - sample_quality_mask, 2, 2, 'quality_mask' - ) - + result = engine.downsample_variable(sample_quality_mask, 2, 2, "quality_mask") + # Check dimensions assert result.shape == (2, 2) - assert result.dims == ('y', 'x') - + assert result.dims == ("y", "x") + # Check logical OR behavior # Top-left block: [0, 0, 0, 1] -> any non-zero = 1 assert result.values[0, 0] == 1 - + # Top-right block: [1, 0, 0, 0] -> any non-zero = 1 assert result.values[0, 1] == 1 - + # Bottom-left block: [1, 0, 0, 0] -> any non-zero = 1 assert result.values[1, 0] == 1 - + # Bottom-right block: [0, 1, 1, 1] -> any non-zero = 1 assert result.values[1, 1] == 1 def test_downsample_probability(self, sample_probability_data): """Test probability downsampling with value clamping.""" engine = S2ResamplingEngine() - + # Downsample from 4x4 to 2x2 result = engine.downsample_variable( - sample_probability_data, 2, 2, 'probability' + sample_probability_data, 2, 2, "probability" ) - + # Check dimensions assert result.shape == (2, 2) - assert result.dims == ('y', 'x') - + assert result.dims == ("y", "x") + # Values should be averages and clamped to [0, 100] assert np.all(result.values >= 0) assert np.all(result.values <= 100) - + # Check specific average calculation # Top-left block: mean of [10.5, 20.3, 15.2, 75.8] ≈ 30.45 expected_val = (10.5 + 20.3 + 15.2 + 75.8) / 4 @@ -279,59 +281,51 @@ def test_downsample_probability(self, sample_probability_data): def test_detector_footprint_same_as_quality_mask(self, sample_quality_mask): """Test that detector footprint uses same method as quality mask.""" engine = S2ResamplingEngine() - + result_quality = engine.downsample_variable( - sample_quality_mask, 2, 2, 'quality_mask' + sample_quality_mask, 2, 2, "quality_mask" ) result_detector = engine.downsample_variable( - sample_quality_mask, 2, 2, 'detector_footprint' + sample_quality_mask, 2, 2, "detector_footprint" ) - + # Results should be identical np.testing.assert_array_equal(result_quality.values, result_detector.values) def test_invalid_variable_type(self, sample_reflectance_data_2d): """Test error handling for invalid variable type.""" engine = S2ResamplingEngine() - + with pytest.raises(ValueError, match="Unknown variable type"): - engine.downsample_variable( - sample_reflectance_data_2d, 2, 2, 'invalid_type' - ) + engine.downsample_variable(sample_reflectance_data_2d, 2, 2, "invalid_type") def test_non_divisible_dimensions(self): """Test handling of non-divisible dimensions.""" engine = S2ResamplingEngine() - + # Create 5x5 data (not evenly divisible by 2) data = np.random.rand(5, 5).astype(np.float32) - coords = { - 'y': np.arange(5), - 'x': np.arange(5) - } - da = xr.DataArray(data, dims=['y', 'x'], coords=coords) - + coords = {"y": np.arange(5), "x": np.arange(5)} + da = xr.DataArray(data, dims=["y", "x"], coords=coords) + # Should crop to make it divisible - result = engine.downsample_variable(da, 2, 2, 'reflectance') - + result = engine.downsample_variable(da, 2, 2, "reflectance") + # Should result in 2x2 output (cropped from 4x4) assert result.shape == (2, 2) def test_single_pixel_downsampling(self): """Test downsampling to single pixel.""" engine = S2ResamplingEngine() - + # Create 4x4 data data = np.ones((4, 4), dtype=np.float32) * 100 - coords = { - 'y': np.arange(4), - 'x': np.arange(4) - } - da = xr.DataArray(data, dims=['y', 'x'], coords=coords) - + coords = {"y": np.arange(4), "x": np.arange(4)} + da = xr.DataArray(data, dims=["y", "x"], coords=coords) + # Downsample to 1x1 - result = engine.downsample_variable(da, 1, 1, 'reflectance') - + result = engine.downsample_variable(da, 1, 1, "reflectance") + assert result.shape == (1, 1) assert result.values[0, 0] == 100.0 @@ -342,50 +336,53 @@ class TestDetermineVariableType: def test_spectral_bands(self): """Test recognition of spectral bands.""" dummy_data = xr.DataArray([1, 2, 3]) - + # Test standard bands - assert determine_variable_type('b01', dummy_data) == 'reflectance' - assert determine_variable_type('b02', dummy_data) == 'reflectance' - assert determine_variable_type('b8a', dummy_data) == 'reflectance' - + assert determine_variable_type("b01", dummy_data) == "reflectance" + assert determine_variable_type("b02", dummy_data) == "reflectance" + assert determine_variable_type("b8a", dummy_data) == "reflectance" + # Test specific non-band variables that should be classified differently - assert determine_variable_type('scl', dummy_data) == 'classification' - assert determine_variable_type('cld', dummy_data) == 'probability' - assert determine_variable_type('quality_b01', dummy_data) == 'quality_mask' + assert determine_variable_type("scl", dummy_data) == "classification" + assert determine_variable_type("cld", dummy_data) == "probability" + assert determine_variable_type("quality_b01", dummy_data) == "quality_mask" def test_classification_data(self): """Test recognition of classification data.""" dummy_data = xr.DataArray([1, 2, 3]) - - assert determine_variable_type('scl', dummy_data) == 'classification' + + assert determine_variable_type("scl", dummy_data) == "classification" def test_probability_data(self): """Test recognition of probability data.""" dummy_data = xr.DataArray([1, 2, 3]) - - assert determine_variable_type('cld', dummy_data) == 'probability' - assert determine_variable_type('snw', dummy_data) == 'probability' + + assert determine_variable_type("cld", dummy_data) == "probability" + assert determine_variable_type("snw", dummy_data) == "probability" def test_atmospheric_quality(self): """Test recognition of atmospheric quality data.""" dummy_data = xr.DataArray([1, 2, 3]) - - assert determine_variable_type('aot', dummy_data) == 'reflectance' - assert determine_variable_type('wvp', dummy_data) == 'reflectance' + + assert determine_variable_type("aot", dummy_data) == "reflectance" + assert determine_variable_type("wvp", dummy_data) == "reflectance" def test_quality_masks(self): """Test recognition of quality mask data.""" dummy_data = xr.DataArray([1, 2, 3]) - - assert determine_variable_type('detector_footprint_b01', dummy_data) == 'quality_mask' - assert determine_variable_type('quality_b02', dummy_data) == 'quality_mask' + + assert ( + determine_variable_type("detector_footprint_b01", dummy_data) + == "quality_mask" + ) + assert determine_variable_type("quality_b02", dummy_data) == "quality_mask" def test_unknown_variable_defaults_to_reflectance(self): """Test that unknown variables default to reflectance.""" dummy_data = xr.DataArray([1, 2, 3]) - - assert determine_variable_type('unknown_var', dummy_data) == 'reflectance' - assert determine_variable_type('custom_band', dummy_data) == 'reflectance' + + assert determine_variable_type("unknown_var", dummy_data) == "reflectance" + assert determine_variable_type("custom_band", dummy_data) == "reflectance" class TestEdgeCases: @@ -394,59 +391,53 @@ class TestEdgeCases: def test_empty_data_array(self): """Test handling of empty data arrays.""" engine = S2ResamplingEngine() - + # Create minimal data array data = np.array([[1]]) - coords = {'y': [0], 'x': [0]} - da = xr.DataArray(data, dims=['y', 'x'], coords=coords) - + coords = {"y": [0], "x": [0]} + da = xr.DataArray(data, dims=["y", "x"], coords=coords) + # This should work for 1x1 -> 1x1 downsampling - result = engine.downsample_variable(da, 1, 1, 'reflectance') + result = engine.downsample_variable(da, 1, 1, "reflectance") assert result.shape == (1, 1) assert result.values[0, 0] == 1 def test_preserve_attributes_and_encoding(self): """Test that attributes and encoding are preserved.""" engine = S2ResamplingEngine() - + data = np.ones((4, 4), dtype=np.uint16) * 1000 - coords = { - 'y': np.arange(4), - 'x': np.arange(4) - } - + coords = {"y": np.arange(4), "x": np.arange(4)} + attrs = { - 'long_name': 'Test reflectance', - 'units': 'reflectance', - 'scale_factor': 0.0001, - 'add_offset': 0 + "long_name": "Test reflectance", + "units": "reflectance", + "scale_factor": 0.0001, + "add_offset": 0, } - - da = xr.DataArray(data, dims=['y', 'x'], coords=coords, attrs=attrs) - - result = engine.downsample_variable(da, 2, 2, 'reflectance') - + + da = xr.DataArray(data, dims=["y", "x"], coords=coords, attrs=attrs) + + result = engine.downsample_variable(da, 2, 2, "reflectance") + # Attributes should be preserved assert result.attrs == attrs def test_coordinate_names_preserved(self): """Test that coordinate names are preserved during downsampling.""" engine = S2ResamplingEngine() - + data = np.ones((4, 4), dtype=np.float32) - coords = { - 'latitude': np.arange(4), - 'longitude': np.arange(4) - } - - da = xr.DataArray(data, dims=['latitude', 'longitude'], coords=coords) - - result = engine.downsample_variable(da, 2, 2, 'reflectance') - + coords = {"latitude": np.arange(4), "longitude": np.arange(4)} + + da = xr.DataArray(data, dims=["latitude", "longitude"], coords=coords) + + result = engine.downsample_variable(da, 2, 2, "reflectance") + # Coordinate names should be preserved - assert 'latitude' in result.coords - assert 'longitude' in result.coords - assert result.dims == ('latitude', 'longitude') + assert "latitude" in result.coords + assert "longitude" in result.coords + assert result.dims == ("latitude", "longitude") class TestIntegrationScenarios: @@ -455,37 +446,34 @@ class TestIntegrationScenarios: def test_multiscale_pyramid_creation(self): """Test creating a complete multiscale pyramid.""" engine = S2ResamplingEngine() - + # Start with 32x32 data original_size = 32 data = np.random.rand(original_size, original_size).astype(np.float32) * 1000 - coords = { - 'y': np.arange(original_size), - 'x': np.arange(original_size) - } - - da = xr.DataArray(data, dims=['y', 'x'], coords=coords) - + coords = {"y": np.arange(original_size), "x": np.arange(original_size)} + + da = xr.DataArray(data, dims=["y", "x"], coords=coords) + # Create pyramid levels: 32x32 -> 16x16 -> 8x8 -> 4x4 -> 2x2 -> 1x1 levels = [] current_data = da current_size = original_size - + while current_size >= 2: next_size = current_size // 2 downsampled = engine.downsample_variable( - current_data, next_size, next_size, 'reflectance' + current_data, next_size, next_size, "reflectance" ) levels.append(downsampled) current_data = downsampled current_size = next_size - + # Verify pyramid structure expected_sizes = [16, 8, 4, 2, 1] for i, level in enumerate(levels): expected_size = expected_sizes[i] assert level.shape == (expected_size, expected_size) - + # Verify that values are reasonable (not NaN, not extreme) for level in levels: assert not np.isnan(level.values).any() @@ -494,46 +482,41 @@ def test_multiscale_pyramid_creation(self): def test_mixed_variable_types_processing(self): """Test processing different variable types together.""" engine = S2ResamplingEngine() - + # Create base 4x4 data size = 4 - coords = {'y': np.arange(size), 'x': np.arange(size)} - + coords = {"y": np.arange(size), "x": np.arange(size)} + # Create different variable types reflectance_data = xr.DataArray( - np.random.rand(size, size) * 1000, - dims=['y', 'x'], coords=coords + np.random.rand(size, size) * 1000, dims=["y", "x"], coords=coords ) - + classification_data = xr.DataArray( - np.random.randint(0, 10, (size, size)), - dims=['y', 'x'], coords=coords + np.random.randint(0, 10, (size, size)), dims=["y", "x"], coords=coords ) - + quality_data = xr.DataArray( - np.random.randint(0, 2, (size, size)), - dims=['y', 'x'], coords=coords + np.random.randint(0, 2, (size, size)), dims=["y", "x"], coords=coords ) - + # Process each with appropriate method results = {} for var_name, var_data, var_type in [ - ('b04', reflectance_data, 'reflectance'), - ('scl', classification_data, 'classification'), - ('quality_b04', quality_data, 'quality_mask') + ("b04", reflectance_data, "reflectance"), + ("scl", classification_data, "classification"), + ("quality_b04", quality_data, "quality_mask"), ]: - results[var_name] = engine.downsample_variable( - var_data, 2, 2, var_type - ) - + results[var_name] = engine.downsample_variable(var_data, 2, 2, var_type) + # Verify all results have same dimensions for result in results.values(): assert result.shape == (2, 2) - + # Verify coordinate consistency - y_coords = results['b04'].coords['y'] - x_coords = results['b04'].coords['x'] - + y_coords = results["b04"].coords["y"] + x_coords = results["b04"].coords["x"] + for result in results.values(): - np.testing.assert_array_equal(result.coords['y'].values, y_coords.values) - np.testing.assert_array_equal(result.coords['x'].values, x_coords.values) + np.testing.assert_array_equal(result.coords["y"].values, y_coords.values) + np.testing.assert_array_equal(result.coords["x"].values, x_coords.values) diff --git a/test_sharding_fix.py b/test_sharding_fix.py index af8ad669..588700cb 100644 --- a/test_sharding_fix.py +++ b/test_sharding_fix.py @@ -6,15 +6,16 @@ """ import sys -import os -sys.path.insert(0, 'src') + +sys.path.insert(0, "src") + def test_calculate_shard_dimension(): """Test the _calculate_shard_dimension function.""" from eopf_geozarr.conversion.geozarr import _calculate_shard_dimension - + print("🧪 Testing _calculate_shard_dimension function...") - + # Test cases: (data_dim, chunk_dim, description) test_cases = [ (10980, 4096, "Sentinel-2 10m resolution typical case"), @@ -28,33 +29,41 @@ def test_calculate_shard_dimension(): (256, 512, "Chunk larger than data"), (1024, 256, "4x multiple case"), ] - + print("\nTest Results:") print("=" * 80) - print(f"{'Data Dim':<10} {'Chunk Dim':<10} {'Shard Dim':<10} {'Divisible?':<12} {'Description'}") + print( + f"{'Data Dim':<10} {'Chunk Dim':<10} {'Shard Dim':<10} {'Divisible?':<12} {'Description'}" + ) print("-" * 80) - + all_passed = True for data_dim, chunk_dim, description in test_cases: shard_dim = _calculate_shard_dimension(data_dim, chunk_dim) - + # When chunk_dim >= data_dim, the effective chunk size is data_dim effective_chunk_dim = min(chunk_dim, data_dim) is_divisible = shard_dim % effective_chunk_dim == 0 status = "✅ YES" if is_divisible else "❌ NO" - - print(f"{data_dim:<10} {chunk_dim:<10} {shard_dim:<10} {status:<12} {description}") - + + print( + f"{data_dim:<10} {chunk_dim:<10} {shard_dim:<10} {status:<12} {description}" + ) + if not is_divisible: all_passed = False - print(f" ⚠️ ERROR: {shard_dim} % {effective_chunk_dim} = {shard_dim % effective_chunk_dim}") - + print( + f" ⚠️ ERROR: {shard_dim} % {effective_chunk_dim} = {shard_dim % effective_chunk_dim}" + ) + print("-" * 80) if all_passed: - print("✅ All tests passed! Shard dimensions are properly divisible by chunk dimensions.") + print( + "✅ All tests passed! Shard dimensions are properly divisible by chunk dimensions." + ) else: print("❌ Some tests failed! Check the implementation.") - + return all_passed @@ -63,55 +72,67 @@ def test_encoding_creation(): import numpy as np import xarray as xr from zarr.codecs import BloscCodec + from eopf_geozarr.conversion.geozarr import _create_geozarr_encoding - + print("\n🧪 Testing encoding creation with sharding...") - + # Create a test dataset data = np.random.rand(1, 10980, 10980).astype(np.float32) - ds = xr.Dataset({ - 'b02': (['time', 'y', 'x'], data), - }, coords={ - 'time': [np.datetime64('2023-01-01')], - 'y': np.arange(10980), - 'x': np.arange(10980), - }) - + ds = xr.Dataset( + { + "b02": (["time", "y", "x"], data), + }, + coords={ + "time": [np.datetime64("2023-01-01")], + "y": np.arange(10980), + "x": np.arange(10980), + }, + ) + compressor = BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0) spatial_chunk = 4096 - + # Test with sharding enabled print("\nTesting with sharding enabled:") - encoding = _create_geozarr_encoding(ds, compressor, spatial_chunk, enable_sharding=True) - + encoding = _create_geozarr_encoding( + ds, compressor, spatial_chunk, enable_sharding=True + ) + for var, enc in encoding.items(): - if 'shards' in enc and enc['shards'] is not None: - chunks = enc['chunks'] - shards = enc['shards'] + if "shards" in enc and enc["shards"] is not None: + chunks = enc["chunks"] + shards = enc["shards"] print(f"Variable: {var}") print(f" Data shape: {ds[var].shape}") print(f" Chunks: {chunks}") print(f" Shards: {shards}") - + # Validate divisibility valid = True for i, (shard_dim, chunk_dim) in enumerate(zip(shards, chunks)): if shard_dim % chunk_dim != 0: - print(f" ❌ Axis {i}: {shard_dim} % {chunk_dim} = {shard_dim % chunk_dim}") + print( + f" ❌ Axis {i}: {shard_dim} % {chunk_dim} = {shard_dim % chunk_dim}" + ) valid = False else: print(f" ✅ Axis {i}: {shard_dim} % {chunk_dim} = 0") - + if valid: print(" ✅ All shard dimensions are divisible by chunk dimensions") else: - print(" ❌ Some shard dimensions are not divisible by chunk dimensions") - + print( + " ❌ Some shard dimensions are not divisible by chunk dimensions" + ) + print("\nTesting with sharding disabled:") - encoding_no_shard = _create_geozarr_encoding(ds, compressor, spatial_chunk, enable_sharding=False) - + encoding_no_shard = _create_geozarr_encoding( + ds, compressor, spatial_chunk, enable_sharding=False + ) + for var, enc in encoding_no_shard.items(): - if 'shards' in enc: + if "shards" in enc: print(f"Variable: {var}, Shards: {enc['shards']}") @@ -119,13 +140,13 @@ def main(): """Run all tests.""" print("🔧 Testing Zarr v3 Sharding Fix for GeoZarr") print("=" * 50) - + # Test the shard dimension calculation test1_passed = test_calculate_shard_dimension() - + # Test the encoding creation test_encoding_creation() - + print("\n" + "=" * 50) if test1_passed: print("✅ All critical tests passed!") @@ -136,7 +157,7 @@ def main(): print("- Enhanced shard calculation with preference for larger multipliers") else: print("❌ Some tests failed. Please review the implementation.") - + return test1_passed From 4849e3fb7960345818b7542752258399ba4b4e44 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Mon, 29 Sep 2025 10:43:39 +0000 Subject: [PATCH 35/83] feat: update memory limit for Dask client to 8GB and adjust spatial chunk size to 256 in S2MultiscalePyramid --- src/eopf_geozarr/cli.py | 6 +- .../s2_optimization/s2_converter.py | 68 ++++++++++--------- .../s2_optimization/s2_data_consolidator.py | 28 +++----- .../s2_optimization/s2_multiscale.py | 55 ++++++++++----- 4 files changed, 84 insertions(+), 73 deletions(-) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 8c0ba32a..c070863f 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -54,7 +54,7 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any from dask.distributed import Client # Set up local cluster with high memory limits - client = Client(n_workers=3, memory_limit="4GB") # set up local cluster with 3 workers and 8GB memory each + client = Client(n_workers=3, memory_limit="8GB") # set up local cluster with 3 workers and 8GB memory each # client = Client() # set up local cluster if verbose: @@ -1173,8 +1173,8 @@ def add_s2_optimization_commands(subparsers): s2_parser.add_argument( '--spatial-chunk', type=int, - default=1024, - help='Spatial chunk size (default: 1024)' + default=256, + help='Spatial chunk size (default: 256)' ) s2_parser.add_argument( '--enable-sharding', diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index d1a7aadc..1d6bd240 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -105,9 +105,9 @@ def convert_s2_optimized( meteorology_path = f"{output_path}/meteorology" self._write_auxiliary_group(meteorology_ds, meteorology_path, "meteorology", verbose) - # Step 5: Add multiscales metadata to measurements group - print("Step 5: Adding multiscales metadata to measurements group...") - self._add_measurements_multiscales_metadata(output_path, pyramid_datasets) + # Step 5: Create measurements group and add multiscales metadata + print("Step 5: Creating measurements group...") + measurement_dt = self._write_measurements_group(pyramid_datasets, "measurements", verbose) # Step 6: Simple root-level consolidation print("Step 6: Final root-level metadata consolidation...") @@ -200,36 +200,38 @@ def _write_auxiliary_group( if verbose: print(f" {group_type.title()} group written: {len(dataset.data_vars)} variables") - def _add_measurements_multiscales_metadata(self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset]) -> None: - """Add multiscales metadata to the measurements group using rioxarray.""" - try: - measurements_path = f"{output_path}/measurements" - - # Create multiscales metadata using rioxarray .rio accessor - multiscales_metadata = self._create_multiscales_metadata_with_rio(pyramid_datasets) - - if multiscales_metadata: - # Use zarr to add metadata to the measurements group - storage_options = get_storage_options(measurements_path) - - try: - import zarr - if storage_options: - store = zarr.storage.FSStore(measurements_path, **storage_options) - else: - store = measurements_path - - # Open the measurements group and add multiscales metadata - measurements_group = zarr.open_group(store, mode='r+') - measurements_group.attrs['multiscales'] = multiscales_metadata - - print(" ✅ Added multiscales metadata to measurements group") - - except Exception as e: - print(f" ⚠️ Could not add multiscales metadata: {e}") - - except Exception as e: - print(f" ⚠️ Error adding multiscales metadata: {e}") + def _write_measurements_group( + self, + pyramid_datasets: Dict[int, xr.Dataset], + group_name: str, + verbose: bool + ) -> None: + """Write measurements group with pyramid datasets.""" + group_path = f"{group_name}" + + measurements_group = xr.DataTree() + for level, ds in pyramid_datasets.items(): + if ds is not None: + measurements_group[level] = ds + + multiscales_attrs = self._create_multiscales_metadata_with_rio(pyramid_datasets) + if multiscales_attrs: + measurements_group.attrs['multiscales'] = [multiscales_attrs] + if verbose: + print(f" Multiscales metadata added with {len(multiscales_attrs.get('tile_matrix_set', {}).get('matrices', []))} levels") + + # Write the measurements group with consolidation + storage_options = get_storage_options(group_path) + measurements_group.to_zarr( + group_path, + mode='w', + consolidated=True, + zarr_format=3, + storage_options=storage_options, + compute=True # Direct compute for simplicity + ) + + return measurements_group def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.Dataset]) -> Dict: """Create multiscales metadata using rioxarray .rio accessor, following geozarr.py format.""" diff --git a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py index 1bf1d2bd..8231a04f 100644 --- a/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py +++ b/src/eopf_geozarr/s2_optimization/s2_data_consolidator.py @@ -3,11 +3,7 @@ """ import xarray as xr -from typing import Dict, List, Tuple, Optional -from .s2_band_mapping import ( - NATIVE_BANDS, QUALITY_DATA_NATIVE, DETECTOR_FOOTPRINT_NATIVE, - get_bands_for_level, get_quality_data_for_level -) +from typing import Dict, List, Tuple class S2DataConsolidator: """Consolidates S2 data from scattered structure into organized groups.""" @@ -84,10 +80,8 @@ def _extract_reflectance_bands(self) -> None: ds = group_node.to_dataset() # Extract only native bands for this resolution - native_bands = NATIVE_BANDS.get(res_num, []) - for band in native_bands: - if band in ds.data_vars: - self.measurements_data[res_num]['bands'][band] = ds[band] + for band in ds.data_vars: + self.measurements_data[res_num]['bands'][band] = ds[band] def _extract_quality_data(self) -> None: """Extract quality mask data.""" @@ -100,11 +94,8 @@ def _extract_quality_data(self) -> None: if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - # Only extract quality for native bands at this resolution - native_bands = NATIVE_BANDS.get(res_num, []) - for band in native_bands: - if band in ds.data_vars: - self.measurements_data[res_num]['quality'][f'quality_{band}'] = ds[band] + for band in ds.data_vars: + self.measurements_data[res_num]['quality'][f'quality_{band}'] = ds[band] def _extract_detector_footprints(self) -> None: """Extract detector footprint data.""" @@ -117,12 +108,9 @@ def _extract_detector_footprints(self) -> None: if group_path in self.dt_input.groups: ds = self.dt_input[group_path].to_dataset() - # Only extract footprints for native bands - native_bands = NATIVE_BANDS.get(res_num, []) - for band in native_bands: - if band in ds.data_vars: - var_name = f'detector_footprint_{band}' - self.measurements_data[res_num]['detector_footprints'][var_name] = ds[band] + for band in ds.data_vars: + var_name = f'detector_footprint_{band}' + self.measurements_data[res_num]['detector_footprints'][var_name] = ds[band] def _extract_atmosphere_data(self) -> None: """Extract atmosphere quality data (aot, wvp) - native at 20m.""" diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index d6169fb0..9b7c503e 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -26,7 +26,7 @@ def compute(*args, **kwargs): class S2MultiscalePyramid: """Creates multiscale pyramids for consolidated S2 data.""" - def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): + def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 256): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk self.resampler = S2ResamplingEngine() @@ -37,9 +37,8 @@ def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 1024): 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) 2: 60, # Level 2: 60m (3x downsampling from 20m) 3: 120, # Level 3: 120m (2x downsampling from 60m) - 4: 240, # Level 4: 240m (2x downsampling from 120m) - 5: 480, # Level 5: 480m (2x downsampling from 240m) - 6: 960 # Level 6: 960m (2x downsampling from 480m) + 4: 360, # Level 4: 360m (3x downsampling from 120m) + 5: 720, # Level 5: 720m (2x downsampling from 360m) } def create_multiscale_measurements( @@ -108,15 +107,25 @@ def _create_multiscale_measurements_parallel( # Write immediately to avoid memory buildup level_path = f"{output_path}/measurements/{level}" print(f" Writing level {level} to {level_path}") - self._write_level_dataset(dataset, level_path, level) + dataset = self._write_level_dataset(dataset, level_path, level) # Store only essential levels for dependencies - if level == 2: + if level <= 2: # Keep level 2 for creating higher levels pyramid_datasets[level] = dataset - elif level < 2: - # Keep reference but could be cleaned up if memory is tight - pyramid_datasets[level] = dataset + # update measurements_by_resolution also to use the in memory dataset + for var in dataset.data_vars: + var_data = dataset[var] + res = dataset.attrs.get('resolution_meters', None) + if res: + if var.startswith('b') or var.startswith('quality_') or var.startswith('detector_footprint_'): + if var.startswith('b'): + measurements_by_resolution[res]['bands'][var] = var_data + elif var.startswith('quality_'): + measurements_by_resolution[res]['quality'][var] = var_data + elif var.startswith('detector_footprint_'): + measurements_by_resolution[res]['detector_footprints'][var] = var_data + # Clean up memory for higher levels (they're already written) if level > 2: @@ -209,7 +218,7 @@ def _create_level_1_dataset_parallel(self, measurements_by_resolution: Dict) -> # Start with native 20m data if 20 in measurements_by_resolution: data_20m = measurements_by_resolution[20] - for category, vars_dict in data_20m.items(): + for category, vars_dict in data_20m.items(): all_vars.update(vars_dict) # Get reference coordinates from 20m data @@ -311,15 +320,17 @@ def _create_level_2_dataset_parallel(self, measurements_by_resolution: Dict) -> data_20m = measurements_by_resolution[20] for category, vars_dict in data_20m.items(): for var_name, var_data in vars_dict.items(): - vars_to_downsample.append((var_name, var_data, '20m')) + if var_name not in all_vars: + vars_to_downsample.append((var_name, var_data, '20m')) # Add 10m data for downsampling if 10 in measurements_by_resolution: data_10m = measurements_by_resolution[10] for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): - vars_to_downsample.append((var_name, var_data, '10m')) - + if var_name not in all_vars: + vars_to_downsample.append((var_name, var_data, '10m')) + # Process all downsampling in parallel if Dask is available if DASK_AVAILABLE and vars_to_downsample: @delayed @@ -499,7 +510,7 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase for category, vars_dict in data_20m.items(): for var_name, var_data in vars_dict.items(): # skip if already present from 20m data - if var_name in all_vars: + if var_name in all_vars or var_name in [v[0] for v in vars_to_downsample]: continue vars_to_downsample.append((var_name, var_data, '20m')) @@ -509,7 +520,7 @@ def _create_level_2_dataset(self, measurements_by_resolution: Dict) -> xr.Datase for category, vars_dict in data_10m.items(): for var_name, var_data in vars_dict.items(): # skip if already present from 20m data - if var_name in all_vars: + if var_name in all_vars or var_name in [v[0] for v in vars_to_downsample]: continue vars_to_downsample.append((var_name, var_data, '10m')) @@ -671,7 +682,7 @@ def _downsample_variables_sequential( return dataset - def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> None: + def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) -> xr.Dataset: """ Write a pyramid level dataset to storage with xy-aligned sharding. @@ -686,7 +697,15 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) if has_time_dim and self._should_separate_time_files(dataset): # Write each time slice separately to ensure single file per variable per time self._write_time_separated_dataset(dataset, level_path, level, encoding) + return dataset else: + # check if level_path exists and skip writing if it does + import os + if os.path.exists(level_path): + print(f" Level path {level_path} already exists. Skipping write.") + # return the existing dataset + return xr.open_dataset(level_path, engine='zarr') + # Write as single dataset with xy-aligned sharding print(f" Writing level {level} to {level_path} (xy-aligned sharding)") @@ -716,7 +735,9 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) else: print(f" Writing zarr file...") write_job.compute() - + + return rechunked_dataset + def _should_separate_time_files(self, dataset: xr.Dataset) -> bool: """Determine if time files should be separated for single file per variable per time.""" for var in dataset.data_vars.values(): From 3d02eab96dc420c8c467c46c77708c626faead75 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Mon, 29 Sep 2025 12:00:55 +0000 Subject: [PATCH 36/83] feat: add new CLI command for converting to GeoZarr S2L2A optimized format with sharding support --- .vscode/launch.json | 28 ++++++++++++++++++++++++++++ src/eopf_geozarr/cli.py | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 3beff475..55ebab49 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -133,6 +133,34 @@ "AWS_ENDPOINT_URL": "https://s3.de.io.cloud.ovh.net/" }, + }, + { + // eopf_geozarr convert https://objectstore.eodc.eu:2222/e05ab01a9d56408d82ac32d69a5aae2a:sample-data/tutorial_data/cpm_v253/S2B_MSIL1C_20250113T103309_N0511_R108_T32TLQ_20250113T122458.zarr /tmp/tmp7mmjkjk3/s2b_subset_test.zarr --groups /measurements/reflectance/r10m --spatial-chunk 512 --min-dimension 128 --tile-width 256 --max-retries 2 --verbose + "name": "Convert to GeoZarr S2L2A Optimized (S3)", + "type": "debugpy", + "request": "launch", + "module": "eopf_geozarr", + "args": [ + "convert-s2-optimized", + "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/08/products/cpm_v256/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + "./tests-output/eopf_geozarr/s2l2_optimized.zarr", + "--spatial-chunk", "256", + "--compression-level", "5", + "--enable-sharding", + "--dask-cluster", + "--verbose" + ], + "cwd": "${workspaceFolder}", + "justMyCode": false, + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}/.venv/bin", + "AWS_PROFILE": "eopf-explorer", + "AWS_DEFAULT_REGION": "de", + "AWS_ENDPOINT_URL": "https://s3.de.io.cloud.ovh.net/" + }, + }, { "name": "Convert to GeoZarr Sentinel-1 GRD (Local)", diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 1648e702..768ac646 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -68,7 +68,7 @@ def setup_dask_cluster(enable_dask: bool, verbose: bool = False) -> Optional[Any except ImportError: print( - "❌ Error: dask.distributed not available. Install with: pip install 'dask[distributed]'"vars_to_downsample + "❌ Error: dask.distributed not available. Install with: pip install 'dask[distributed]'" ) sys.exit(1) except Exception as e: From 48f5dd8788d42ec8ad0aa3083cd9221c33ca5647 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Mon, 29 Sep 2025 14:06:33 +0200 Subject: [PATCH 37/83] feat: implement batched parallel downsampling for S2 datasets and improve classification downsampling method --- .../s2_optimization/s2_multiscale.py | 178 ++++++++++++------ .../s2_optimization/s2_resampling.py | 48 ++--- 2 files changed, 143 insertions(+), 83 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 6dbc6f41..4a98e448 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -213,7 +213,7 @@ def _create_level_dataset_parallel( def _create_level_1_dataset_parallel( self, measurements_by_resolution: Dict ) -> xr.Dataset: - """Create level 1 dataset with parallel downsampling from 10m data.""" + """Create level 1 dataset with batched parallel downsampling from 10m data.""" all_vars = {} reference_coords = None @@ -231,7 +231,7 @@ def _create_level_1_dataset_parallel( "y": first_var.coords["y"], } - # Add downsampled 10m data with parallelization + # Add downsampled 10m data with batched parallelization if 10 in measurements_by_resolution and reference_coords: data_10m = measurements_by_resolution[10] target_height = len(reference_coords["y"]) @@ -246,34 +246,22 @@ def _create_level_1_dataset_parallel( continue vars_to_downsample.append((var_name, var_data)) - # Process variables in parallel if Dask is available + # Process variables in batches if Dask is available if DASK_AVAILABLE and vars_to_downsample: - - @delayed - def downsample_10m_variable(var_name: str, var_data: xr.DataArray): - var_type = determine_variable_type(var_name, var_data) - downsampled = self.resampler.downsample_variable( - var_data, target_height, target_width, var_type + batch_size = min(8, max(4, len(vars_to_downsample) // 4)) # Adaptive batch size + print(f" Batched parallel downsampling {len(vars_to_downsample)} variables from 10m to 20m (batch size: {batch_size})...") + + # Process variables in batches + for i in range(0, len(vars_to_downsample), batch_size): + batch = vars_to_downsample[i:i + batch_size] + batch_vars = self._process_variable_batch_to_20m( + batch, target_height, target_width, reference_coords ) - # Align coordinates - downsampled = downsampled.assign_coords(reference_coords) - return var_name, downsampled - - # Create tasks for all variables - downsample_tasks = [ - downsample_10m_variable(var_name, var_data) - for var_name, var_data in vars_to_downsample - ] - - # Compute all in parallel - print( - f" Parallel downsampling {len(downsample_tasks)} variables from 10m to 20m..." - ) - results = compute(*downsample_tasks) - for var_name, downsampled_var in results: - all_vars[var_name] = downsampled_var + all_vars.update(batch_vars) + print(f" Completed batch {i//batch_size + 1}/{(len(vars_to_downsample) + batch_size - 1)//batch_size}") else: # Sequential fallback + print(f" Sequential downsampling {len(vars_to_downsample)} variables from 10m to 20m...") for var_name, var_data in vars_to_downsample: var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( @@ -295,10 +283,37 @@ def downsample_10m_variable(var_name: str, var_data: xr.DataArray): return dataset + def _process_variable_batch_to_20m( + self, batch: list, target_height: int, target_width: int, reference_coords: dict + ) -> dict: + """Process a batch of variables for downsampling to 20m with memory management.""" + + @delayed + def downsample_to_20m_variable(var_name: str, var_data: xr.DataArray): + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + return var_name, downsampled + + # Create tasks for this batch only + batch_tasks = [ + downsample_to_20m_variable(var_name, var_data) + for var_name, var_data in batch + ] + + # Compute batch in parallel + batch_results = compute(*batch_tasks) + + # Return as dictionary + return dict(batch_results) + def _create_level_2_dataset_parallel( self, measurements_by_resolution: Dict ) -> xr.Dataset: - """Create level 2 dataset with parallel downsampling to 60m.""" + """Create level 2 dataset with batched parallel downsampling to 60m.""" all_vars = {} reference_coords = None @@ -338,36 +353,22 @@ def _create_level_2_dataset_parallel( if var_name not in all_vars: vars_to_downsample.append((var_name, var_data, '10m')) - # Process all downsampling in parallel if Dask is available + # Process downsampling in batches to manage memory if DASK_AVAILABLE and vars_to_downsample: - - @delayed - def downsample_to_60m_variable( - var_name: str, var_data: xr.DataArray, source_res: str - ): - var_type = determine_variable_type(var_name, var_data) - downsampled = self.resampler.downsample_variable( - var_data, target_height, target_width, var_type + batch_size = min(8, max(4, len(vars_to_downsample) // 4)) # Adaptive batch size + print(f" Batched parallel downsampling {len(vars_to_downsample)} variables to 60m (batch size: {batch_size})...") + + # Process variables in batches + for i in range(0, len(vars_to_downsample), batch_size): + batch = vars_to_downsample[i:i + batch_size] + batch_vars = self._process_variable_batch_to_60m( + batch, target_height, target_width, reference_coords ) - # Align coordinates - downsampled = downsampled.assign_coords(reference_coords) - return var_name, downsampled - - # Create tasks for all variables - downsample_tasks = [ - downsample_to_60m_variable(var_name, var_data, source_res) - for var_name, var_data, source_res in vars_to_downsample - ] - - # Compute all in parallel - print( - f" Parallel downsampling {len(downsample_tasks)} variables to 60m..." - ) - results = compute(*downsample_tasks) - for var_name, downsampled_var in results: - all_vars[var_name] = downsampled_var + all_vars.update(batch_vars) + print(f" Completed batch {i//batch_size + 1}/{(len(vars_to_downsample) + batch_size - 1)//batch_size}") else: # Sequential fallback + print(f" Sequential downsampling {len(vars_to_downsample)} variables to 60m...") for var_name, var_data, source_res in vars_to_downsample: var_type = determine_variable_type(var_name, var_data) downsampled = self.resampler.downsample_variable( @@ -389,6 +390,77 @@ def downsample_to_60m_variable( return dataset + def _process_variable_batch_to_60m( + self, batch: list, target_height: int, target_width: int, reference_coords: dict + ) -> dict: + """Process a batch of variables for downsampling to 60m with memory management.""" + + @delayed + def downsample_to_60m_variable( + var_name: str, var_data: xr.DataArray, source_res: str + ): + try: + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + # Align coordinates + downsampled = downsampled.assign_coords(reference_coords) + return var_name, downsampled + except Exception as e: + print(f" Warning: Failed to downsample {var_name} from {source_res}: {e}") + return var_name, None + + # Create tasks for this batch only + batch_tasks = [ + downsample_to_60m_variable(var_name, var_data, source_res) + for var_name, var_data, source_res in batch + ] + + # Compute batch in parallel with memory management + try: + batch_results = compute(*batch_tasks) + + # Filter out failed results and return as dictionary + successful_results = { + var_name: result for var_name, result in batch_results + if result is not None + } + + # Force garbage collection to free memory + import gc + gc.collect() + + return successful_results + + except Exception as e: + print(f" Error processing batch: {e}") + # Fallback to sequential processing for this batch + return self._process_batch_sequential_fallback_60m( + batch, target_height, target_width, reference_coords + ) + + def _process_batch_sequential_fallback_60m( + self, batch: list, target_height: int, target_width: int, reference_coords: dict + ) -> dict: + """Sequential fallback for failed batch processing.""" + print(f" Falling back to sequential processing for {len(batch)} variables...") + results = {} + + for var_name, var_data, source_res in batch: + try: + var_type = determine_variable_type(var_name, var_data) + downsampled = self.resampler.downsample_variable( + var_data, target_height, target_width, var_type + ) + downsampled = downsampled.assign_coords(reference_coords) + results[var_name] = downsampled + except Exception as e: + print(f" Warning: Failed to downsample {var_name}: {e}") + continue + + return results + def _create_downsampled_dataset_from_level2_parallel( self, level: int, target_resolution: int, level_2_dataset: xr.Dataset ) -> xr.Dataset: diff --git a/src/eopf_geozarr/s2_optimization/s2_resampling.py b/src/eopf_geozarr/s2_optimization/s2_resampling.py index 0840e30c..6c226e2d 100644 --- a/src/eopf_geozarr/s2_optimization/s2_resampling.py +++ b/src/eopf_geozarr/s2_optimization/s2_resampling.py @@ -88,9 +88,7 @@ def _downsample_reflectance( def _downsample_classification( self, data: xr.DataArray, target_height: int, target_width: int ) -> xr.DataArray: - """Mode-based downsampling for classification data.""" - from scipy import stats - + """Fast nearest neighbor downsampling for classification data.""" current_height, current_width = data.shape[-2:] block_h = current_height // target_height block_w = current_width // target_width @@ -100,35 +98,25 @@ def _downsample_classification( new_width = (current_width // block_w) * block_w data = data[..., :new_height, :new_width] - # Reshape for block processing + # Use simple nearest neighbor sampling (much faster than mode) + # Take the center pixel of each block as representative + center_h = block_h // 2 + center_w = block_w // 2 + if data.ndim == 3: - reshaped = data.values.reshape( - data.shape[0], target_height, block_h, target_width, block_w - ) - # Compute mode for each block - downsampled = np.zeros( - (data.shape[0], target_height, target_width), dtype=data.dtype - ) - for t in range(data.shape[0]): - for i in range(target_height): - for j in range(target_width): - block = reshaped[t, i, :, j, :].flatten() - mode_val = stats.mode(block, keepdims=False)[0] - downsampled[t, i, j] = mode_val + # Sample every block_h and block_w pixels, starting from center + downsampled = data.values[:, center_h::block_h, center_w::block_w] + # Ensure we get exactly the target dimensions + downsampled = downsampled[:, :target_height, :target_width] else: - reshaped = data.values.reshape( - target_height, block_h, target_width, block_w - ) - downsampled = np.zeros((target_height, target_width), dtype=data.dtype) - for i in range(target_height): - for j in range(target_width): - block = reshaped[i, :, j, :].flatten() - mode_val = stats.mode(block, keepdims=False)[0] - downsampled[i, j] = mode_val - - # Create coordinates - y_coords = data.coords[data.dims[-2]][::block_h][:target_height] - x_coords = data.coords[data.dims[-1]][::block_w][:target_width] + # Sample every block_h and block_w pixels, starting from center + downsampled = data.values[center_h::block_h, center_w::block_w] + # Ensure we get exactly the target dimensions + downsampled = downsampled[:target_height, :target_width] + + # Create coordinates (sample from the center positions) + y_coords = data.coords[data.dims[-2]][center_h::block_h][:target_height] + x_coords = data.coords[data.dims[-1]][center_w::block_w][:target_width] if data.ndim == 3: coords = { From e20c411dd59bd7d477cd4d391994d4308ad89733 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 08:21:55 +0000 Subject: [PATCH 38/83] fix: update measurement group keys and enhance dataset loading with decoding options --- src/eopf_geozarr/s2_optimization/s2_converter.py | 3 ++- src/eopf_geozarr/s2_optimization/s2_multiscale.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 4937401a..310d667d 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -220,7 +220,7 @@ def _write_measurements_group( measurements_group = xr.DataTree() for level, ds in pyramid_datasets.items(): if ds is not None: - measurements_group[level] = ds + measurements_group[f"{level}"] = ds multiscales_attrs = self._create_multiscales_metadata_with_rio(pyramid_datasets) if multiscales_attrs: @@ -235,6 +235,7 @@ def _write_measurements_group( mode='w', consolidated=True, zarr_format=3, + encoding={}, # Encoding handled at individual dataset level storage_options=storage_options, compute=True # Direct compute for simplicity ) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale.py b/src/eopf_geozarr/s2_optimization/s2_multiscale.py index 4a98e448..6a99a0ea 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale.py @@ -785,7 +785,7 @@ def _write_level_dataset(self, dataset: xr.Dataset, level_path: str, level: int) if os.path.exists(level_path): print(f" Level path {level_path} already exists. Skipping write.") # return the existing dataset - return xr.open_dataset(level_path, engine='zarr') + return xr.open_dataset(level_path, engine='zarr', chunks={}, decode_coords="all") # Write as single dataset with xy-aligned sharding print(f" Writing level {level} to {level_path} (xy-aligned sharding)") From e33f035fb1a4036a8399ebe69a5b08936f15789c Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 10:34:50 +0200 Subject: [PATCH 39/83] feat: add streaming support for multiscale pyramid creation in S2 converter --- .../s2_optimization/s2_converter.py | 24 +- .../s2_multiscale_streaming.py | 513 ++++++++++++++++++ 2 files changed, 532 insertions(+), 5 deletions(-) create mode 100644 src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 310d667d..c2a2f18d 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -15,6 +15,7 @@ from .s2_data_consolidator import S2DataConsolidator from .s2_multiscale import S2MultiscalePyramid +from .s2_multiscale_streaming import S2StreamingMultiscalePyramid from .s2_validation import S2OptimizationValidator try: @@ -34,14 +35,19 @@ def __init__( spatial_chunk: int = 1024, compression_level: int = 3, max_retries: int = 3, + enable_streaming: bool = True, ): self.enable_sharding = enable_sharding self.spatial_chunk = spatial_chunk self.compression_level = compression_level self.max_retries = max_retries + self.enable_streaming = enable_streaming - # Initialize components - self.pyramid_creator = S2MultiscalePyramid(enable_sharding, spatial_chunk) + # Initialize components - choose between streaming and traditional + if enable_streaming: + self.pyramid_creator = S2StreamingMultiscalePyramid(enable_sharding, spatial_chunk) + else: + self.pyramid_creator = S2MultiscalePyramid(enable_sharding, spatial_chunk) self.validator = S2OptimizationValidator() def convert_s2_optimized( @@ -94,9 +100,16 @@ def convert_s2_optimized( # Step 2: Create multiscale measurements print("Step 2: Creating multiscale measurements pyramid...") - pyramid_datasets = self.pyramid_creator.create_multiscale_measurements( - measurements_data, output_path - ) + if self.enable_streaming: + # Use streaming approach - computation happens during write + pyramid_datasets = self.pyramid_creator.create_multiscale_measurements_streaming( + measurements_data, output_path + ) + else: + # Use traditional approach + pyramid_datasets = self.pyramid_creator.create_multiscale_measurements( + measurements_data, output_path + ) print(f" Created {len(pyramid_datasets)} pyramid levels") @@ -435,6 +448,7 @@ def convert_s2_optimized( "spatial_chunk": kwargs.pop("spatial_chunk", 1024), "compression_level": kwargs.pop("compression_level", 3), "max_retries": kwargs.pop("max_retries", 3), + "enable_streaming": kwargs.pop("enable_streaming", True), } # Remaining kwargs are for the convert_s2_optimized method diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py new file mode 100644 index 00000000..df57db88 --- /dev/null +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -0,0 +1,513 @@ +""" +Streaming multiscale pyramid creation for optimized S2 structure. +Uses lazy evaluation to minimize memory usage during dataset preparation. +""" + +from typing import Dict, Tuple + +import xarray as xr +from pyproj import CRS + +from .s2_resampling import S2ResamplingEngine, determine_variable_type + +try: + import distributed + from dask import compute, delayed + import dask.array as da + + DISTRIBUTED_AVAILABLE = True + DASK_AVAILABLE = True +except ImportError: + DISTRIBUTED_AVAILABLE = False + DASK_AVAILABLE = False + + # Create dummy delayed function for non-dask environments + def delayed(func): + return func + + def compute(*args, **kwargs): + return args + + +class S2StreamingMultiscalePyramid: + """Creates streaming multiscale pyramids with lazy evaluation.""" + + def __init__(self, enable_sharding: bool = True, spatial_chunk: int = 256): + self.enable_sharding = enable_sharding + self.spatial_chunk = spatial_chunk + self.resampler = S2ResamplingEngine() + + # Define pyramid levels: resolution in meters + self.pyramid_levels = { + 0: 10, # Level 0: 10m (native for b02,b03,b04,b08) + 1: 20, # Level 1: 20m (native for b05,b06,b07,b11,b12,b8a + all quality) + 2: 60, # Level 2: 60m (3x downsampling from 20m) + 3: 120, # Level 3: 120m (2x downsampling from 60m) + 4: 360, # Level 4: 360m (3x downsampling from 120m) + 5: 720, # Level 5: 720m (2x downsampling from 360m) + } + + def create_multiscale_measurements_streaming( + self, measurements_by_resolution: Dict[int, Dict], output_path: str + ) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid with streaming lazy evaluation. + + Key innovation: Downsampling operations are prepared as computation graphs + but not executed until write time, enabling true streaming processing. + """ + if DASK_AVAILABLE: + return self._create_streaming_measurements_lazy( + measurements_by_resolution, output_path + ) + else: + # Fallback to regular processing + return self._create_multiscale_measurements_sequential( + measurements_by_resolution, output_path + ) + + def _create_streaming_measurements_lazy( + self, measurements_by_resolution: Dict[int, Dict], output_path: str + ) -> Dict[int, xr.Dataset]: + """ + Create multiscale pyramid with lazy evaluation and streaming writes. + + Strategy: + 1. Create lazy datasets with delayed downsampling operations + 2. Write each level with streaming execution + 3. Computation happens only during zarr write operations + 4. Minimal memory usage - no intermediate results stored + """ + print("Creating streaming multiscale pyramid with lazy evaluation...") + pyramid_datasets = {} + + # Process levels sequentially but prepare lazy operations + for level in sorted(self.pyramid_levels.keys()): + target_resolution = self.pyramid_levels[level] + print(f"Preparing lazy operations for level {level} ({target_resolution}m)...") + + # Create lazy dataset with delayed operations + if level <= 2: + # Base levels: use source measurements data + lazy_dataset = self._create_lazy_level_dataset( + level, target_resolution, measurements_by_resolution + ) + else: + # Higher levels: use level 2 data if available + if 2 in pyramid_datasets: + lazy_dataset = self._create_lazy_downsampled_dataset_from_level2( + level, target_resolution, pyramid_datasets[2] + ) + else: + print(f" Skipping level {level} - level 2 not available") + continue + + if lazy_dataset and len(lazy_dataset.data_vars) > 0: + # Store lazy dataset for potential use by higher levels + pyramid_datasets[level] = lazy_dataset + + # Stream write the lazy dataset (computation happens here) + level_path = f"{output_path}/measurements/{level}" + print(f" Streaming write of level {level} to {level_path}") + self._stream_write_lazy_dataset(lazy_dataset, level_path, level) + + # For levels 3+, we can discard after writing to save memory + if level > 2: + pyramid_datasets[level] = None + else: + print(f" Skipping empty level {level}") + + print(f"✅ Streaming pyramid creation complete") + return pyramid_datasets + + def _create_lazy_level_dataset( + self, + level: int, + target_resolution: int, + measurements_by_resolution: Dict[int, Dict], + ) -> xr.Dataset: + """Create dataset with lazy downsampling operations.""" + + if level == 0: + # Level 0: Only native 10m data (no downsampling needed) + return self._create_level_0_dataset(measurements_by_resolution) + elif level == 1: + # Level 1: All data at 20m (native + lazy downsampled from 10m) + return self._create_lazy_level_1_dataset(measurements_by_resolution) + elif level == 2: + # Level 2: All data at 60m (native + lazy downsampled from 20m/10m) + return self._create_lazy_level_2_dataset(measurements_by_resolution) + else: + # Should not be called for levels 3+ in streaming approach + raise ValueError(f"Use _create_lazy_downsampled_dataset_from_level2 for level {level}") + + def _create_lazy_level_1_dataset( + self, measurements_by_resolution: Dict + ) -> xr.Dataset: + """Create level 1 dataset with lazy downsampling from 10m data.""" + all_vars = {} + reference_coords = None + + # Start with native 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 20m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + "x": first_var.coords["x"], + "y": first_var.coords["y"], + } + + # Add lazy downsampled 10m data + if 10 in measurements_by_resolution and reference_coords: + data_10m = measurements_by_resolution[10] + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + + # Create lazy downsampling operations + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + if var_name in all_vars: + continue + + # Create lazy downsampling operation + lazy_downsampled = self._create_lazy_downsample_operation( + var_data, target_height, target_width, reference_coords + ) + all_vars[var_name] = lazy_downsampled + + if not all_vars: + return xr.Dataset() + + # Create dataset with lazy variables + dataset = xr.Dataset(all_vars) + dataset.attrs["pyramid_level"] = 1 + dataset.attrs["resolution_meters"] = 20 + + return dataset + + def _create_lazy_level_2_dataset( + self, measurements_by_resolution: Dict + ) -> xr.Dataset: + """Create level 2 dataset with lazy downsampling to 60m.""" + all_vars = {} + reference_coords = None + + # Start with native 60m data + if 60 in measurements_by_resolution: + data_60m = measurements_by_resolution[60] + for category, vars_dict in data_60m.items(): + all_vars.update(vars_dict) + + # Get reference coordinates from 60m data + if all_vars: + first_var = next(iter(all_vars.values())) + reference_coords = { + "x": first_var.coords["x"], + "y": first_var.coords["y"], + } + + if reference_coords: + target_height = len(reference_coords["y"]) + target_width = len(reference_coords["x"]) + + # Add lazy downsampling from 20m data + if 20 in measurements_by_resolution: + data_20m = measurements_by_resolution[20] + for category, vars_dict in data_20m.items(): + for var_name, var_data in vars_dict.items(): + if var_name not in all_vars: + lazy_downsampled = self._create_lazy_downsample_operation( + var_data, target_height, target_width, reference_coords + ) + all_vars[var_name] = lazy_downsampled + + # Add lazy downsampling from 10m data + if 10 in measurements_by_resolution: + data_10m = measurements_by_resolution[10] + for category, vars_dict in data_10m.items(): + for var_name, var_data in vars_dict.items(): + if var_name not in all_vars: + lazy_downsampled = self._create_lazy_downsample_operation( + var_data, target_height, target_width, reference_coords + ) + all_vars[var_name] = lazy_downsampled + + if not all_vars: + return xr.Dataset() + + # Create dataset with lazy variables + dataset = xr.Dataset(all_vars) + dataset.attrs["pyramid_level"] = 2 + dataset.attrs["resolution_meters"] = 60 + + return dataset + + def _create_lazy_downsampled_dataset_from_level2( + self, level: int, target_resolution: int, level_2_dataset: xr.Dataset + ) -> xr.Dataset: + """Create lazy downsampled dataset from level 2.""" + if len(level_2_dataset.data_vars) == 0: + return xr.Dataset() + + # Calculate target dimensions + downsample_factor = 2 ** (level - 2) + + # Get reference dimensions from level 2 + ref_var = next(iter(level_2_dataset.data_vars.values())) + current_height, current_width = ref_var.shape[-2:] + target_height = current_height // downsample_factor + target_width = current_width // downsample_factor + + # Create lazy downsampling operations for all variables + lazy_vars = {} + for var_name, var_data in level_2_dataset.data_vars.items(): + lazy_downsampled = self._create_lazy_downsample_operation_from_existing( + var_data, target_height, target_width + ) + lazy_vars[var_name] = lazy_downsampled + + # Create dataset with lazy variables + dataset = xr.Dataset(lazy_vars, coords=level_2_dataset.coords) + dataset.attrs["pyramid_level"] = level + dataset.attrs["resolution_meters"] = target_resolution + + return dataset + + def _create_lazy_downsample_operation( + self, + source_data: xr.DataArray, + target_height: int, + target_width: int, + reference_coords: dict + ) -> xr.DataArray: + """Create a lazy downsampling operation using Dask delayed.""" + + @delayed + def downsample_operation(): + var_type = determine_variable_type(source_data.name, source_data) + downsampled = self.resampler.downsample_variable( + source_data, target_height, target_width, var_type + ) + # Align coordinates + return downsampled.assign_coords(reference_coords) + + # Create delayed operation + lazy_result = downsample_operation() + + # Convert to Dask array with proper shape and chunks + # Estimate output shape based on target dimensions + if source_data.ndim == 3: + output_shape = (source_data.shape[0], target_height, target_width) + chunks = (1, min(256, target_height), min(256, target_width)) + else: + output_shape = (target_height, target_width) + chunks = (min(256, target_height), min(256, target_width)) + + # Create Dask array from delayed operation + dask_array = da.from_delayed( + lazy_result, + shape=output_shape, + dtype=source_data.dtype + ).rechunk(chunks) + + # Create coordinates for the output + if source_data.ndim == 3: + coords = { + source_data.dims[0]: source_data.coords[source_data.dims[0]], + source_data.dims[-2]: reference_coords["y"], + source_data.dims[-1]: reference_coords["x"], + } + else: + coords = { + source_data.dims[-2]: reference_coords["y"], + source_data.dims[-1]: reference_coords["x"], + } + + # Return as xarray DataArray with lazy data + return xr.DataArray( + dask_array, + dims=source_data.dims, + coords=coords, + attrs=source_data.attrs.copy(), + name=source_data.name + ) + + def _create_lazy_downsample_operation_from_existing( + self, + source_data: xr.DataArray, + target_height: int, + target_width: int + ) -> xr.DataArray: + """Create lazy downsampling operation from existing data.""" + + @delayed + def downsample_operation(): + var_type = determine_variable_type(source_data.name, source_data) + return self.resampler.downsample_variable( + source_data, target_height, target_width, var_type + ) + + # Create delayed operation + lazy_result = downsample_operation() + + # Estimate output shape and chunks + if source_data.ndim == 3: + output_shape = (source_data.shape[0], target_height, target_width) + chunks = (1, min(256, target_height), min(256, target_width)) + else: + output_shape = (target_height, target_width) + chunks = (min(256, target_height), min(256, target_width)) + + # Create Dask array from delayed operation + dask_array = da.from_delayed( + lazy_result, + shape=output_shape, + dtype=source_data.dtype + ).rechunk(chunks) + + # Return as xarray DataArray with lazy data + return xr.DataArray( + dask_array, + dims=source_data.dims, + attrs=source_data.attrs.copy(), + name=source_data.name + ) + + def _stream_write_lazy_dataset( + self, lazy_dataset: xr.Dataset, level_path: str, level: int + ) -> None: + """ + Stream write a lazy dataset - computation happens during write. + + This is where the magic happens: all the lazy downsampling operations + are executed as the data is streamed to storage. + """ + import os + + # Check if level already exists + if os.path.exists(level_path): + print(f" Level path {level_path} already exists. Skipping write.") + return + + # Create encoding for streaming write + encoding = self._create_level_encoding(lazy_dataset, level) + + print(f" Streaming computation and write to {level_path}") + print(f" Variables: {list(lazy_dataset.data_vars.keys())}") + + # Write with streaming computation + # The to_zarr operation will trigger all lazy computations + lazy_dataset.to_zarr( + level_path, + mode="w", + consolidated=True, + zarr_format=3, + encoding=encoding, + compute=True, # This triggers the lazy computation during write + ) + + print(f" ✅ Streaming write complete for level {level}") + + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: + """Create level 0 dataset with only native 10m data (no lazy operations needed).""" + if 10 not in measurements_by_resolution: + return xr.Dataset() + + data_10m = measurements_by_resolution[10] + all_vars = {} + + # Add only native 10m bands and their associated data + for category, vars_dict in data_10m.items(): + all_vars.update(vars_dict) + + if not all_vars: + return xr.Dataset() + + # Create consolidated dataset + dataset = xr.Dataset(all_vars) + dataset.attrs["pyramid_level"] = 0 + dataset.attrs["resolution_meters"] = 10 + + self._write_geo_metadata(dataset) + return dataset + + def _create_multiscale_measurements_sequential( + self, measurements_by_resolution: Dict[int, Dict], output_path: str + ) -> Dict[int, xr.Dataset]: + """Fallback sequential processing for non-Dask environments.""" + print("Creating multiscale pyramid sequentially (no streaming)...") + # Implementation would be similar to the original sequential approach + # This is a fallback - the main value is in the streaming approach + return {} + + def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: + """Create encoding optimized for streaming writes.""" + encoding = {} + + # Calculate level-appropriate chunk sizes for streaming + chunk_size = max(256, self.spatial_chunk // (2**level)) + + for var_name, var_data in dataset.data_vars.items(): + if hasattr(var_data.data, 'chunks'): + # Use existing chunks from Dask array + chunks = var_data.data.chunks + if len(chunks) >= 2: + # Convert chunk tuples to sizes + encoding_chunks = tuple(chunks[i][0] for i in range(len(chunks))) + else: + encoding_chunks = (chunk_size,) + else: + # Fallback chunk calculation + if var_data.ndim >= 2: + if var_data.ndim == 3: + encoding_chunks = (1, chunk_size, chunk_size) + else: + encoding_chunks = (chunk_size, chunk_size) + else: + encoding_chunks = (min(chunk_size, var_data.shape[0]),) + + # Configure encoding for streaming + from zarr.codecs import BloscCodec + + compressor = BloscCodec( + cname="zstd", clevel=3, shuffle="shuffle", blocksize=0 + ) + encoding[var_name] = { + "chunks": encoding_chunks, + "compressors": [compressor] + } + + # Add coordinate encoding + for coord_name in dataset.coords: + encoding[coord_name] = {"compressors": None} + + return encoding + + def _write_geo_metadata( + self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref" + ) -> None: + """Write geographic metadata to the dataset.""" + # Implementation same as original + crs = None + for var in dataset.data_vars.values(): + if hasattr(var, "rio") and var.rio.crs: + crs = var.rio.crs + break + elif "proj:epsg" in var.attrs: + epsg = var.attrs["proj:epsg"] + crs = CRS.from_epsg(epsg) + break + + if crs is not None: + dataset.rio.write_crs( + crs, grid_mapping_name=grid_mapping_var_name, inplace=True + ) + dataset.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) + + for var in dataset.data_vars.values(): + var.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) From 0307d0d95a45218cfd6e65d73140fde0d9976048 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 10:37:14 +0200 Subject: [PATCH 40/83] feat: add --enable-streaming option for experimental streaming mode in S2 optimization command --- src/eopf_geozarr/cli.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/eopf_geozarr/cli.py b/src/eopf_geozarr/cli.py index 768ac646..d1cdc141 100644 --- a/src/eopf_geozarr/cli.py +++ b/src/eopf_geozarr/cli.py @@ -1201,6 +1201,11 @@ def add_s2_optimization_commands(subparsers): action="store_true", help="Start a local dask cluster for parallel processing and progress bars", ) + s2_parser.add_argument( + "--enable-streaming", + action="store_true", + help="Enable streaming mode for large datasets (experimental)", + ) s2_parser.set_defaults(func=convert_s2_optimized_command) @@ -1233,6 +1238,7 @@ def convert_s2_optimized_command(args): create_meteorology_group=not args.skip_meteorology, validate_output=not args.skip_validation, verbose=args.verbose, + enable_streaming=args.enable_streaming, ) print(f"✅ S2 optimization completed: {args.output_path}") From 1d6a922a17a2a111f4db4fab9be99913c5301cc0 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 11:31:05 +0200 Subject: [PATCH 41/83] fix: avoid passing coordinates in lazy dataset creation to prevent alignment issues --- .../s2_optimization/s2_multiscale_streaming.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index df57db88..8a9dbc3a 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -271,8 +271,9 @@ def _create_lazy_downsampled_dataset_from_level2( ) lazy_vars[var_name] = lazy_downsampled - # Create dataset with lazy variables - dataset = xr.Dataset(lazy_vars, coords=level_2_dataset.coords) + # Create dataset with lazy variables - don't pass coords to avoid alignment issues + # The coordinates will be computed when the lazy operations are executed + dataset = xr.Dataset(lazy_vars) dataset.attrs["pyramid_level"] = level dataset.attrs["resolution_meters"] = target_resolution @@ -370,7 +371,8 @@ def downsample_operation(): dtype=source_data.dtype ).rechunk(chunks) - # Return as xarray DataArray with lazy data + # Return as xarray DataArray with lazy data - no coords to avoid alignment issues + # Coordinates will be set when the lazy operation is computed return xr.DataArray( dask_array, dims=source_data.dims, From c32aaef185dc43381de17c2075f71c384426ba77 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 11:34:07 +0200 Subject: [PATCH 42/83] feat: implement Zarr v3 compatible encoding for optimized datasets --- src/eopf_geozarr/s2_optimization/s2_converter.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index c2a2f18d..da76936b 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -241,6 +241,20 @@ def _write_measurements_group( if verbose: print(f" Multiscales metadata added with {len(multiscales_attrs.get('tile_matrix_set', {}).get('matrices', []))} levels") + # Create proper Zarr v3 compatible encoding + encoding = {} + from zarr.codecs import BloscCodec + + # Add encoding for any variables that might need it + for level, ds in pyramid_datasets.items(): + if ds is not None: + for var_name in ds.data_vars: + encoding[var_name] = { + "compressors": [BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0)] + } + for coord_name in ds.coords: + encoding[coord_name] = {"compressors": None} + # Write the measurements group with consolidation storage_options = get_storage_options(group_path) measurements_group.to_zarr( @@ -248,7 +262,7 @@ def _write_measurements_group( mode='w', consolidated=True, zarr_format=3, - encoding={}, # Encoding handled at individual dataset level + encoding=encoding, # Use proper Zarr v3 encoding storage_options=storage_options, compute=True # Direct compute for simplicity ) From 2ccb11079c0a4d8ddf48bd97ea117ef7c9c2d27a Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 11:39:53 +0200 Subject: [PATCH 43/83] fix: enhance measurements group writing by consolidating metadata and improving Zarr group handling --- .../s2_optimization/s2_converter.py | 70 ++++++++++--------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index da76936b..6c7b8d3a 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -227,47 +227,49 @@ def _write_measurements_group( group_name: str, verbose: bool ) -> None: - """Write measurements group with pyramid datasets.""" + """Write measurements group metadata and consolidate all level metadata.""" + import zarr + import os + group_path = f"{group_name}" - measurements_group = xr.DataTree() - for level, ds in pyramid_datasets.items(): - if ds is not None: - measurements_group[f"{level}"] = ds - + print(" Creating measurements group with consolidated metadata...") + + # Create multiscales metadata multiscales_attrs = self._create_multiscales_metadata_with_rio(pyramid_datasets) + + # Get storage options + storage_options = get_storage_options(group_path) + + # Open/create the measurements group + if storage_options: + store = zarr.storage.FSStore(group_path, **storage_options) + else: + store = group_path + + # Create or open the measurements group + if not os.path.exists(group_path): + group = zarr.open_group(store, mode='w') + else: + group = zarr.open_group(store, mode='r+') + + # Add multiscales metadata if multiscales_attrs: - measurements_group.attrs['multiscales'] = [multiscales_attrs] + group.attrs['multiscales'] = [multiscales_attrs] if verbose: - print(f" Multiscales metadata added with {len(multiscales_attrs.get('tile_matrix_set', {}).get('matrices', []))} levels") - - # Create proper Zarr v3 compatible encoding - encoding = {} - from zarr.codecs import BloscCodec + num_levels = len(multiscales_attrs.get('tile_matrix_set', {}).get('matrices', [])) + print(f" Multiscales metadata added with {num_levels} levels") - # Add encoding for any variables that might need it - for level, ds in pyramid_datasets.items(): - if ds is not None: - for var_name in ds.data_vars: - encoding[var_name] = { - "compressors": [BloscCodec(cname="zstd", clevel=3, shuffle="shuffle", blocksize=0)] - } - for coord_name in ds.coords: - encoding[coord_name] = {"compressors": None} - - # Write the measurements group with consolidation - storage_options = get_storage_options(group_path) - measurements_group.to_zarr( - group_path, - mode='w', - consolidated=True, - zarr_format=3, - encoding=encoding, # Use proper Zarr v3 encoding - storage_options=storage_options, - compute=True # Direct compute for simplicity - ) + # Consolidate all level metadata into the group + print(" Consolidating metadata from all pyramid levels...") + try: + # Force consolidation of the entire measurements tree + zarr.consolidate_metadata(store) + print(" ✅ Measurements group metadata consolidated") + except Exception as e: + print(f" ⚠️ Warning: Metadata consolidation failed: {e}") - return measurements_group + return None def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.Dataset]) -> Dict: """Create multiscales metadata using rioxarray .rio accessor, following geozarr.py format.""" From a4952e74091f986980fe74f7f28807bec4ebc7f4 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 11:54:49 +0200 Subject: [PATCH 44/83] feat: enhance streaming write with advanced chunking and sharding support --- .../s2_multiscale_streaming.py | 77 +++++++++++++++++-- 1 file changed, 71 insertions(+), 6 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index 8a9dbc3a..7d5cc295 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -384,10 +384,10 @@ def _stream_write_lazy_dataset( self, lazy_dataset: xr.Dataset, level_path: str, level: int ) -> None: """ - Stream write a lazy dataset - computation happens during write. + Stream write a lazy dataset with advanced chunking and sharding. This is where the magic happens: all the lazy downsampling operations - are executed as the data is streamed to storage. + are executed as the data is streamed to storage with optimal performance. """ import os @@ -396,25 +396,90 @@ def _stream_write_lazy_dataset( print(f" Level path {level_path} already exists. Skipping write.") return - # Create encoding for streaming write + # Create advanced encoding for streaming write encoding = self._create_level_encoding(lazy_dataset, level) print(f" Streaming computation and write to {level_path}") print(f" Variables: {list(lazy_dataset.data_vars.keys())}") - # Write with streaming computation + # Rechunk dataset to align with encoding when sharding is enabled + if self.enable_sharding: + lazy_dataset = self._rechunk_dataset_for_encoding(lazy_dataset, encoding) + + # Write with streaming computation and progress tracking # The to_zarr operation will trigger all lazy computations - lazy_dataset.to_zarr( + write_job = lazy_dataset.to_zarr( level_path, mode="w", consolidated=True, zarr_format=3, encoding=encoding, - compute=True, # This triggers the lazy computation during write + compute=False, # Create job first for progress tracking ) + write_job = write_job.persist() + + # Show progress bar if distributed is available + if DISTRIBUTED_AVAILABLE: + try: + distributed.progress(write_job, notebook=False) + except Exception as e: + print(f" Warning: Could not display progress bar: {e}") + write_job.compute() + else: + print(" Writing zarr file...") + write_job.compute() print(f" ✅ Streaming write complete for level {level}") + def _rechunk_dataset_for_encoding( + self, dataset: xr.Dataset, encoding: Dict + ) -> xr.Dataset: + """ + Rechunk dataset variables to align with sharding dimensions when sharding is enabled. + + When using Zarr v3 sharding, Dask chunks must align with shard dimensions to avoid + checksum validation errors. + """ + rechunked_vars = {} + + for var_name, var_data in dataset.data_vars.items(): + if var_name in encoding: + var_encoding = encoding[var_name] + + # If sharding is enabled, rechunk based on shard dimensions + if "shards" in var_encoding and var_encoding["shards"] is not None: + target_chunks = var_encoding[ + "shards" + ] # Use shard dimensions for rechunking + elif "chunks" in var_encoding: + target_chunks = var_encoding[ + "chunks" + ] # Fallback to chunk dimensions + else: + # No specific chunking needed, use original variable + rechunked_vars[var_name] = var_data + continue + + # Create chunk dict using the actual dimensions of the variable + var_dims = var_data.dims + chunk_dict = {} + for i, dim in enumerate(var_dims): + if i < len(target_chunks): + chunk_dict[dim] = target_chunks[i] + + # Rechunk the variable to match the target dimensions + rechunked_vars[var_name] = var_data.chunk(chunk_dict) + else: + # No specific chunking needed, use original variable + rechunked_vars[var_name] = var_data + + # Create new dataset with rechunked variables, preserving coordinates + rechunked_dataset = xr.Dataset( + rechunked_vars, coords=dataset.coords, attrs=dataset.attrs + ) + + return rechunked_dataset + def _create_level_0_dataset(self, measurements_by_resolution: Dict) -> xr.Dataset: """Create level 0 dataset with only native 10m data (no lazy operations needed).""" if 10 not in measurements_by_resolution: From 9d97eee31b97353e5627d04eeb94b13c6df13787 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 11:57:08 +0200 Subject: [PATCH 45/83] feat: enhance encoding for streaming writes with advanced chunking and sharding support --- .../s2_multiscale_streaming.py | 107 ++++++++++++++---- 1 file changed, 83 insertions(+), 24 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index 7d5cc295..3bfca5dc 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -513,41 +513,47 @@ def _create_multiscale_measurements_sequential( return {} def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: - """Create encoding optimized for streaming writes.""" + """Create optimized encoding for a pyramid level with advanced chunking and sharding.""" encoding = {} - - # Calculate level-appropriate chunk sizes for streaming + + # Calculate level-appropriate chunk sizes chunk_size = max(256, self.spatial_chunk // (2**level)) for var_name, var_data in dataset.data_vars.items(): - if hasattr(var_data.data, 'chunks'): - # Use existing chunks from Dask array - chunks = var_data.data.chunks - if len(chunks) >= 2: - # Convert chunk tuples to sizes - encoding_chunks = tuple(chunks[i][0] for i in range(len(chunks))) + if var_data.ndim >= 2: + height, width = var_data.shape[-2:] + + # Use advanced aligned chunk calculation + spatial_chunk_aligned = min( + chunk_size, + self._calculate_aligned_chunk_size(width, chunk_size), + self._calculate_aligned_chunk_size(height, chunk_size), + ) + + if var_data.ndim == 3: + # Single file per variable per time: chunk time dimension to 1 + chunks = (1, spatial_chunk_aligned, spatial_chunk_aligned) else: - encoding_chunks = (chunk_size,) + chunks = (spatial_chunk_aligned, spatial_chunk_aligned) else: - # Fallback chunk calculation - if var_data.ndim >= 2: - if var_data.ndim == 3: - encoding_chunks = (1, chunk_size, chunk_size) - else: - encoding_chunks = (chunk_size, chunk_size) - else: - encoding_chunks = (min(chunk_size, var_data.shape[0]),) + chunks = (min(chunk_size, var_data.shape[0]),) - # Configure encoding for streaming + # Configure encoding - use proper compressor following geozarr.py pattern from zarr.codecs import BloscCodec - + compressor = BloscCodec( cname="zstd", clevel=3, shuffle="shuffle", blocksize=0 ) - encoding[var_name] = { - "chunks": encoding_chunks, - "compressors": [compressor] - } + var_encoding = {"chunks": chunks, "compressors": [compressor]} + + # Add advanced sharding if enabled - shards match x/y dimensions exactly + if self.enable_sharding and var_data.ndim >= 2: + shard_dims = self._calculate_simple_shard_dimensions( + var_data.shape, chunks + ) + var_encoding["shards"] = shard_dims + + encoding[var_name] = var_encoding # Add coordinate encoding for coord_name in dataset.coords: @@ -555,6 +561,59 @@ def _create_level_encoding(self, dataset: xr.Dataset, level: int) -> Dict: return encoding + def _calculate_aligned_chunk_size( + self, dimension_size: int, target_chunk: int + ) -> int: + """ + Calculate aligned chunk size following geozarr.py logic. + + This ensures good chunk alignment without complex calculations. + """ + if target_chunk >= dimension_size: + return dimension_size + + # Find the largest divisor of dimension_size that's close to target_chunk + best_chunk = target_chunk + for chunk_candidate in range(target_chunk, max(target_chunk // 2, 1), -1): + if dimension_size % chunk_candidate == 0: + best_chunk = chunk_candidate + break + + return best_chunk + + def _calculate_simple_shard_dimensions( + self, data_shape: tuple, chunks: tuple + ) -> tuple: + """ + Calculate shard dimensions that are compatible with chunk dimensions. + + Shard dimensions must be evenly divisible by chunk dimensions for Zarr v3. + When possible, shards should match x/y dimensions exactly as required. + """ + shard_dims = [] + + for i, (dim_size, chunk_size) in enumerate(zip(data_shape, chunks)): + if i == 0 and len(data_shape) == 3: + # First dimension in 3D data (time) - use single time slice per shard + shard_dims.append(1) + else: + # For x/y dimensions, try to use full dimension size + # But ensure it's divisible by chunk size + if dim_size % chunk_size == 0: + # Perfect: full dimension is divisible by chunk + shard_dims.append(dim_size) + else: + # Find the largest multiple of chunk_size that fits + num_chunks = dim_size // chunk_size + if num_chunks > 0: + shard_size = num_chunks * chunk_size + shard_dims.append(shard_size) + else: + # Fallback: use chunk size itself + shard_dims.append(chunk_size) + + return tuple(shard_dims) + def _write_geo_metadata( self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref" ) -> None: From 76cde29bec2d0cd531c9343178926ae65e3ebed2 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 15:59:05 +0200 Subject: [PATCH 46/83] fix: improve root-level metadata consolidation with proper Zarr group creation and linking --- .../s2_optimization/s2_converter.py | 77 ++++++++++--------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 6c7b8d3a..ba682f34 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -352,47 +352,54 @@ def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.D def _simple_root_consolidation( self, output_path: str, pyramid_datasets: Dict[int, xr.Dataset] ) -> None: - """Simple root-level metadata consolidation using only xarray.""" + """Simple root-level metadata consolidation with proper zarr group creation.""" try: - # Since each level and auxiliary group was written with consolidated=True, - # we just need to create a simple root-level consolidated metadata - print(" Performing simple root consolidation...") - - # Use xarray to open and immediately close the root group with consolidation - # This creates/updates the root .zmetadata file + print(" Performing root consolidation...") storage_options = get_storage_options(output_path) - # Open the root zarr group and let xarray handle consolidation + # First, ensure the root zarr group exists + import zarr + import os + + if storage_options: + store = zarr.storage.FSStore(output_path, **storage_options) + else: + store = output_path + + # Create root zarr group if it doesn't exist + if not os.path.exists(os.path.join(output_path, 'zarr.json')): + print(" Creating root zarr group...") + root_group = zarr.open_group(store, mode='w') + root_group.attrs.update({ + "title": "Optimized Sentinel-2 Dataset", + "description": "Multiscale pyramid structure for efficient access", + "zarr_format": 3 + }) + else: + root_group = zarr.open_group(store, mode='r+') + + # Ensure subgroups are properly linked + if self.enable_streaming: + # In streaming mode, link existing subgroups + for subgroup in ['measurements', 'geometry', 'meteorology']: + subgroup_path = os.path.join(output_path, subgroup) + if os.path.exists(subgroup_path): + try: + if subgroup not in root_group: + # Link the subgroup to the root + subgroup_obj = zarr.open_group(subgroup_path, mode='r') + # Copy attributes to root group reference + root_group.attrs[f"{subgroup}_info"] = f"Subgroup: {subgroup}" + except Exception as e: + print(f" Warning: Could not link subgroup {subgroup}: {e}") + + # Consolidate metadata try: - # This will create consolidated metadata at the root level - with xr.open_zarr( - output_path, - storage_options=storage_options, - consolidated=True, - chunks={}, - ) as root_ds: - # Just opening and closing with consolidated=True should be enough - pass + zarr.consolidate_metadata(store) print(" ✅ Root consolidation completed") except Exception as e: - print( - f" ⚠️ Root consolidation using xarray failed, trying zarr directly: {e}" - ) - - # Fallback: minimal zarr consolidation if needed - import zarr - - store = ( - zarr.storage.FSStore(output_path, **storage_options) - if storage_options - else output_path - ) - try: - zarr.consolidate_metadata(store) - print(" ✅ Root consolidation completed with zarr") - except Exception as e2: - print(f" ⚠️ Warning: Root consolidation failed: {e2}") - + print(f" ⚠️ Warning: Metadata consolidation failed: {e}") + except Exception as e: print(f" ⚠️ Warning: Root consolidation failed: {e}") From 52e516dc50caa0dc0d930b82ec6ac666460e3ccd Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 14:03:19 +0000 Subject: [PATCH 47/83] feat: add streaming support to S2 optimized converter and update measurements group handling --- .vscode/launch.json | 1 + src/eopf_geozarr/s2_optimization/s2_converter.py | 14 +++++--------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 55ebab49..97cdb407 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -149,6 +149,7 @@ "--compression-level", "5", "--enable-sharding", "--dask-cluster", + "--enable-streaming", "--verbose" ], "cwd": "${workspaceFolder}", diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index 6c7b8d3a..cc0e12b2 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -129,7 +129,8 @@ def convert_s2_optimized( # Step 5: Create measurements group and add multiscales metadata print("Step 5: Creating measurements group...") - measurement_dt = self._write_measurements_group(pyramid_datasets, "measurements", verbose) + measurement_path = f"{output_path}/measurements" + measurement_dt = self._write_measurements_group(measurement_path, pyramid_datasets, verbose) # Step 6: Simple root-level consolidation print("Step 6: Final root-level metadata consolidation...") @@ -223,16 +224,14 @@ def _write_auxiliary_group( def _write_measurements_group( self, + group_path: str, pyramid_datasets: Dict[int, xr.Dataset], - group_name: str, verbose: bool ) -> None: """Write measurements group metadata and consolidate all level metadata.""" import zarr import os - - group_path = f"{group_name}" - + print(" Creating measurements group with consolidated metadata...") # Create multiscales metadata @@ -248,10 +247,7 @@ def _write_measurements_group( store = group_path # Create or open the measurements group - if not os.path.exists(group_path): - group = zarr.open_group(store, mode='w') - else: - group = zarr.open_group(store, mode='r+') + group = zarr.open_group(store, mode='a') # Add multiscales metadata if multiscales_attrs: From 5e5825114d35a5dbd383707c9fd8c56f73fc8dc4 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 14:11:21 +0000 Subject: [PATCH 48/83] fix: change root Zarr group creation mode from 'w' to 'a' for appending data --- src/eopf_geozarr/s2_optimization/s2_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index fe037be0..c653a8dc 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -365,7 +365,7 @@ def _simple_root_consolidation( # Create root zarr group if it doesn't exist if not os.path.exists(os.path.join(output_path, 'zarr.json')): print(" Creating root zarr group...") - root_group = zarr.open_group(store, mode='w') + root_group = zarr.open_group(store, mode='a') root_group.attrs.update({ "title": "Optimized Sentinel-2 Dataset", "description": "Multiscale pyramid structure for efficient access", From 8250bc7a9a12cfd80b950df7b3f2a804adfc6849 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 21:39:15 +0200 Subject: [PATCH 49/83] refactor: streamline Zarr group handling and metadata consolidation in S2 converter --- .../s2_optimization/s2_converter.py | 72 +++++++------------ 1 file changed, 26 insertions(+), 46 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index c653a8dc..a6d5b7d0 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -6,7 +6,9 @@ from typing import Dict import xarray as xr +from zarr import consolidate_metadata +from eopf_geozarr.conversion import fs_utils from eopf_geozarr.conversion.fs_utils import get_storage_options from eopf_geozarr.conversion.geozarr import ( _create_tile_matrix_limits, @@ -236,19 +238,10 @@ def _write_measurements_group( # Create multiscales metadata multiscales_attrs = self._create_multiscales_metadata_with_rio(pyramid_datasets) - - # Get storage options - storage_options = get_storage_options(group_path) - - # Open/create the measurements group - if storage_options: - store = zarr.storage.FSStore(group_path, **storage_options) - else: - store = group_path # Create or open the measurements group - group = zarr.open_group(store, mode='a') - + group = fs_utils.open_zarr_group(group_path, mode='a') + # Add multiscales metadata if multiscales_attrs: group.attrs['multiscales'] = [multiscales_attrs] @@ -260,7 +253,7 @@ def _write_measurements_group( print(" Consolidating metadata from all pyramid levels...") try: # Force consolidation of the entire measurements tree - zarr.consolidate_metadata(store) + consolidate_metadata(group.store) print(" ✅ Measurements group metadata consolidated") except Exception as e: print(f" ⚠️ Warning: Metadata consolidation failed: {e}") @@ -351,47 +344,34 @@ def _simple_root_consolidation( """Simple root-level metadata consolidation with proper zarr group creation.""" try: print(" Performing root consolidation...") - storage_options = get_storage_options(output_path) - - # First, ensure the root zarr group exists - import zarr - import os - - if storage_options: - store = zarr.storage.FSStore(output_path, **storage_options) - else: - store = output_path # Create root zarr group if it doesn't exist - if not os.path.exists(os.path.join(output_path, 'zarr.json')): - print(" Creating root zarr group...") - root_group = zarr.open_group(store, mode='a') - root_group.attrs.update({ - "title": "Optimized Sentinel-2 Dataset", - "description": "Multiscale pyramid structure for efficient access", - "zarr_format": 3 - }) - else: - root_group = zarr.open_group(store, mode='r+') + print(" Creating root zarr group...") + root_group = fs_utils.open_zarr_group(output_path, mode="a") + root_group.attrs.update({ + "title": "Optimized Sentinel-2 Dataset", + "description": "Multiscale pyramid structure for efficient access", + "zarr_format": 3 + }) # Ensure subgroups are properly linked - if self.enable_streaming: - # In streaming mode, link existing subgroups - for subgroup in ['measurements', 'geometry', 'meteorology']: - subgroup_path = os.path.join(output_path, subgroup) - if os.path.exists(subgroup_path): - try: - if subgroup not in root_group: - # Link the subgroup to the root - subgroup_obj = zarr.open_group(subgroup_path, mode='r') - # Copy attributes to root group reference - root_group.attrs[f"{subgroup}_info"] = f"Subgroup: {subgroup}" - except Exception as e: - print(f" Warning: Could not link subgroup {subgroup}: {e}") + # if self.enable_streaming: + # # In streaming mode, link existing subgroups + # for subgroup in ['measurements', 'geometry', 'meteorology']: + # subgroup_path = os.path.join(output_path, subgroup) + # if os.path.exists(subgroup_path): + # try: + # if subgroup not in root_group: + # # Link the subgroup to the root + # subgroup_obj = zarr.open_group(subgroup_path, mode='r') + # # Copy attributes to root group reference + # root_group.attrs[f"{subgroup}_info"] = f"Subgroup: {subgroup}" + # except Exception as e: + # print(f" Warning: Could not link subgroup {subgroup}: {e}") # Consolidate metadata try: - zarr.consolidate_metadata(store) + consolidate_metadata(root_group.store) print(" ✅ Root consolidation completed") except Exception as e: print(f" ⚠️ Warning: Metadata consolidation failed: {e}") From deff685ef319dbc63eda01ac96ce6c9454c25d49 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 19:39:37 +0000 Subject: [PATCH 50/83] fix: streamline root Zarr group creation by removing existence check and ensuring proper attributes are set --- .vscode/launch.json | 4 ++-- .../s2_optimization/s2_converter.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 97cdb407..ca714cd2 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -143,8 +143,8 @@ "args": [ "convert-s2-optimized", "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/08/products/cpm_v256/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", - // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", - "./tests-output/eopf_geozarr/s2l2_optimized.zarr", + "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + // "./tests-output/eopf_geozarr/s2l2_optimized.zarr", "--spatial-chunk", "256", "--compression-level", "5", "--enable-sharding", diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index c653a8dc..a0ee5def 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -363,16 +363,13 @@ def _simple_root_consolidation( store = output_path # Create root zarr group if it doesn't exist - if not os.path.exists(os.path.join(output_path, 'zarr.json')): - print(" Creating root zarr group...") - root_group = zarr.open_group(store, mode='a') - root_group.attrs.update({ - "title": "Optimized Sentinel-2 Dataset", - "description": "Multiscale pyramid structure for efficient access", - "zarr_format": 3 - }) - else: - root_group = zarr.open_group(store, mode='r+') + print(" Creating root zarr group...") + root_group = zarr.open_group(store, mode='a') + root_group.attrs.update({ + "title": "Optimized Sentinel-2 Dataset", + "description": "Multiscale pyramid structure for efficient access", + "zarr_format": 3 + }) # Ensure subgroups are properly linked if self.enable_streaming: From cb4ada149d58377152399744137911053cb225ef Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 19:58:59 +0000 Subject: [PATCH 51/83] fix: correct multiscales attribute assignment and update group prefix handling refactor: replace os.path.exists with fs_utils.path_exists for level path check --- src/eopf_geozarr/s2_optimization/s2_converter.py | 4 ++-- src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/eopf_geozarr/s2_optimization/s2_converter.py b/src/eopf_geozarr/s2_optimization/s2_converter.py index a6d5b7d0..da9cc5b5 100644 --- a/src/eopf_geozarr/s2_optimization/s2_converter.py +++ b/src/eopf_geozarr/s2_optimization/s2_converter.py @@ -244,7 +244,7 @@ def _write_measurements_group( # Add multiscales metadata if multiscales_attrs: - group.attrs['multiscales'] = [multiscales_attrs] + group.attrs['multiscales'] = multiscales_attrs if verbose: num_levels = len(multiscales_attrs.get('tile_matrix_set', {}).get('matrices', [])) print(f" Multiscales metadata added with {num_levels} levels") @@ -315,7 +315,7 @@ def _create_multiscales_metadata_with_rio(self, pyramid_datasets: Dict[int, xr.D native_crs, native_bounds, overview_levels, - "measurements", # group prefix + "", # group prefix ) # Create tile matrix limits following geozarr.py exactly diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index 3bfca5dc..9ea66645 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -8,6 +8,8 @@ import xarray as xr from pyproj import CRS +from eopf_geozarr.conversion import fs_utils + from .s2_resampling import S2ResamplingEngine, determine_variable_type try: @@ -389,10 +391,9 @@ def _stream_write_lazy_dataset( This is where the magic happens: all the lazy downsampling operations are executed as the data is streamed to storage with optimal performance. """ - import os # Check if level already exists - if os.path.exists(level_path): + if fs_utils.path_exists(level_path): print(f" Level path {level_path} already exists. Skipping write.") return From 16c245b86803dc92ed2261f4041c2df053f877e1 Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Tue, 30 Sep 2025 22:25:34 +0200 Subject: [PATCH 52/83] feat: add downsampled coordinates creation for multiscale pyramid levels --- .vscode/launch.json | 3 +- .../s2_multiscale_streaming.py | 46 +++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index ca714cd2..e1bca17f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -267,7 +267,8 @@ "module": "eopf_geozarr", "args": [ "info", - "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a/S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr", + "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", "--verbose", "--html-output", "dataset_info.html" ], diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index 9ea66645..67b9d383 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -265,6 +265,11 @@ def _create_lazy_downsampled_dataset_from_level2( target_height = current_height // downsample_factor target_width = current_width // downsample_factor + # Create downsampled coordinates from level 2 + downsampled_coords = self._create_downsampled_coordinates( + level_2_dataset, target_height, target_width, downsample_factor + ) + # Create lazy downsampling operations for all variables lazy_vars = {} for var_name, var_data in level_2_dataset.data_vars.items(): @@ -273,9 +278,8 @@ def _create_lazy_downsampled_dataset_from_level2( ) lazy_vars[var_name] = lazy_downsampled - # Create dataset with lazy variables - don't pass coords to avoid alignment issues - # The coordinates will be computed when the lazy operations are executed - dataset = xr.Dataset(lazy_vars) + # Create dataset with lazy variables AND proper coordinates + dataset = xr.Dataset(lazy_vars, coords=downsampled_coords) dataset.attrs["pyramid_level"] = level dataset.attrs["resolution_meters"] = target_resolution @@ -615,6 +619,42 @@ def _calculate_simple_shard_dimensions( return tuple(shard_dims) + def _create_downsampled_coordinates( + self, level_2_dataset: xr.Dataset, target_height: int, target_width: int, downsample_factor: int + ) -> Dict: + """Create downsampled coordinates for higher pyramid levels.""" + import numpy as np + + # Get original coordinates from level 2 + if 'x' not in level_2_dataset.coords or 'y' not in level_2_dataset.coords: + return {} + + x_coords_orig = level_2_dataset.coords['x'].values + y_coords_orig = level_2_dataset.coords['y'].values + + # Calculate downsampled coordinates by taking every nth point + # where n is the downsample_factor + x_coords_downsampled = x_coords_orig[::downsample_factor][:target_width] + y_coords_downsampled = y_coords_orig[::downsample_factor][:target_height] + + # Create coordinate dictionary with proper attributes + coords = {} + + # Copy x coordinate with attributes + x_attrs = level_2_dataset.coords['x'].attrs.copy() + coords['x'] = (['x'], x_coords_downsampled, x_attrs) + + # Copy y coordinate with attributes + y_attrs = level_2_dataset.coords['y'].attrs.copy() + coords['y'] = (['y'], y_coords_downsampled, y_attrs) + + # Copy any other coordinates that might exist + for coord_name, coord_data in level_2_dataset.coords.items(): + if coord_name not in ['x', 'y']: + coords[coord_name] = coord_data + + return coords + def _write_geo_metadata( self, dataset: xr.Dataset, grid_mapping_var_name: str = "spatial_ref" ) -> None: From 42a72fbcf2c589066471191e689a4561e96b2b9e Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Wed, 1 Oct 2025 05:57:43 +0000 Subject: [PATCH 53/83] fix: update launch configuration for S2A MSIL2A dataset and adjust grid mapping attributes in streaming pyramid creation --- .vscode/launch.json | 4 ++-- src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index e1bca17f..cc3f0157 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -143,8 +143,8 @@ "args": [ "convert-s2-optimized", "https://objects.eodc.eu/e05ab01a9d56408d82ac32d69a5aae2a:202509-s02msil2a/08/products/cpm_v256/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", - "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", - // "./tests-output/eopf_geozarr/s2l2_optimized.zarr", + // "s3://esa-zarr-sentinel-explorer-fra/tests-output/sentinel-2-l2a-opt/S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr", + "./tests-output/eopf_geozarr/s2l2_optimized.zarr", "--spatial-chunk", "256", "--compression-level", "5", "--enable-sharding", diff --git a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py index 67b9d383..1992777d 100644 --- a/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py +++ b/src/eopf_geozarr/s2_optimization/s2_multiscale_streaming.py @@ -112,10 +112,6 @@ def _create_streaming_measurements_lazy( level_path = f"{output_path}/measurements/{level}" print(f" Streaming write of level {level} to {level_path}") self._stream_write_lazy_dataset(lazy_dataset, level_path, level) - - # For levels 3+, we can discard after writing to save memory - if level > 2: - pyramid_datasets[level] = None else: print(f" Skipping empty level {level}") @@ -675,6 +671,8 @@ def _write_geo_metadata( crs, grid_mapping_name=grid_mapping_var_name, inplace=True ) dataset.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) + dataset.attrs["grid_mapping"] = grid_mapping_var_name for var in dataset.data_vars.values(): var.rio.write_grid_mapping(grid_mapping_var_name, inplace=True) + var.attrs["grid_mapping"] = grid_mapping_var_name From e16d9f7f5a0789b0a15039e059d94c5dc884ebea Mon Sep 17 00:00:00 2001 From: Emmanuel Mathot Date: Wed, 1 Oct 2025 08:01:13 +0200 Subject: [PATCH 54/83] Refactor downsample factor calculation in S2StreamingMultiscalePyramid Updated the downsample factor calculation to use resolution ratios from pyramid_levels. This change improves clarity by explicitly referencing the resolutions of level 2 and the target level, ensuring accurate downsampling based on the defined pyramid structure. --- dataset_info.html | 29271 +++++----------- .../s2_multiscale_streaming.py | 6 +- 2 files changed, 8850 insertions(+), 20427 deletions(-) diff --git a/dataset_info.html b/dataset_info.html index 6a36a8cb..b0307145 100644 --- a/dataset_info.html +++ b/dataset_info.html @@ -4,7 +4,7 @@ - DataTree Visualization - S2A_MSIL2A_20250704T094051_N0511_R036_T33SWB_20250704T115824.zarr + DataTree Visualization - S2A_MSIL2A_20250908T100041_N0511_R122_T32TQM_20250908T115116.zarr
<xarray.Dataset> Size: 670kB
-Dimensions:                        (angle: 2, band: 13, x: 23, y: 23,
+Dimensions:                        (angle: 2, x: 23, y: 23, band: 13,
                                     detector: 6)
 Coordinates:
   * angle                          (angle) <U7 56B 'zenith' 'azimuth'
+  * x                              (x) int64 184B 699960 704960 ... 809960
+  * y                              (y) int64 184B 4700040 4695040 ... 4590040
   * band                           (band) <U3 156B 'b01' 'b02' ... 'b11' 'b12'
-  * x                              (x) int64 184B 499980 504980 ... 609980
-  * y                              (y) int64 184B 4200000 4195000 ... 4090000
   * detector                       (detector) int64 48B 1 2 3 4 5 6
 Data variables:
     mean_sun_angles                (angle) float64 16B dask.array<chunksize=(2,), meta=np.ndarray>
-    mean_viewing_incidence_angles  (band, angle) float64 208B dask.array<chunksize=(13, 2), meta=np.ndarray>
-    spatial_ref                    int64 8B ...
     sun_angles                     (angle, y, x) float64 8kB dask.array<chunksize=(2, 23, 23), meta=np.ndarray>
-    viewing_incidence_angles       (band, detector, angle, y, x) float64 660kB dask.array<chunksize=(13, 6, 2, 23, 23), meta=np.ndarray>
+ 4590810, 4590750, 4590690, 4590630, 4590570, 4590510, 4590450, 4590390, + 4590330, 4590270], + dtype='int64', name='y', length=1830))
  • @@ -17409,8 +4078,13 @@

    Attributes

    - grid_mapping: - spatial_ref + pyramid_level: + 2 +
    + +
    + resolution_meters: + 60
    @@ -17420,12 +4094,12 @@

    Attributes

    -
    +
    📄 1 - (11 variables • 1 attributes) + (36 variables • 2 attributes)
    @@ -17792,108 +4466,393 @@

    Variables

    padding: 0 1px; } -.xr-var-attrs-in:checked ~ .xr-var-attrs, -.xr-var-data-in:checked ~ .xr-var-data, -.xr-index-data-in:checked ~ .xr-index-data { - display: block; -} +.xr-var-attrs-in:checked ~ .xr-var-attrs, +.xr-var-data-in:checked ~ .xr-var-data, +.xr-index-data-in:checked ~ .xr-index-data { + display: block; +} + +.xr-var-data > table { + float: right; +} + +.xr-var-data > pre, +.xr-index-data > pre, +.xr-var-data > table > tbody > tr { + background-color: transparent !important; +} + +.xr-var-name span, +.xr-var-data, +.xr-index-name div, +.xr-index-data, +.xr-attrs { + padding-left: 25px !important; +} + +.xr-attrs, +.xr-var-attrs, +.xr-var-data, +.xr-index-data { + grid-column: 1 / -1; +} + +dl.xr-attrs { + padding: 0; + margin: 0; + display: grid; + grid-template-columns: 125px auto; +} + +.xr-attrs dt, +.xr-attrs dd { + padding: 0; + margin: 0; + float: left; + padding-right: 10px; + width: auto; +} + +.xr-attrs dt { + font-weight: normal; + grid-column: 1; +} + +.xr-attrs dt:hover span { + display: inline-block; + background: var(--xr-background-color); + padding-right: 10px; +} + +.xr-attrs dd { + grid-column: 2; + white-space: pre-wrap; + word-break: break-all; +} + +.xr-icon-database, +.xr-icon-file-text2, +.xr-no-icon { + display: inline-block; + vertical-align: middle; + width: 1em; + height: 1.5em !important; + stroke-width: 0; + stroke: currentColor; + fill: currentColor; +} + +.xr-var-attrs-in:checked + label > .xr-icon-file-text2, +.xr-var-data-in:checked + label > .xr-icon-database, +.xr-index-data-in:checked + label > .xr-icon-database { + color: var(--xr-font-color0); + filter: drop-shadow(1px 1px 5px var(--xr-font-color2)); + stroke-width: 0.8px; +} +
    <xarray.Dataset> Size: 3GB
    +Dimensions:                 (x: 5490, y: 5490)
    +Coordinates:
    +    band                    int64 8B 1
    +  * x                       (x) int64 44kB 699970 699990 ... 809730 809750
    +  * y                       (y) int64 44kB 4700030 4700010 ... 4590270 4590250
    +Data variables: (12/36)
    +    aot                     (y, x) uint16 60MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    b01                     (y, x) float64 241MB dask.array<chunksize=(4026, 4026), meta=np.ndarray>
    +    b02                     (y, x) float64 241MB dask.array<chunksize=(4026, 4026), meta=np.ndarray>
    +    b03                     (y, x) float64 241MB dask.array<chunksize=(4026, 4026), meta=np.ndarray>
    +    b04                     (y, x) float64 241MB dask.array<chunksize=(4026, 4026), meta=np.ndarray>
    +    b05                     (y, x) float64 241MB dask.array<chunksize=(4026, 4026), meta=np.ndarray>
    +    ...                      ...
    +    quality_b11             (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    quality_b12             (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    quality_b8a             (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    scl                     (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    snw                     (y, x) uint8 30MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>
    +    wvp                     (y, x) uint16 60MB dask.array<chunksize=(5490, 5490), meta=np.ndarray>