From e45fd526d3d48d1dc4cd511bc45a0fa9b849bcd3 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 15:39:15 -0700 Subject: [PATCH 1/6] Adapt map_blocks to use new Coordinates API --- xarray/core/coordinates.py | 2 +- xarray/core/parallel.py | 44 ++++++++++++++++++++++++-------------- xarray/tests/test_dask.py | 19 ++++++++++++++++ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cdf1d354be6..c59c5deba16 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -213,7 +213,7 @@ class Coordinates(AbstractCoordinates): :py:class:`~xarray.Coordinates` object is passed, its indexes will be added to the new created object. indexes: dict-like, optional - Mapping of where keys are coordinate names and values are + Mapping where keys are coordinate names and values are :py:class:`~xarray.indexes.Index` objects. If None (default), pandas indexes will be created for each dimension coordinate. Passing an empty dictionary will skip this default behavior. diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index f971556b3f7..afe045965f5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -11,6 +11,7 @@ from xarray.core.alignment import align from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection if TYPE_CHECKING: @@ -345,13 +346,16 @@ def _wrapper( for arg in aligned ) + merged_coordinates = merge([arg.coords for arg in aligned]).coords + _, npargs = unzip( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) + # TODO (dcherian) cleanup to just use a Coordinates object + input_indexes = dict(npargs[0]._indexes) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) - input_indexes = dict(npargs[0]._indexes) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) @@ -360,18 +364,29 @@ def _wrapper( if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) - template_indexes = set(template._indexes) - preserved_indexes = template_indexes & set(input_indexes) - new_indexes = template_indexes - set(input_indexes) - indexes = {dim: input_indexes[dim] for dim in preserved_indexes} - indexes.update({k: template._indexes[k] for k in new_indexes}) + template_coords = set(template.coords) + preserved_indexes = template_coords & set(merged_coordinates) + new_indexes = template_coords - set(merged_coordinates) + + preserved_coords = merged_coordinates.to_dataset()[preserved_indexes] + # preserved_coords contains all coordinates bariables that share a dimension + # with any index variable in preserved_indexes + # Drop any unneeded vars in a second pass, this is required for e.g. + # if the mapped function were to drop a non-dimension coordinate variable. + preserved_coords = preserved_coords.drop_vars( + tuple(k for k in preserved_coords.variables if k not in template_coords) + ) + + coordinates = merge( + (preserved_coords, template.coords.to_dataset()[new_indexes]) + ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks } else: # template xarray object has been provided with proper sizes and chunk shapes - indexes = dict(template._indexes) + coordinates = template.coords output_chunks = template.chunksizes if not output_chunks: raise ValueError( @@ -496,8 +511,10 @@ def subset_dataset_to_block( expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] expected["indexes"] = { - dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] - for dim in indexes + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in coordinates.xindexes } from_wrapper = (gname,) + chunk_tuple @@ -506,7 +523,7 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} for name, variable in template.variables.items(): - if name in indexes: + if name in coordinates.indexes: continue gname_l = f"{name}-{gname}" var_key_map[name] = gname_l @@ -543,12 +560,7 @@ def subset_dataset_to_block( }, ) - # TODO: benbovy - flexible indexes: make it work with custom indexes - # this will need to pass both indexes and coords to the Dataset constructor - result = Dataset( - coords={k: idx.to_pandas_index() for k, idx in indexes.items()}, - attrs=template.attrs, - ) + result = Dataset(coords=coordinates, attrs=template.attrs) for index in result._indexes: result[index].attrs = template[index].attrs diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c2a77c97d85..137d6020829 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1367,6 +1367,25 @@ def test_map_blocks_da_ds_with_template(obj): assert_identical(actual, template) +def test_map_blocks_roundtrip_string_index(): + ds = xr.Dataset( + {"data": (["label"], [1, 2, 3])}, coords={"label": ["foo", "bar", "baz"]} + ).chunk(label=1) + assert ds.label.dtype == np.dtype(" Date: Mon, 18 Dec 2023 16:13:47 -0700 Subject: [PATCH 2/6] cleanup --- xarray/core/parallel.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index afe045965f5..1d707c31c14 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -352,14 +352,11 @@ def _wrapper( sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) ) - # TODO (dcherian) cleanup to just use a Coordinates object - input_indexes = dict(npargs[0]._indexes) # check that chunk sizes are compatible input_chunks = dict(npargs[0].chunks) for arg in xarray_objs[1:]: assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) - input_indexes.update(arg._indexes) if template is None: # infer template by providing zero-shaped arrays From b5f763d223c3c7c1312b944d7bf1441418151ebd Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Dec 2023 16:28:57 -0700 Subject: [PATCH 3/6] typing fixes --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/core/parallel.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0335ad3bdda..0f245ff464b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -80,7 +80,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None try: from dask.delayed import Delayed except ImportError: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9ec39e74ad1..a6fc0e2ca18 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -171,7 +171,7 @@ try: from dask.dataframe import DataFrame as DaskDataFrame except ImportError: - DaskDataFrame = None # type: ignore + DaskDataFrame = None # list of attributes of pd.DatetimeIndex that are ndarrays of time info diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 1d707c31c14..6b4797af841 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -9,6 +9,7 @@ import numpy as np from xarray.core.alignment import align +from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.merge import merge @@ -358,6 +359,7 @@ def _wrapper( assert_chunks_compatible(npargs[0], arg) input_chunks.update(arg.chunks) + coordinates: Coordinates if template is None: # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) @@ -499,7 +501,7 @@ def subset_dataset_to_block( # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected = {} + expected: dict[Hashable, dict] = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function expected["shapes"] = { From 85b4133fa841109fb720aa9e7a2347e24d15dca8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 09:32:01 -0700 Subject: [PATCH 4/6] optimize --- xarray/core/parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 6b4797af841..2cc8e421a59 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -487,6 +487,7 @@ def subset_dataset_to_block( return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + include_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index @@ -521,9 +522,8 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name, variable in template.variables.items(): - if name in coordinates.indexes: - continue + for name in include_variables: + variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l From 8b1e341b45b47546ac266bb63c3de71c4fedc49b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 10:26:51 -0700 Subject: [PATCH 5/6] small cleanups --- xarray/core/parallel.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 2cc8e421a59..8da936fd28d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -364,10 +364,10 @@ def _wrapper( # infer template by providing zero-shaped arrays template = infer_template(func, aligned[0], *args, **kwargs) template_coords = set(template.coords) - preserved_indexes = template_coords & set(merged_coordinates) - new_indexes = template_coords - set(merged_coordinates) + preserved_coord_vars = template_coords & set(merged_coordinates) + new_coord_vars = template_coords - set(merged_coordinates) - preserved_coords = merged_coordinates.to_dataset()[preserved_indexes] + preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars] # preserved_coords contains all coordinates bariables that share a dimension # with any index variable in preserved_indexes # Drop any unneeded vars in a second pass, this is required for e.g. @@ -377,7 +377,7 @@ def _wrapper( ) coordinates = merge( - (preserved_coords, template.coords.to_dataset()[new_indexes]) + (preserved_coords, template.coords.to_dataset()[new_coord_vars]) ).coords output_chunks: Mapping[Hashable, tuple[int, ...]] = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks @@ -487,7 +487,9 @@ def subset_dataset_to_block( return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) - include_variables = set(template.variables) - set(coordinates.indexes) + # variable names that depend on the computation. Currently, indexes + # cannot be modified in the mapped function, so we exclude thos + computed_variables = set(template.variables) - set(coordinates.xindexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index @@ -522,7 +524,7 @@ def subset_dataset_to_block( # mapping from variable name to dask graph key var_key_map: dict[Hashable, str] = {} - for name in include_variables: + for name in computed_variables: variable = template.variables[name] gname_l = f"{name}-{gname}" var_key_map[name] = gname_l From 8d0188b704f9835a435b8b697e8e4c61a2ab7d7f Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Dec 2023 10:38:46 -0700 Subject: [PATCH 6/6] Typing fixes --- xarray/core/parallel.py | 46 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8da936fd28d..ef505b55345 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -4,7 +4,7 @@ import itertools import operator from collections.abc import Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict import numpy as np @@ -12,6 +12,7 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.indexes import Index from xarray.core.merge import merge from xarray.core.pycompat import is_dask_collection @@ -19,6 +20,13 @@ from xarray.core.types import T_Xarray +class ExpectedDict(TypedDict): + shapes: dict[Hashable, int] + coords: set[Hashable] + data_vars: set[Hashable] + indexes: dict[Hashable, Index] + + def unzip(iterable): return zip(*iterable) @@ -33,7 +41,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset): def check_result_variables( - result: DataArray | Dataset, expected: Mapping[str, Any], kind: str + result: DataArray | Dataset, + expected: ExpectedDict, + kind: Literal["coords", "data_vars"], ): if kind == "coords": nice_str = "coordinate" @@ -256,7 +266,7 @@ def _wrapper( args: list, kwargs: dict, arg_is_array: Iterable[bool], - expected: dict, + expected: ExpectedDict, ): """ Wrapper function that receives datasets in args; converts to dataarrays when necessary; @@ -502,21 +512,23 @@ def subset_dataset_to_block( for isxr, arg in zip(is_xarray, npargs) ] - # expected["shapes", "coords", "data_vars", "indexes"] are used to # raise nice error messages in _wrapper - expected: dict[Hashable, dict] = {} - # input chunk 0 along a dimension maps to output chunk 0 along the same dimension - # even if length of dimension is changed by the applied function - expected["shapes"] = { - k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks - } - expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment] - expected["coords"] = set(template.coords.keys()) # type: ignore[assignment] - expected["indexes"] = { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in coordinates.xindexes + expected: ExpectedDict = { + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + "indexes": { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in coordinates.xindexes + }, } from_wrapper = (gname,) + chunk_tuple