diff --git a/changes/3559.misc.md b/changes/3559.misc.md new file mode 100644 index 0000000000..2611c3fa29 --- /dev/null +++ b/changes/3559.misc.md @@ -0,0 +1 @@ +Create `Bytes`, a new data type for variable-length bytes. This data type is a drop-in replacement for `VariableLengthBytes` that complies with the published [`Bytes`](https://github.com/zarr-developers/zarr-extensions/tree/main/data-types/bytes) data type spec. \ No newline at end of file diff --git a/examples/custom_dtype/custom_dtype.py b/examples/custom_dtype/custom_dtype.py index a98f3414f6..bd770569b2 100644 --- a/examples/custom_dtype/custom_dtype.py +++ b/examples/custom_dtype/custom_dtype.py @@ -25,11 +25,11 @@ from zarr.core.common import JSON, ZarrFormat from zarr.core.dtype import ZDType, data_type_registry from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, check_dtype_spec_v2, ) +from zarr.errors import DataTypeValidationError # This is the int2 array data type int2_dtype_cls = type(np.dtype("int2")) diff --git a/src/zarr/core/dtype/__init__.py b/src/zarr/core/dtype/__init__.py index f3077c32e5..b722675293 100644 --- a/src/zarr/core/dtype/__init__.py +++ b/src/zarr/core/dtype/__init__.py @@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Final, TypeAlias from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeJSON, ) from zarr.core.dtype.npy.bool import Bool from zarr.core.dtype.npy.bytes import ( + Bytes, + BytesJSON_V2, + BytesJSON_V3, NullTerminatedBytes, NullterminatedBytesJSON_V2, NullTerminatedBytesJSON_V3, @@ -30,6 +32,7 @@ TimeDelta64JSON_V2, TimeDelta64JSON_V3, ) +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import ZarrFormat @@ -52,8 +55,12 @@ __all__ = [ "Bool", + "Bytes", + "BytesJSON_V2", + "BytesJSON_V3", "Complex64", "Complex128", + "DTypeJSON", "DataTypeRegistry", "DataTypeValidationError", "DateTime64", @@ -94,6 +101,8 @@ "VariableLengthUTF8JSON_V2", "ZDType", "data_type_registry", + "disable_legacy_bytes_dtype", + "enable_legacy_bytes_dtype", "parse_data_type", "parse_dtype", ] @@ -115,8 +124,8 @@ TimeDType = DateTime64 | TimeDelta64 TIME_DTYPE: Final = DateTime64, TimeDelta64 -BytesDType = RawBytes | NullTerminatedBytes | VariableLengthBytes -BYTES_DTYPE: Final = RawBytes, NullTerminatedBytes, VariableLengthBytes +BytesDType = RawBytes | NullTerminatedBytes | Bytes +BYTES_DTYPE: Final = RawBytes, NullTerminatedBytes, Bytes AnyDType = ( Bool @@ -127,7 +136,6 @@ | BytesDType | Structured | TimeDType - | VariableLengthBytes ) # mypy has trouble inferring the type of variablelengthstring dtype, because its class definition # depends on the installed numpy version. That's why the type: ignore statement is needed here. @@ -140,7 +148,6 @@ *BYTES_DTYPE, Structured, *TIME_DTYPE, - VariableLengthBytes, ) # These are aliases for variable-length UTF-8 strings @@ -277,6 +284,36 @@ def parse_dtype( # If the dtype request is one of the aliases for variable-length UTF-8 strings, # return that dtype. return VariableLengthUTF8() # type: ignore[return-value] + if dtype_spec is bytes: + # Treat the bytes type as a request for the Bytes dtype + return Bytes() + # otherwise, we have either a numpy dtype string, or a zarr v3 dtype string, and in either case # we can create a native dtype from it, and do the dtype inference from that return get_data_type_from_native_dtype(dtype_spec) # type: ignore[arg-type] + + +def enable_legacy_bytes_dtype() -> None: + """ + Unregister the new Bytes data type from the registry, and replace it with the + VariableLengthBytes dtype instead. Used for backwards compatibility. + """ + if ( + "bytes" in data_type_registry.contents + and "variable_length_bytes" not in data_type_registry.contents + ): + data_type_registry.unregister("bytes") + data_type_registry.register("variable_length_bytes", VariableLengthBytes) + + +def disable_legacy_bytes_dtype() -> None: + """ + Unregister the old VariableLengthBytes dtype from the registry, and replace it with + the new Bytes dtype. Used to reverse the effect of enable_legacy_bytes_dtype + """ + if ( + "variable_length_bytes" in data_type_registry.contents + and "bytes" not in data_type_registry.contents + ): + data_type_registry.unregister("variable_length_bytes") + data_type_registry.register("bytes", Bytes) diff --git a/src/zarr/core/dtype/common.py b/src/zarr/core/dtype/common.py index 6b70f595ba..35a4b4504c 100644 --- a/src/zarr/core/dtype/common.py +++ b/src/zarr/core/dtype/common.py @@ -151,12 +151,6 @@ def unpack_dtype_json(data: DTypeSpec_V2 | DTypeSpec_V3) -> DTypeJSON: return data -class DataTypeValidationError(ValueError): ... - - -class ScalarTypeValidationError(ValueError): ... - - @dataclass(frozen=True, kw_only=True) class HasLength: """ diff --git a/src/zarr/core/dtype/npy/bool.py b/src/zarr/core/dtype/npy/bool.py index 3e7f5b72f0..f92476a455 100644 --- a/src/zarr/core/dtype/npy/bool.py +++ b/src/zarr/core/dtype/npy/bool.py @@ -6,13 +6,13 @@ import numpy as np from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasItemSize, check_dtype_spec_v2, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/bytes.py b/src/zarr/core/dtype/npy/bytes.py index cb7d86e957..b22eaa9b5a 100644 --- a/src/zarr/core/dtype/npy/bytes.py +++ b/src/zarr/core/dtype/npy/bytes.py @@ -9,7 +9,6 @@ from zarr.core.common import JSON, NamedConfig, ZarrFormat from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasItemSize, @@ -18,8 +17,9 @@ check_dtype_spec_v2, v3_unstable_dtype_warning, ) -from zarr.core.dtype.npy.common import check_json_str +from zarr.core.dtype.npy.common import check_json_array_of_ints, check_json_str from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError BytesLike = np.bytes_ | str | bytes | int @@ -142,6 +142,33 @@ class RawBytesJSON_V3(NamedConfig[Literal["raw_bytes"], FixedLengthBytesConfig]) """ +class BytesJSON_V2(DTypeConfig_V2[Literal["|O"], Literal["vlen-bytes"]]): + """ + A wrapper around the JSON representation of the `Bytes` data type in Zarr V2. + + The `name` field of this class contains the value that would appear under the + `dtype` field in Zarr V2 array metadata. The `object_codec_id` field is always `"vlen-bytes"` + + References + ---------- + The structure of the `name` field is defined in the Zarr V2 + [specification document](https://github.com/zarr-developers/zarr-specs/blob/main/docs/v2/v2.0.rst#data-type-encoding). + + Examples + -------- + ```python + { + "name": "|O", + "object_codec_id": "vlen-bytes" + } + ``` + """ + + +BytesJSON_V3 = Literal["bytes"] +"""The Zarr V3 JSON representation of the `Bytes` data type.""" + + class VariableLengthBytesJSON_V2(DTypeConfig_V2[Literal["|O"], Literal["vlen-bytes"]]): """ A wrapper around the JSON representation of the ``VariableLengthBytes`` data type in Zarr V2. @@ -165,6 +192,20 @@ class VariableLengthBytesJSON_V2(DTypeConfig_V2[Literal["|O"], Literal["vlen-byt """ +def base64_encode_bytes(data: bytes) -> str: + """ + Encode bytes into a base64-encoded string. + """ + return base64.standard_b64encode(data).decode("ascii") + + +def base64_decode_bytes(data: str) -> bytes: + """ + Decode a base64-encoded string into bytes. + """ + return base64.standard_b64decode(data.encode("ascii")) + + @dataclass(frozen=True, kw_only=True) class NullTerminatedBytes(ZDType[np.dtypes.BytesDType[int], np.bytes_], HasLength, HasItemSize): """ @@ -1175,7 +1216,7 @@ def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: str A string representation of the scalar. """ - return base64.standard_b64encode(data).decode("ascii") # type: ignore[arg-type] + return base64_encode_bytes(data) # type: ignore[arg-type] def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> bytes: """ @@ -1200,7 +1241,7 @@ def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> bytes: """ if check_json_str(data): - return base64.standard_b64decode(data.encode("ascii")) + return base64_decode_bytes(data) raise TypeError(f"Invalid type: {data}. Expected a string.") # pragma: no cover def _check_scalar(self, data: object) -> TypeGuard[BytesLike]: @@ -1277,3 +1318,352 @@ def cast_scalar(self, data: object) -> bytes: f"data type {self}." ) raise TypeError(msg) + + +@dataclass(frozen=True, kw_only=True) +class Bytes(ZDType[np.dtypes.ObjectDType, bytes], HasObjectCodec): + """ + A Zarr data type for arrays containing variable-length sequences of bytes. + + Wraps the NumPy "object" data type. Scalars for this data type are instances of ``bytes``. + + This data type inherits from `VariableLengthBytes` for backwards compatibility. + + Attributes + ---------- + dtype_cls: ClassVar[type[np.dtypes.ObjectDType]] = np.dtypes.ObjectDType + The NumPy data type wrapped by this ZDType. + _zarr_v3_name: ClassVar[Literal["bytes"]] = "bytes" + The name of this data type in Zarr V3. + object_codec_id: ClassVar[Literal["vlen-bytes"]] = "vlen-bytes" + The object codec ID for this data type. + + References + ---------- + The specification for this data type can be found at + https://github.com/zarr-developers/zarr-extensions/tree/main/data-types/bytes + + Notes + ----- + Because this data type uses the NumPy "object" data type, it does not guarantee a compact memory + representation of array data. Therefore a "vlen-bytes" codec is needed to ensure that the array + data can be persisted to storage. + """ + + dtype_cls = np.dtypes.ObjectDType + _zarr_v3_name: ClassVar[BytesJSON_V3] = "bytes" + object_codec_id: ClassVar[Literal["vlen-bytes"]] = "vlen-bytes" + + @classmethod + def from_native_dtype(cls, dtype: TBaseDType) -> Self: + """ + Create an instance of Bytes from an instance of np.dtypes.ObjectDType. + + This method checks if the provided data type is an instance of np.dtypes.ObjectDType. + If so, it returns an instance of Bytes. + + Parameters + ---------- + dtype : TBaseDType + The native dtype to convert. + + Returns + ------- + VariableLengthBytes + An instance of VariableLengthBytes. + + Raises + ------ + DataTypeValidationError + If the dtype is not compatible with VariableLengthBytes. + """ + if cls._check_native_dtype(dtype): + return cls() + raise DataTypeValidationError( + f"Invalid data type: {dtype}. Expected an instance of {cls.dtype_cls}" + ) + + def to_native_dtype(self) -> np.dtypes.ObjectDType: + """ + Create a NumPy object dtype from this VariableLengthBytes ZDType. + + Returns + ------- + np.dtypes.ObjectDType + A NumPy data type object representing variable-length bytes. + """ + return self.dtype_cls() + + @classmethod + def _check_json_v2( + cls, + data: DTypeJSON, + ) -> TypeGuard[BytesJSON_V2]: + """ + Check that the input is a valid JSON representation of a NumPy O dtype, and that the + object codec id is appropriate for variable-length byte strings. + + Parameters + ---------- + data : DTypeJSON + The JSON data to check. + + Returns + ------- + True if the input is a valid representation of this class in Zarr V2, False + otherwise. + """ + # Check that the input is a valid JSON representation of a Zarr v2 data type spec. + if not check_dtype_spec_v2(data): + return False + + # Check that the object codec id is appropriate for variable-length bytes strings. + if data["name"] != "|O": + return False + return data["object_codec_id"] == cls.object_codec_id + + @classmethod + def _check_json_v3(cls, data: DTypeJSON) -> TypeGuard[BytesJSON_V3]: + """ + Check that the input is a valid JSON representation of this class in Zarr V3. + + Parameters + ---------- + data : DTypeJSON + The JSON data to check. + + Returns + ------- + TypeGuard[Literal["Bytes"]] + True if the input is "bytes", False otherwise. + """ + + return data in (cls._zarr_v3_name, "variable_length_bytes") + + @classmethod + def _from_json_v2(cls, data: DTypeJSON) -> Self: + """ + Create an instance of Bytes from Zarr V2-flavored JSON. + + This method checks if the input data is a valid representation of this class + in Zarr V2. If so, it returns a new instance this class. + + Parameters + ---------- + data : DTypeJSON + The JSON data to parse. + + Returns + ------- + Self + An instance of this data type. + + Raises + ------ + DataTypeValidationError + If the input data is not a valid representation of this class class. + """ + + if cls._check_json_v2(data): + return cls() + msg = ( + f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string " + f"'|O' and an object_codec_id of {cls.object_codec_id}" + ) + raise DataTypeValidationError(msg) + + @classmethod + def _from_json_v3(cls, data: DTypeJSON) -> Self: + """ + Create an instance of Bytes from Zarr V3-flavored JSON. + + This method checks if the input data is a valid representation of + Bytes in Zarr V3. If so, it returns a new instance of + Bytes. + + Parameters + ---------- + data : DTypeJSON + The JSON data to parse. + + Returns + ------- + Bytes + + Raises + ------ + DataTypeValidationError + If the input data is not a valid representation of this class. + """ + + if cls._check_json_v3(data): + return cls() + msg = f"Invalid JSON representation of {cls.__name__}. Got {data!r}, expected the string {cls._zarr_v3_name!r}" + raise DataTypeValidationError(msg) + + @overload + def to_json(self, zarr_format: Literal[2]) -> BytesJSON_V2: ... + + @overload + def to_json(self, zarr_format: Literal[3]) -> BytesJSON_V3: ... + + def to_json(self, zarr_format: ZarrFormat) -> BytesJSON_V2 | BytesJSON_V3: + """ + Convert the variable-length bytes data type to a JSON-serializable form. + + Parameters + ---------- + zarr_format : ZarrFormat + The zarr format version. Accepted values are 2 and 3. + + Returns + ------- + ``DTypeConfig_V2[Literal["|O"], Literal["bytes"]] | Literal["bytes"]`` + The JSON-serializable representation of the variable-length bytes data type. + For zarr_format 2, returns a dictionary with "name" and "object_codec_id". + For zarr_format 3, returns a string identifier "bytes". + + Raises + ------ + ValueError + If zarr_format is not 2 or 3. + """ + + if zarr_format == 2: + return {"name": "|O", "object_codec_id": self.object_codec_id} + elif zarr_format == 3: + return self._zarr_v3_name + raise ValueError(f"zarr_format must be 2 or 3, got {zarr_format}") # pragma: no cover + + def default_scalar(self) -> bytes: + """ + Return the default scalar value for the variable-length bytes data type. + + Returns + ------- + bytes + The default scalar value, which is an empty byte string. + """ + + return b"" + + def to_json_scalar(self, data: object, *, zarr_format: ZarrFormat) -> str: + """ + Convert a scalar to a JSON-serializable tuple of integers. + + This method encodes the given scalar as bytes and then + encodes the bytes as a base64-encoded string. + + Parameters + ---------- + data : object + The scalar to convert. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + str + A base64-encoded string. + """ + return base64_encode_bytes(self.cast_scalar(data)) + + def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> bytes: + """ + Decode a base64-encoded JSON string or sequence of integers to bytes. + + Parameters + ---------- + data : JSON + The JSON-serializable base64-encoded string. + zarr_format : ZarrFormat + The zarr format version. + + Returns + ------- + bytes + The decoded bytes from the base64 string. + + Raises + ------ + TypeError + If the input data is not a base64-encoded string or a sequence of integers. + """ + if check_json_str(data): + return base64_decode_bytes(data) + if check_json_array_of_ints(data): + return bytes(data) + raise TypeError( + f"Invalid type: {data}. Expected a sequence of integers or a base64-encoded string." + ) # pragma: no cover + + def _check_scalar(self, data: object) -> TypeGuard[bytes | str]: + """ + Check if the provided data is bytes. + + Parameters + ---------- + data : object + The data to check. + + Returns + ------- + TypeGuard[BytesLike] + True if the data is bytes, False otherwise. + """ + return isinstance(data, (bytes, str)) + + def _cast_scalar_unchecked(self, data: bytes | str) -> bytes: + """ + Cast the provided scalar data to bytes. + + Parameters + ---------- + data : BytesLike + The data to cast. + + Returns + ------- + bytes + The casted data as bytes. + + Notes + ----- + This method does not perform any type checking. + The input data must be bytes-like. + """ + if isinstance(data, str): + return bytes(data, encoding="utf-8") + return bytes(data) + + def cast_scalar(self, data: object) -> bytes: + """ + Attempt to cast a given object to a bytes scalar. + + This method first checks if the provided data is a valid scalar that can be + converted to a bytes scalar. If the check succeeds, the unchecked casting + operation is performed. If the data is not valid, a TypeError is raised. + + Parameters + ---------- + data : object + The data to be cast to a bytes scalar. + + Returns + ------- + bytes + The data cast as a bytes scalar. + + Raises + ------ + TypeError + If the data cannot be converted to a bytes scalar. + """ + + if self._check_scalar(data): + return self._cast_scalar_unchecked(data) + msg = ( + f"Cannot convert object {data!r} with type {type(data)} to a scalar compatible with the " + f"data type {self}." + ) + raise TypeError(msg) diff --git a/src/zarr/core/dtype/npy/common.py b/src/zarr/core/dtype/npy/common.py index 107b3bd12d..45cbe41562 100644 --- a/src/zarr/core/dtype/npy/common.py +++ b/src/zarr/core/dtype/npy/common.py @@ -547,6 +547,23 @@ def check_json_floatish_str(data: JSON) -> TypeGuard[FloatishStr]: return True +def check_json_array_of_ints(data: JSON) -> TypeGuard[Sequence[int]]: + """ + Check if an object is a sequence of integers. + + Parameters + ---------- + data : JSON + The JSON value to check. + + Returns + ------- + bool + True if the data is a sequence of integers, False otherwise. + """ + return isinstance(data, Sequence) and all(isinstance(item, int) for item in data) + + def check_json_str(data: JSON) -> TypeGuard[str]: """ Check if a JSON value is a string. diff --git a/src/zarr/core/dtype/npy/complex.py b/src/zarr/core/dtype/npy/complex.py index 99abee5e24..f8d59aa5b4 100644 --- a/src/zarr/core/dtype/npy/complex.py +++ b/src/zarr/core/dtype/npy/complex.py @@ -13,7 +13,6 @@ import numpy as np from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasEndianness, @@ -34,6 +33,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/float.py b/src/zarr/core/dtype/npy/float.py index 0be2cbca9b..32cb8e7d74 100644 --- a/src/zarr/core/dtype/npy/float.py +++ b/src/zarr/core/dtype/npy/float.py @@ -6,7 +6,6 @@ import numpy as np from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasEndianness, @@ -28,6 +27,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/int.py b/src/zarr/core/dtype/npy/int.py index 580776a865..59d36091b5 100644 --- a/src/zarr/core/dtype/npy/int.py +++ b/src/zarr/core/dtype/npy/int.py @@ -16,7 +16,6 @@ import numpy as np from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasEndianness, @@ -31,6 +30,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/string.py b/src/zarr/core/dtype/npy/string.py index 41d3a60078..ead868d0e3 100644 --- a/src/zarr/core/dtype/npy/string.py +++ b/src/zarr/core/dtype/npy/string.py @@ -18,7 +18,6 @@ from zarr.core.common import NamedConfig from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasEndianness, @@ -34,6 +33,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TDType_co, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/structured.py b/src/zarr/core/dtype/npy/structured.py index 8bedee07ef..d4c06a5c19 100644 --- a/src/zarr/core/dtype/npy/structured.py +++ b/src/zarr/core/dtype/npy/structured.py @@ -8,7 +8,6 @@ from zarr.core.common import NamedConfig from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasItemSize, @@ -23,6 +22,7 @@ check_json_str, ) from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/npy/time.py b/src/zarr/core/dtype/npy/time.py index 402a140321..3f358962c5 100644 --- a/src/zarr/core/dtype/npy/time.py +++ b/src/zarr/core/dtype/npy/time.py @@ -20,7 +20,6 @@ from zarr.core.common import NamedConfig from zarr.core.dtype.common import ( - DataTypeValidationError, DTypeConfig_V2, DTypeJSON, HasEndianness, @@ -35,6 +34,7 @@ get_endianness_from_numpy_dtype, ) from zarr.core.dtype.wrapper import TBaseDType, ZDType +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from zarr.core.common import JSON, ZarrFormat diff --git a/src/zarr/core/dtype/registry.py b/src/zarr/core/dtype/registry.py index cb9ab50044..f4e57e375d 100644 --- a/src/zarr/core/dtype/registry.py +++ b/src/zarr/core/dtype/registry.py @@ -6,15 +6,15 @@ import numpy as np -from zarr.core.dtype.common import ( - DataTypeValidationError, - DTypeJSON, -) +from zarr.errors import DataTypeResolutionError, DataTypeValidationError if TYPE_CHECKING: from importlib.metadata import EntryPoint from zarr.core.common import ZarrFormat + from zarr.core.dtype.common import ( + DTypeJSON, + ) from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType @@ -173,8 +173,8 @@ def match_dtype(self, dtype: TBaseDType) -> ZDType[TBaseDType, TBaseScalar]: "entirely by providing a specific Zarr data type when creating your array." "For more information, see https://github.com/zarr-developers/zarr-python/issues/3117" ) - raise ValueError(msg) - raise ValueError(f"No Zarr data type found that matches dtype '{dtype!r}'") + raise DataTypeResolutionError(msg) + raise DataTypeResolutionError(f"No Zarr data type found that matches dtype '{dtype!r}'") def match_json( self, data: DTypeJSON, *, zarr_format: ZarrFormat @@ -198,11 +198,27 @@ def match_json( ------ ValueError If no matching Zarr data type is found for the given JSON data. - """ + Notes + ----- + + If multiple matches are found, this function raises a ValueError. In this case + conflicting data types must be unregistered, or the Zarr data type should be explicitly + constructed. + """ + matched: list[ZDType[TBaseDType, TBaseScalar]] = [] for val in self.contents.values(): - try: - return val.from_json(data, zarr_format=zarr_format) - except DataTypeValidationError: - pass - raise ValueError(f"No Zarr data type found that matches {data!r}") + with contextlib.suppress(DataTypeValidationError): + matched.append(val.from_json(data, zarr_format=zarr_format)) + if len(matched) == 1: + return matched[0] + elif len(matched) > 1: + msg = ( + f"Zarr data type resolution from {data} failed. " + f"Multiple data type wrappers found that match dtype '{data}': {matched}. " + "You should unregister one of these data types, or avoid Zarr data type inference " + "entirely by providing a specific Zarr data type when creating your array." + "For more information, see https://github.com/zarr-developers/zarr-python/issues/3117" + ) + raise DataTypeResolutionError(msg) + raise DataTypeResolutionError(f"No Zarr data type found that matches {data!r}") diff --git a/src/zarr/dtype.py b/src/zarr/dtype.py index 616d1c1ce2..77f7406bec 100644 --- a/src/zarr/dtype.py +++ b/src/zarr/dtype.py @@ -2,7 +2,6 @@ Bool, Complex64, Complex128, - DataTypeValidationError, DateTime64, DateTime64JSON_V2, DateTime64JSON_V3, @@ -43,6 +42,7 @@ parse_data_type, # noqa: F401 parse_dtype, ) +from zarr.errors import DataTypeValidationError __all__ = [ "Bool", diff --git a/src/zarr/errors.py b/src/zarr/errors.py index bcd6a08deb..39ed3b6e7b 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -6,10 +6,12 @@ "ContainsArrayAndGroupError", "ContainsArrayError", "ContainsGroupError", + "DataTypeValidationError", "GroupNotFoundError", "MetadataValidationError", "NegativeStepError", "NodeTypeValidationError", + "ScalarTypeValidationError", "UnstableSpecificationWarning", "VindexInvalidSelectionError", "ZarrDeprecationWarning", @@ -144,3 +146,13 @@ class BoundsCheckError(IndexError): ... class ArrayIndexError(IndexError): ... + + +class DataTypeValidationError(ValueError): ... + + +class ScalarTypeValidationError(ValueError): ... + + +class DataTypeResolutionError(ValueError): + """Error raised when an input cannot be unambiguously resolved to a Zarr data type class.""" diff --git a/tests/package_with_entrypoint/__init__.py b/tests/package_with_entrypoint/__init__.py index ae86378cb5..1dcefcd623 100644 --- a/tests/package_with_entrypoint/__init__.py +++ b/tests/package_with_entrypoint/__init__.py @@ -9,8 +9,8 @@ from zarr.abc.codec import ArrayBytesCodec, CodecInput, CodecPipeline from zarr.codecs import BytesCodec from zarr.core.buffer import Buffer, NDBuffer -from zarr.core.dtype.common import DataTypeValidationError, DTypeJSON, DTypeSpec_V2 from zarr.core.dtype.npy.bool import Bool +from zarr.errors import DataTypeValidationError if TYPE_CHECKING: from collections.abc import Iterable @@ -18,6 +18,7 @@ from zarr.core.array_spec import ArraySpec from zarr.core.common import ZarrFormat + from zarr.core.dtype.common import DTypeJSON, DTypeSpec_V2 class TestEntrypointCodec(ArrayBytesCodec): diff --git a/tests/test_dtype/test_npy/test_bytes.py b/tests/test_dtype/test_npy/test_bytes.py index 6a4bcc4691..b87b4364f7 100644 --- a/tests/test_dtype/test_npy/test_bytes.py +++ b/tests/test_dtype/test_npy/test_bytes.py @@ -2,7 +2,12 @@ import pytest from tests.test_dtype.test_wrapper import BaseTestZDType -from zarr.core.dtype.npy.bytes import NullTerminatedBytes, RawBytes, VariableLengthBytes +from zarr.core.dtype import ( + data_type_registry, + disable_legacy_bytes_dtype, + enable_legacy_bytes_dtype, +) +from zarr.core.dtype.npy.bytes import Bytes, NullTerminatedBytes, RawBytes, VariableLengthBytes from zarr.errors import UnstableSpecificationWarning @@ -101,6 +106,62 @@ class TestRawBytes(BaseTestZDType): ) +class TestBytes(BaseTestZDType): + test_cls = Bytes + valid_dtype = (np.dtype("|O"),) + invalid_dtype = ( + np.dtype(np.int8), + np.dtype(np.float64), + np.dtype("|U10"), + ) + valid_json_v2 = ({"name": "|O", "object_codec_id": "vlen-bytes"},) + valid_json_v3 = ("bytes",) + invalid_json_v2 = ( + "|S", + "|U10", + "|f8", + ) + + invalid_json_v3 = ( + {"name": "fixed_length_ascii", "configuration": {"length_bits": 0}}, + {"name": "numpy.fixed_length_ascii", "configuration": {"length_bits": "invalid"}}, + ) + + scalar_v2_params = ( + (Bytes(), ""), + (Bytes(), "YWI="), + ) + scalar_v3_params = ( + (Bytes(), ""), + (Bytes(), "YWI="), + ) + cast_value_params = ( + (Bytes(), "", b""), + (Bytes(), "ab", b"ab"), + (Bytes(), "abcdefg", b"abcdefg"), + ) + invalid_scalar_params = ((Bytes(), 1.0),) + item_size_params = (Bytes(),) + + +def test_bytes_string_fill_alias() -> None: + """ + Test that the bytes dtype parses a sequence of ints as a valid JSON + encoding for a bytes scalar. + """ + data = (1, 2, 3) + a = Bytes().from_json_scalar(data, zarr_format=3) + b = bytes(data) + assert a == b + + +def test_bytes_alias() -> None: + """Test that "variable_length_bytes" is an accepted alias for "bytes" in JSON metadata""" + a = Bytes.from_json("bytes", zarr_format=3) + b = Bytes.from_json("variable_length_bytes", zarr_format=3) + assert a == b + + class TestVariableLengthBytes(BaseTestZDType): test_cls = VariableLengthBytes valid_dtype = (np.dtype("|O"),) @@ -170,3 +231,24 @@ def test_invalid_size(zdtype_cls: type[NullTerminatedBytes] | type[RawBytes]) -> msg = f"length must be >= 1, got {length}." with pytest.raises(ValueError, match=msg): zdtype_cls(length=length) + + +def test_legacy_bytes_compatibility() -> None: + """ + Test that the enable_legacy_bytes_dtype function unregisters the Bytes + dtype and inserts the VariableLengthBytes dtype in the registry. Also + test that this operation is reversed by the disable_legacy_bytes_dtype() + function. + """ + assert "bytes" in data_type_registry.contents + assert "variable_length_bytes" not in data_type_registry.contents + + enable_legacy_bytes_dtype() + + assert "bytes" not in data_type_registry.contents + assert "variable_length_bytes" in data_type_registry.contents + + disable_legacy_bytes_dtype() + + assert "bytes" in data_type_registry.contents + assert "variable_length_bytes" not in data_type_registry.contents