diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index c923ca2eb87..38f8f8cd495 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,7 +1,8 @@ from __future__ import annotations -import copy import itertools +import textwrap +from collections import ChainMap from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( @@ -16,6 +17,7 @@ ) from xarray.core import utils +from xarray.core.alignment import align from xarray.core.common import TreeAttrAccessMixin from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray @@ -31,7 +33,7 @@ MappedDataWithCoords, ) from xarray.core.datatree_render import RenderDataTree -from xarray.core.formatting import datatree_repr +from xarray.core.formatting import datatree_repr, dims_and_coords_repr from xarray.core.formatting_html import ( datatree_repr as datatree_repr_html, ) @@ -79,11 +81,24 @@ T_Path = Union[str, NodePath] +def _collect_data_and_coord_variables( + data: Dataset, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]: + data_variables = {} + coord_variables = {} + for k, v in data.variables.items(): + if k in data._coord_names: + coord_variables[k] = v + else: + data_variables[k] = v + return data_variables, coord_variables + + def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: if isinstance(data, DataArray): ds = data.to_dataset() elif isinstance(data, Dataset): - ds = data + ds = data.copy(deep=False) elif data is None: ds = Dataset() else: @@ -93,14 +108,57 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: return ds -def _check_for_name_collisions( - children: Iterable[str], variables: Iterable[Hashable] +def _join_path(root: str, name: str) -> str: + return str(NodePath(root) / name) + + +def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: + return Dataset._construct_direct( + variables=parent._variables | ds._variables, + coord_names=parent._coord_names | ds._coord_names, + dims=parent._dims | ds._dims, + attrs=ds._attrs, + indexes=parent._indexes | ds._indexes, + encoding=ds._encoding, + close=ds._close, + ) + + +def _without_header(text: str) -> str: + return "\n".join(text.split("\n")[1:]) + + +def _indented(text: str) -> str: + return textwrap.indent(text, prefix=" ") + + +def _check_alignment( + path: str, + node_ds: Dataset, + parent_ds: Dataset | None, + children: Mapping[str, DataTree], ) -> None: - colliding_names = set(children).intersection(set(variables)) - if colliding_names: - raise KeyError( - f"Some names would collide between variables and children: {list(colliding_names)}" - ) + if parent_ds is not None: + try: + align(node_ds, parent_ds, join="exact") + except ValueError as e: + node_repr = _indented(_without_header(repr(node_ds))) + parent_repr = _indented(dims_and_coords_repr(parent_ds)) + raise ValueError( + f"group {path!r} is not aligned with its parents:\n" + f"Group:\n{node_repr}\nFrom parents:\n{parent_repr}" + ) from e + + if children: + if parent_ds is not None: + base_ds = _inherited_dataset(node_ds, parent_ds) + else: + base_ds = node_ds + + for child_name, child in children.items(): + child_path = str(NodePath(path) / child_name) + child_ds = child.to_dataset(inherited=False) + _check_alignment(child_path, child_ds, base_ds, child.children) class DatasetView(Dataset): @@ -118,7 +176,7 @@ class DatasetView(Dataset): __slots__ = ( "_attrs", - "_cache", + "_cache", # used by _CachedAccessor "_coord_names", "_dims", "_encoding", @@ -136,21 +194,27 @@ def __init__( raise AttributeError("DatasetView objects are not to be initialized directly") @classmethod - def _from_node( + def _constructor( cls, - wrapping_node: DataTree, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int], + attrs: dict | None, + indexes: dict[Any, Index], + encoding: dict | None, + close: Callable[[], None] | None, ) -> DatasetView: - """Constructor, using dataset attributes from wrapping node""" - + """Private constructor, from Dataset attributes.""" + # We override Dataset._construct_direct below, so we need a new + # constructor for creating DatasetView objects. obj: DatasetView = object.__new__(cls) - obj._variables = wrapping_node._variables - obj._coord_names = wrapping_node._coord_names - obj._dims = wrapping_node._dims - obj._indexes = wrapping_node._indexes - obj._attrs = wrapping_node._attrs - obj._close = wrapping_node._close - obj._encoding = wrapping_node._encoding - + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding return obj def __setitem__(self, key, val) -> None: @@ -337,27 +401,27 @@ class DataTree( _name: str | None _parent: DataTree | None _children: dict[str, DataTree] + _cache: dict[str, Any] # used by _CachedAccessor + _data_variables: dict[Hashable, Variable] + _node_coord_variables: dict[Hashable, Variable] + _node_dims: dict[Hashable, int] + _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None - _cache: dict[str, Any] - _coord_names: set[Hashable] - _dims: dict[Hashable, int] _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] - _variables: dict[Hashable, Variable] __slots__ = ( "_name", "_parent", "_children", + "_cache", # used by _CachedAccessor + "_data_variables", + "_node_coord_variables", + "_node_dims", + "_node_indexes", "_attrs", - "_cache", - "_coord_names", - "_dims", "_encoding", "_close", - "_indexes", - "_variables", ) def __init__( @@ -370,14 +434,15 @@ def __init__( """ Create a single node of a DataTree. - The node may optionally contain data in the form of data and coordinate variables, stored in the same way as - data is stored in an xarray.Dataset. + The node may optionally contain data in the form of data and coordinate + variables, stored in the same way as data is stored in an + xarray.Dataset. Parameters ---------- data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. - Default is None. + Data to store under the .ds attribute of this node. DataArrays will + be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional @@ -393,30 +458,48 @@ def __init__( -------- DataTree.from_dict """ - - # validate input if children is None: children = {} - ds = _coerce_to_dataset(data) - _check_for_name_collisions(children, ds.variables) super().__init__(name=name) + self._set_node_data(_coerce_to_dataset(data)) + self.parent = parent + self.children = children - # set data attributes - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) + def _set_node_data(self, ds: Dataset): + data_vars, coord_vars = _collect_data_and_coord_variables(ds) + self._data_variables = data_vars + self._node_coord_variables = coord_vars + self._node_dims = ds._dims + self._node_indexes = ds._indexes + self._encoding = ds._encoding + self._attrs = ds._attrs self._close = ds._close - # set tree attributes (must happen after variables set to avoid initialization errors) - self.children = children - self.parent = parent + def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._pre_attach(parent, name) + if name in parent.ds.variables: + raise KeyError( + f"parent {parent.name} already contains a variable named {name}" + ) + path = str(NodePath(parent.path) / name) + node_ds = self.to_dataset(inherited=False) + parent_ds = parent._to_dataset_view(rebuild_dims=False) + _check_alignment(path, node_ds, parent_ds, self.children) + + @property + def _coord_variables(self) -> ChainMap[Hashable, Variable]: + return ChainMap( + self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + ) + + @property + def _dims(self) -> ChainMap[Hashable, int]: + return ChainMap(self._node_dims, *(p._node_dims for p in self.parents)) + + @property + def _indexes(self) -> ChainMap[Hashable, Index]: + return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) @property def parent(self: DataTree) -> DataTree | None: @@ -429,71 +512,87 @@ def parent(self: DataTree, new_parent: DataTree) -> None: raise ValueError("Cannot set an unnamed node as a child of another node") self._set_parent(new_parent, self.name) + def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: + variables = dict(self._data_variables) + variables |= self._coord_variables + if rebuild_dims: + dims = calculate_dimensions(variables) + else: + # Note: rebuild_dims=False can create technically invalid Dataset + # objects because it may not contain all dimensions on its direct + # member variables, e.g., consider: + # tree = DataTree.from_dict( + # { + # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2 + # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1 + # } + # ) + # However, they are fine for internal use cases, for align() or + # building a repr(). + dims = dict(self._dims) + return DatasetView._constructor( + variables=variables, + coord_names=set(self._coord_variables), + dims=dims, + attrs=self._attrs, + indexes=dict(self._indexes), + encoding=self._encoding, + close=None, + ) + @property def ds(self) -> DatasetView: """ An immutable Dataset-like view onto the data in this node. - For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + Includes inherited coordinates and indexes from parent nodes. + + For a mutable Dataset containing the same data as in this node, use + `.to_dataset()` instead. See Also -------- DataTree.to_dataset """ - return DatasetView._from_node(self) + return self._to_dataset_view(rebuild_dims=True) @ds.setter def ds(self, data: Dataset | DataArray | None = None) -> None: - # Known mypy issue for setters with different type to property: - # https://github.com/python/mypy/issues/3004 ds = _coerce_to_dataset(data) + self._replace_node(ds) - _check_for_name_collisions(self.children, ds.variables) - - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) - self._close = ds._close - - def _pre_attach(self: DataTree, parent: DataTree) -> None: - """ - Method which superclass calls before setting parent, here used to prevent having two - children with duplicate names (or a data variable with the same name as a child). - """ - super()._pre_attach(parent) - if self.name in list(parent.ds.variables): - raise KeyError( - f"parent {parent.name} already contains a data variable named {self.name}" - ) - - def to_dataset(self) -> Dataset: + def to_dataset(self, inherited: bool = True) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. + Parameters + ---------- + inherited : bool, optional + If False, only include coordinates and indexes defined at the level + of this DataTree node, excluding inherited coordinates. + See Also -------- DataTree.ds """ + coord_vars = self._coord_variables if inherited else self._node_coord_variables + variables = dict(self._data_variables) + variables |= coord_vars + dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) return Dataset._construct_direct( - self._variables, - self._coord_names, - self._dims, - self._attrs, - self._indexes, - self._encoding, + variables, + set(coord_vars), + dims, + None if self._attrs is None else dict(self._attrs), + dict(self._indexes if inherited else self._node_indexes), + None if self._encoding is None else dict(self._encoding), self._close, ) @property - def has_data(self): - """Whether or not there are any data variables in this node.""" - return len(self._variables) > 0 + def has_data(self) -> bool: + """Whether or not there are any variables in this node.""" + return bool(self._data_variables or self._node_coord_variables) @property def has_attrs(self) -> bool: @@ -518,7 +617,7 @@ def variables(self) -> Mapping[Hashable, Variable]: Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ - return Frozen(self._variables) + return Frozen(self._data_variables | self._coord_variables) @property def attrs(self) -> dict[Hashable, Any]: @@ -579,7 +678,7 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: def _item_sources(self) -> Iterable[Mapping[Any, Any]]: """Places to look-up items for key-completion""" yield self.data_vars - yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords) # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) @@ -621,10 +720,10 @@ def __contains__(self, key: object) -> bool: return key in self.variables or key in self.children def __bool__(self) -> bool: - return bool(self.ds.data_vars) or bool(self.children) + return bool(self._data_variables) or bool(self._children) def __iter__(self) -> Iterator[Hashable]: - return itertools.chain(self.ds.data_vars, self.children) + return itertools.chain(self._data_variables, self._children) def __array__(self, dtype=None, copy=None): raise TypeError( @@ -646,122 +745,32 @@ def _repr_html_(self): return f"
{escape(repr(self))}"
return datatree_repr_html(self)
- @classmethod
- def _construct_direct(
- cls,
- variables: dict[Any, Variable],
- coord_names: set[Hashable],
- dims: dict[Any, int] | None = None,
- attrs: dict | None = None,
- indexes: dict[Any, Index] | None = None,
- encoding: dict | None = None,
- name: str | None = None,
- parent: DataTree | None = None,
- children: dict[str, DataTree] | None = None,
- close: Callable[[], None] | None = None,
- ) -> DataTree:
- """Shortcut around __init__ for internal use when we want to skip costly validation."""
+ def _replace_node(
+ self: DataTree,
+ data: Dataset | Default = _default,
+ children: dict[str, DataTree] | Default = _default,
+ ) -> None:
- # data attributes
- if dims is None:
- dims = calculate_dimensions(variables)
- if indexes is None:
- indexes = {}
- if children is None:
- children = dict()
+ ds = self.to_dataset(inherited=False) if data is _default else data
- obj: DataTree = object.__new__(cls)
- obj._variables = variables
- obj._coord_names = coord_names
- obj._dims = dims
- obj._indexes = indexes
- obj._attrs = attrs
- obj._close = close
- obj._encoding = encoding
-
- # tree attributes
- obj._name = name
- obj._children = children
- obj._parent = parent
+ if children is _default:
+ children = self._children
- return obj
+ for child_name in children:
+ if child_name in ds.variables:
+ raise ValueError(f"node already contains a variable named {child_name}")
- def _replace(
- self: DataTree,
- variables: dict[Hashable, Variable] | None = None,
- coord_names: set[Hashable] | None = None,
- dims: dict[Any, int] | None = None,
- attrs: dict[Hashable, Any] | None | Default = _default,
- indexes: dict[Hashable, Index] | None = None,
- encoding: dict | None | Default = _default,
- name: str | None | Default = _default,
- parent: DataTree | None | Default = _default,
- children: dict[str, DataTree] | None = None,
- inplace: bool = False,
- ) -> DataTree:
- """
- Fastpath constructor for internal use.
+ parent_ds = (
+ self.parent._to_dataset_view(rebuild_dims=False)
+ if self.parent is not None
+ else None
+ )
+ _check_alignment(self.path, ds, parent_ds, children)
- Returns an object with optionally replaced attributes.
+ if data is not _default:
+ self._set_node_data(ds)
- Explicitly passed arguments are *not* copied when placed on the new
- datatree. It is up to the caller to ensure that they have the right type
- and are not used elsewhere.
- """
- # TODO Adding new children inplace using this method will cause bugs.
- # You will end up with an inconsistency between the name of the child node and the key the child is stored under.
- # Use ._set() instead for now
- if inplace:
- if variables is not None:
- self._variables = variables
- if coord_names is not None:
- self._coord_names = coord_names
- if dims is not None:
- self._dims = dims
- if attrs is not _default:
- self._attrs = attrs
- if indexes is not None:
- self._indexes = indexes
- if encoding is not _default:
- self._encoding = encoding
- if name is not _default:
- self._name = name
- if parent is not _default:
- self._parent = parent
- if children is not None:
- self._children = children
- obj = self
- else:
- if variables is None:
- variables = self._variables.copy()
- if coord_names is None:
- coord_names = self._coord_names.copy()
- if dims is None:
- dims = self._dims.copy()
- if attrs is _default:
- attrs = copy.copy(self._attrs)
- if indexes is None:
- indexes = self._indexes.copy()
- if encoding is _default:
- encoding = copy.copy(self._encoding)
- if name is _default:
- name = self._name # no need to copy str objects or None
- if parent is _default:
- parent = copy.copy(self._parent)
- if children is _default:
- children = copy.copy(self._children)
- obj = self._construct_direct(
- variables,
- coord_names,
- dims,
- attrs,
- indexes,
- encoding,
- name,
- parent,
- children,
- )
- return obj
+ self._children = children
def copy(
self: DataTree,
@@ -813,9 +822,8 @@ def _copy_node(
deep: bool = False,
) -> DataTree:
"""Copy just one node of a tree"""
- new_node: DataTree = DataTree()
- new_node.name = self.name
- new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment]
+ data = self.ds.copy(deep=deep)
+ new_node: DataTree = DataTree(data, name=self.name)
return new_node
def __copy__(self: DataTree) -> DataTree:
@@ -963,11 +971,12 @@ def update(
raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")
vars_merge_result = dataset_update_method(self.to_dataset(), new_variables)
+ data = Dataset._construct_direct(**vars_merge_result._asdict())
+
# TODO are there any subtleties with preserving order of children like this?
merged_children = {**self.children, **new_children}
- self._replace(
- inplace=True, children=merged_children, **vars_merge_result._asdict()
- )
+
+ self._replace_node(data, children=merged_children)
def assign(
self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any
@@ -1042,10 +1051,12 @@ def drop_nodes(
if extra:
raise KeyError(f"Cannot drop all nodes - nodes {extra} not present")
+ result = self.copy()
children_to_keep = {
- name: child for name, child in self.children.items() if name not in names
+ name: child for name, child in result.children.items() if name not in names
}
- return self._replace(children=children_to_keep)
+ result._replace_node(children=children_to_keep)
+ return result
@classmethod
def from_dict(
@@ -1137,7 +1148,9 @@ def indexes(self) -> Indexes[pd.Index]:
@property
def xindexes(self) -> Indexes[Index]:
"""Mapping of xarray Index objects used for label based indexing."""
- return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes})
+ return Indexes(
+ self._indexes, {k: self._coord_variables[k] for k in self._indexes}
+ )
@property
def coords(self) -> DatasetCoordinates:
diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py
index 1473e624d9e..36665a0d153 100644
--- a/xarray/core/datatree_io.py
+++ b/xarray/core/datatree_io.py
@@ -85,7 +85,7 @@ def _datatree_to_netcdf(
unlimited_dims = {}
for node in dt.subtree:
- ds = node.ds
+ ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
@@ -151,7 +151,7 @@ def _datatree_to_zarr(
)
for node in dt.subtree:
- ds = node.ds
+ ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index 5c4a3015843..6dca4eba8e8 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -748,6 +748,27 @@ def dataset_repr(ds):
return "\n".join(summary)
+def dims_and_coords_repr(ds) -> str:
+ """Partial Dataset repr for use inside DataTree inheritance errors."""
+ summary = []
+
+ col_width = _calculate_col_width(ds.coords)
+ max_rows = OPTIONS["display_max_rows"]
+
+ dims_start = pretty_print("Dimensions:", col_width)
+ dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
+ summary.append(f"{dims_start}({dims_values})")
+
+ if ds.coords:
+ summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))
+
+ unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
+ if unindexed_dims_str:
+ summary.append(unindexed_dims_str)
+
+ return "\n".join(summary)
+
+
def diff_dim_summary(a, b):
if a.sizes != b.sizes:
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
@@ -1030,7 +1051,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
- ds_info = "\n" + repr(node.ds)
+ ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py
index 9bf5befbe3f..24b290031eb 100644
--- a/xarray/core/formatting_html.py
+++ b/xarray/core/formatting_html.py
@@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
def datatree_node_repr(group_title: str, dt: DataTree) -> str:
header_components = [f"