diff --git a/xarray/core/common.py b/xarray/core/common.py index bae3b6cd73d..93a5bb71b07 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -20,6 +20,7 @@ ALL_DIMS = ReprObject('') +C = TypeVar('C') T = TypeVar('T') @@ -297,9 +298,11 @@ def get_index(self, key: Hashable) -> pd.Index: # need to ensure dtype=int64 in case range is empty on Python 2 return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64) - def _calc_assign_results(self, kwargs: Mapping[str, T] - ) -> MutableMapping[str, T]: - results = SortedKeysDict() # type: SortedKeysDict[str, T] + def _calc_assign_results( + self: C, + kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] + ) -> MutableMapping[Hashable, T]: + results = SortedKeysDict() # type: SortedKeysDict[Hashable, T] for k, v in kwargs.items(): if callable(v): results[k] = v(self) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0e28613323e..40966f684a2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4,7 +4,8 @@ from collections import OrderedDict from numbers import Number from typing import (Any, Callable, Dict, Hashable, Iterable, List, Mapping, - Optional, Sequence, Tuple, Union, cast, TYPE_CHECKING) + Optional, Sequence, Tuple, Union, cast, overload, + TYPE_CHECKING) import numpy as np import pandas as pd @@ -1752,17 +1753,35 @@ def transpose(self, def T(self) -> 'DataArray': return self.transpose() - def drop(self, - labels: Union[Hashable, Sequence[Hashable]], - dim: Hashable = None, - *, - errors: str = 'raise') -> 'DataArray': + # Drop coords + @overload + def drop( + self, + labels: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'DataArray': + ... + + # Drop index labels along dimension + @overload # noqa: F811 + def drop( + self, + labels: Any, # array-like + dim: Hashable, + *, + errors: str = 'raise' + ) -> 'DataArray': + ... + + def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 """Drop coordinates or index labels from this DataArray. Parameters ---------- labels : hashable or sequence of hashables - Name(s) of coordinate variables or index labels to drop. + Name(s) of coordinates or index labels to drop. + If dim is not None, labels can be any array-like. dim : hashable, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops coordinates rather than index labels. @@ -1775,8 +1794,6 @@ def drop(self, ------- dropped : DataArray """ - if utils.is_scalar(labels): - labels = [labels] ds = self._to_temp_dataset().drop(labels, dim, errors=errors) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b00dad965ed..5d3ca932ccc 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6,9 +6,10 @@ from distutils.version import LooseVersion from numbers import Number from pathlib import Path -from typing import (Any, DefaultDict, Dict, Hashable, Iterable, Iterator, List, - Mapping, MutableMapping, Optional, Sequence, Set, Tuple, - Union, cast, TYPE_CHECKING) +from typing import (Any, Callable, DefaultDict, Dict, Hashable, Iterable, + Iterator, List, Mapping, MutableMapping, Optional, + Sequence, Set, Tuple, Union, cast, overload, + TYPE_CHECKING) import numpy as np import pandas as pd @@ -315,10 +316,10 @@ class _LocIndexer: def __init__(self, dataset: 'Dataset'): self.dataset = dataset - def __getitem__(self, key: Mapping[str, Any]) -> 'Dataset': + def __getitem__(self, key: Mapping[Hashable, Any]) -> 'Dataset': if not utils.is_dict_like(key): raise TypeError('can only lookup dictionaries from Dataset.loc') - return self.dataset.sel(**key) + return self.dataset.sel(key) class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): @@ -792,7 +793,7 @@ def _replace_with_new_dims( # type: ignore self, variables: 'OrderedDict[Any, Variable]', coord_names: set = None, - attrs: 'OrderedDict' = __default, + attrs: Optional['OrderedDict'] = __default, indexes: 'OrderedDict[Any, pd.Index]' = __default, inplace: bool = False, ) -> 'Dataset': @@ -3261,7 +3262,8 @@ def merge( return self._replace_vars_and_dims(variables, coord_names, dims, inplace=inplace) - def _assert_all_in_dataset(self, names, virtual_okay=False): + def _assert_all_in_dataset(self, names: Iterable[Hashable], + virtual_okay: bool = False) -> None: bad_names = set(names) - set(self._variables) if virtual_okay: bad_names -= self.virtual_variables @@ -3269,14 +3271,36 @@ def _assert_all_in_dataset(self, names, virtual_okay=False): raise ValueError('One or more of the specified variables ' 'cannot be found in this dataset') - def drop(self, labels, dim=None, *, errors='raise'): + # Drop variables + @overload + def drop( + self, + labels: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'Dataset': + ... + + # Drop index labels along dimension + @overload # noqa: F811 + def drop( + self, + labels: Any, # array-like + dim: Hashable, + *, + errors: str = 'raise' + ) -> 'Dataset': + ... + + def drop(self, labels, dim=None, *, errors='raise'): # noqa: F811 """Drop variables or index labels from this dataset. Parameters ---------- - labels : scalar or list of scalars + labels : hashable or iterable of hashables Name(s) of variables or index labels to drop. - dim : None or str, optional + If dim is not None, labels can be any array-like. + dim : None or hashable, optional Dimension along which to drop index labels. By default (if ``dim is None``), drops variables rather than index labels. errors: {'raise', 'ignore'}, optional @@ -3291,11 +3315,21 @@ def drop(self, labels, dim=None, *, errors='raise'): """ if errors not in ['raise', 'ignore']: raise ValueError('errors must be either "raise" or "ignore"') - if utils.is_scalar(labels): - labels = [labels] + if dim is None: + if isinstance(labels, str) or not isinstance(labels, Iterable): + labels = {labels} + else: + labels = set(labels) + return self._drop_vars(labels, errors=errors) else: + # Don't cast to set, as it would harm performance when labels + # is a large numpy array + if utils.is_scalar(labels): + labels = [labels] + labels = np.asarray(labels) + try: index = self.indexes[dim] except KeyError: @@ -3304,25 +3338,38 @@ def drop(self, labels, dim=None, *, errors='raise'): new_index = index.drop(labels, errors=errors) return self.loc[{dim: new_index}] - def _drop_vars(self, names, errors='raise'): + def _drop_vars( + self, + names: set, + errors: str = 'raise' + ) -> 'Dataset': if errors == 'raise': self._assert_all_in_dataset(names) - drop = set(names) + variables = OrderedDict((k, v) for k, v in self._variables.items() - if k not in drop) + if k not in names) coord_names = set(k for k in self._coord_names if k in variables) indexes = OrderedDict((k, v) for k, v in self.indexes.items() - if k not in drop) + if k not in names) return self._replace_with_new_dims( variables, coord_names=coord_names, indexes=indexes) - def drop_dims(self, drop_dims, *, errors='raise'): + def drop_dims( + self, + drop_dims: Union[Hashable, Iterable[Hashable]], + *, + errors: str = 'raise' + ) -> 'Dataset': """Drop dimensions and associated variables from this dataset. Parameters ---------- drop_dims : str or list Dimension or dimensions to drop. + errors: {'raise', 'ignore'}, optional + If 'raise' (default), raises a ValueError error if any of the + dimensions passed are not in the dataset. If 'ignore', any given + labels that are in the dataset are dropped and no error is raised. Returns ------- @@ -3338,8 +3385,10 @@ def drop_dims(self, drop_dims, *, errors='raise'): if errors not in ['raise', 'ignore']: raise ValueError('errors must be either "raise" or "ignore"') - if utils.is_scalar(drop_dims): + if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable): drop_dims = [drop_dims] + else: + drop_dims = list(drop_dims) if errors == 'raise': missing_dimensions = [d for d in drop_dims if d not in self.dims] @@ -3351,7 +3400,7 @@ def drop_dims(self, drop_dims, *, errors='raise'): for d in v.dims if d in drop_dims) return self._drop_vars(drop_vars) - def transpose(self, *dims): + def transpose(self, *dims: Hashable) -> 'Dataset': """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -3359,7 +3408,7 @@ def transpose(self, *dims): Parameters ---------- - *dims : str, optional + *dims : Hashable, optional By default, reverse the dimensions on each array. Otherwise, reorder the dimensions to this order. @@ -3391,13 +3440,19 @@ def transpose(self, *dims): ds._variables[name] = var.transpose(*var_dims) return ds - def dropna(self, dim, how='any', thresh=None, subset=None): + def dropna( + self, + dim: Hashable, + how: str = 'any', + thresh: int = None, + subset: Iterable[Hashable] = None + ): """Returns a new dataset with dropped labels for missing values along the provided dimension. Parameters ---------- - dim : str + dim : Hashable Dimension along which to drop missing values. Dropping along multiple dimensions simultaneously is not yet supported. how : {'any', 'all'}, optional @@ -3405,8 +3460,8 @@ def dropna(self, dim, how='any', thresh=None, subset=None): * all : if all values are NA, drop that label thresh : int, default None If supplied, require this many non-NA values. - subset : sequence, optional - Subset of variables to check for missing values. By default, all + subset : iterable of hashable, optional + Which variables to check for missing values. By default, all variables in the dataset are checked. Returns @@ -3421,7 +3476,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): raise ValueError('%s must be a single dataset dimension' % dim) if subset is None: - subset = list(self.data_vars) + subset = iter(self.data_vars) count = np.zeros(self.dims[dim], dtype=np.int64) size = 0 @@ -3430,7 +3485,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += np.asarray(array.count(dims)) # type: ignore size += np.prod([self.dims[d] for d in dims]) if thresh is not None: @@ -3446,7 +3501,7 @@ def dropna(self, dim, how='any', thresh=None, subset=None): return self.isel({dim: mask}) - def fillna(self, value): + def fillna(self, value: Any) -> 'Dataset': """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3475,14 +3530,19 @@ def fillna(self, value): out = ops.fillna(self, value) return out - def interpolate_na(self, dim=None, method='linear', limit=None, - use_coordinate=True, - **kwargs): + def interpolate_na( + self, + dim: Hashable = None, + method: str = 'linear', + limit: int = None, + use_coordinate: Union[bool, Hashable] = True, + **kwargs: Any + ) -> 'Dataset': """Interpolate values according to different methods. Parameters ---------- - dim : str + dim : Hashable Specifies the dimension along which to interpolate. method : {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', 'polynomial', 'barycentric', 'krog', 'pchip', @@ -3506,6 +3566,8 @@ def interpolate_na(self, dim=None, method='linear', limit=None, limit : int, default None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. + kwargs : any + parameters passed verbatim to the underlying interplation function Returns ------- @@ -3524,14 +3586,14 @@ def interpolate_na(self, dim=None, method='linear', limit=None, **kwargs) return new - def ffill(self, dim, limit=None): - '''Fill NaN values by propogating values forward + def ffill(self, dim: Hashable, limit: int = None) -> 'Dataset': + """Fill NaN values by propogating values forward *Requires bottleneck.* Parameters ---------- - dim : str + dim : Hashable Specifies the dimension along which to propagate values when filling. limit : int, default None @@ -3543,14 +3605,14 @@ def ffill(self, dim, limit=None): Returns ------- Dataset - ''' + """ from .missing import ffill, _apply_over_vars_with_dim new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self, dim, limit=None): - '''Fill NaN values by propogating values backward + def bfill(self, dim: Hashable, limit: int = None) -> 'Dataset': + """Fill NaN values by propogating values backward *Requires bottleneck.* @@ -3568,13 +3630,13 @@ def bfill(self, dim, limit=None): Returns ------- Dataset - ''' + """ from .missing import bfill, _apply_over_vars_with_dim new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self, other): + def combine_first(self, other: 'Dataset') -> 'Dataset': """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -3583,7 +3645,7 @@ def combine_first(self, other): Parameters ---------- - other : DataArray + other : Dataset Used to fill all matching missing values in this array. Returns @@ -3593,13 +3655,21 @@ def combine_first(self, other): out = ops.fillna(self, other, join="outer", dataset_join="outer") return out - def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, - numeric_only=False, allow_lazy=False, **kwargs): + def reduce( + self, + func: Callable, + dim: Union[Hashable, Iterable[Hashable]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + numeric_only: bool = False, + allow_lazy: bool = False, + **kwargs: Any + ) -> 'Dataset': """Reduce this dataset by applying `func` along some dimension(s). Parameters ---------- - func : function + func : callable Function which can be called in the form `f(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. @@ -3616,7 +3686,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, are removed. numeric_only : bool, optional If True, only apply ``func`` to variables with a numeric dtype. - **kwargs : dict + **kwargs : Any Additional keyword arguments passed on to ``func``. Returns @@ -3627,10 +3697,10 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, """ if dim is ALL_DIMS: dim = None - if isinstance(dim, str): - dims = set([dim]) - elif dim is None: + if dim is None: dims = set(self.dims) + elif isinstance(dim, str) or not isinstance(dim, Iterable): + dims = {dim} else: dims = set(dim) @@ -3642,9 +3712,12 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - variables = OrderedDict() + variables = OrderedDict() # type: OrderedDict[Hashable, Variable] for name, var in self._variables.items(): - reduce_dims = [d for d in var.dims if d in dims] + reduce_dims = [ + d for d in var.dims + if d in dims + ] if name in self.coords: if not reduce_dims: variables[name] = var @@ -3660,7 +3733,7 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient - reduce_dims = None + reduce_dims = None # type: ignore variables[name] = var.reduce(func, dim=reduce_dims, keep_attrs=keep_attrs, keepdims=keepdims, @@ -3674,12 +3747,18 @@ def reduce(self, func, dim=None, keep_attrs=None, keepdims=False, return self._replace_with_new_dims( variables, coord_names=coord_names, attrs=attrs, indexes=indexes) - def apply(self, func, keep_attrs=None, args=(), **kwargs): + def apply( + self, + func: Callable, + keep_attrs: bool = None, + args: Iterable[Any] = (), + **kwargs: Any + ) -> 'Dataset': """Apply a function over the data variables in this dataset. Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, *args, **kwargs)` to transform each DataArray `x` in this dataset into another DataArray. @@ -3689,7 +3768,7 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs): be returned without attributes. args : tuple, optional Positional arguments passed on to `func`. - **kwargs : dict + **kwargs : Any Keyword arguments passed on to `func`. Returns @@ -3724,7 +3803,11 @@ def apply(self, func, keep_attrs=None, args=(), **kwargs): attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) - def assign(self, variables=None, **variables_kwargs): + def assign( + self, + variables: Mapping[Hashable, Any] = None, + **variables_kwargs: Hashable + ) -> 'Dataset': """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -3737,7 +3820,7 @@ def assign(self, variables=None, **variables_kwargs): scalar, or array), they are simply assigned. **variables_kwargs: The keyword arguments form of ``variables``. - One of variables or variables_kwarg must be provided. + One of variables or variables_kwargs must be provided. Returns ------- diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5697704bdbc..000469f24bf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1904,9 +1904,9 @@ def test_drop_coordinates(self): assert_identical(actual, expected) with raises_regex(ValueError, 'cannot be found'): - arr.drop(None) + arr.drop('w') - actual = expected.drop(None, errors='ignore') + actual = expected.drop('w', errors='ignore') assert_identical(actual, expected) renamed = arr.rename('foo') diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fc6f7f36938..fc15393f269 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2000,6 +2000,12 @@ def test_drop_index_labels(self): expected = data.isel(x=slice(0, 0)) assert_identical(expected, actual) + # DataArrays as labels are a nasty corner case as they are not + # Iterable[Hashable] - DataArray.__iter__ yields scalar DataArrays. + actual = data.drop(DataArray(['a', 'b', 'c']), 'x', errors='ignore') + expected = data.isel(x=slice(0, 0)) + assert_identical(expected, actual) + with raises_regex( ValueError, 'does not have coordinate labels'): data.drop(1, 'y')