diff --git a/changes/3623.misc.md b/changes/3623.misc.md new file mode 100644 index 0000000000..4060e55e5f --- /dev/null +++ b/changes/3623.misc.md @@ -0,0 +1,5 @@ +This PR contains minor, non-function-altering, changes to use `ZarrFormat` across the repo as opposed to duplicating is with `Literal[2,3]`. + +Additionally, it fixes broken linting by using a `Literal[True, False]` type hint for Numpy hypothesis testing, as opposed to `bool`. + +Basically improves the typehints and reduces fat-finger error surface area slightly. diff --git a/examples/custom_dtype/custom_dtype.py b/examples/custom_dtype/custom_dtype.py index a98f3414f6..ec38d782b6 100644 --- a/examples/custom_dtype/custom_dtype.py +++ b/examples/custom_dtype/custom_dtype.py @@ -217,7 +217,7 @@ def from_json_scalar(self, data: JSON, *, zarr_format: ZarrFormat) -> ml_dtypes. # this parametrized function will create arrays in zarr v2 and v3 using our new data type @pytest.mark.parametrize("zarr_format", [2, 3]) -def test_custom_dtype(tmp_path: Path, zarr_format: Literal[2, 3]) -> None: +def test_custom_dtype(tmp_path: Path, zarr_format: ZarrFormat) -> None: # create array and write values z_w = zarr.create_array( store=tmp_path, shape=(4,), dtype="int2", zarr_format=zarr_format, compressors=None diff --git a/src/zarr/_cli/cli.py b/src/zarr/_cli/cli.py index 785efe505b..35521f01ab 100644 --- a/src/zarr/_cli/cli.py +++ b/src/zarr/_cli/cli.py @@ -6,6 +6,7 @@ import zarr import zarr.metadata.migrate_v3 as migrate_metadata +from zarr.core.common import ZarrFormat from zarr.core.sync import sync from zarr.storage._common import make_store @@ -23,12 +24,12 @@ def _set_logging_level(*, verbose: bool) -> None: zarr.set_format("%(message)s") -class ZarrFormat(str, Enum): +class CLIZarrFormat(str, Enum): v2 = "v2" v3 = "v3" -class ZarrFormatV3(str, Enum): +class CLIZarrFormatV3(str, Enum): """Limit CLI choice to only v3""" v3 = "v3" @@ -37,7 +38,7 @@ class ZarrFormatV3(str, Enum): @app.command() # type: ignore[misc] def migrate( zarr_format: Annotated[ - ZarrFormatV3, + CLIZarrFormatV3, typer.Argument( help="Zarr format to migrate to. Currently only 'v3' is supported.", ), @@ -122,7 +123,7 @@ def migrate( @app.command() # type: ignore[misc] def remove_metadata( zarr_format: Annotated[ - ZarrFormat, + CLIZarrFormat, typer.Argument(help="Which format's metadata to remove - v2 or v3."), ], store: Annotated[ @@ -160,7 +161,7 @@ def remove_metadata( sync( migrate_metadata.remove_metadata( store=input_zarr_store, - zarr_format=cast(Literal[2, 3], int(zarr_format[1:])), + zarr_format=cast(ZarrFormat, int(zarr_format[1:])), force=force, dry_run=dry_run, ) diff --git a/src/zarr/metadata/migrate_v3.py b/src/zarr/metadata/migrate_v3.py index 8f83e01f20..a72939100d 100644 --- a/src/zarr/metadata/migrate_v3.py +++ b/src/zarr/metadata/migrate_v3.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Literal, cast +from typing import cast import numcodecs.abc @@ -140,7 +140,7 @@ async def remove_metadata( continue if force or await _metadata_exists( - cast(Literal[2, 3], alternative_metadata), store_path / parent_path + cast(ZarrFormat, alternative_metadata), store_path / parent_path ): _logger.info("Deleting metadata at %s", store_path / file_path) if not dry_run: diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 5eb17214fe..330f220b56 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -25,6 +25,8 @@ from zarr.storage._utils import normalize_path from zarr.types import AnyArray +TrueOrFalse = Literal[True, False] + # Copied from Xarray _attr_keys = st.text(st.characters(), min_size=1) _attr_values = st.recursive( @@ -131,7 +133,7 @@ def array_metadata( draw: st.DrawFn, *, array_shapes: Callable[..., st.SearchStrategy[tuple[int, ...]]] = npst.array_shapes, - zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats, + zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats, attributes: SearchStrategy[Mapping[str, JSON] | None] = attrs, ) -> ArrayV2Metadata | ArrayV3Metadata: zarr_format = draw(zarr_formats) @@ -348,8 +350,8 @@ def basic_indices( shape: tuple[int, ...], min_dims: int = 0, max_dims: int | None = None, - allow_newaxis: bool = False, - allow_ellipsis: bool = True, + allow_newaxis: TrueOrFalse = False, + allow_ellipsis: TrueOrFalse = True, ) -> Any: """Basic indices without unsupported negative slices.""" strategy = npst.basic_indices( @@ -362,7 +364,7 @@ def basic_indices( lambda idxr: ( not ( is_negative_slice(idxr) - or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr)) # type: ignore[redundant-expr] + or (isinstance(idxr, tuple) and any(is_negative_slice(idx) for idx in idxr)) ) ) ) diff --git a/tests/package_with_entrypoint/__init__.py b/tests/package_with_entrypoint/__init__.py index ae86378cb5..7b5dfb5a1e 100644 --- a/tests/package_with_entrypoint/__init__.py +++ b/tests/package_with_entrypoint/__init__.py @@ -84,7 +84,7 @@ class TestDataType(Bool): _zarr_v3_name: ClassVar[Literal["test"]] = "test" # type: ignore[assignment] @classmethod - def from_json(cls, data: DTypeJSON, *, zarr_format: Literal[2, 3]) -> Self: + def from_json(cls, data: DTypeJSON, *, zarr_format: ZarrFormat) -> Self: if zarr_format == 2 and data == {"name": cls._zarr_v3_name, "object_codec_id": None}: return cls() if zarr_format == 3 and data == cls._zarr_v3_name: