diff --git a/cf/data/dask_utils.py b/cf/data/dask_utils.py index 3c12a65f24..6429fd2afe 100644 --- a/cf/data/dask_utils.py +++ b/cf/data/dask_utils.py @@ -115,6 +115,32 @@ def allclose(a_blocks, b_blocks, rtol=rtol, atol=atol): ) +def cf_contains(a, value): + """Whether or not an array contains a value. + + .. versionadded:: TODODASK + + .. seealso:: `cf.Data.__contains__` + + :Parameters: + + a: `numpy.ndarray` + The array. + + value: array_like + The value. + + :Returns: + + `numpy.ndarray` + A size 1 Boolean array, with the same number of dimensions + as *a*, that indicates whether or not *a* contains the + value. + + """ + return np.array(value in a).reshape((1,) * a.ndim) + + try: from scipy.ndimage import convolve1d except ImportError: diff --git a/cf/data/data.py b/cf/data/data.py index 068b309d90..ffc1c03c65 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -104,6 +104,7 @@ ) from .dask_utils import ( _da_ma_allclose, + cf_contains, cf_dt2rt, cf_harden_mask, cf_percentile, @@ -650,37 +651,81 @@ def __contains__(self, value): x.__contains__(y) <==> y in x - Returns True if the value is contained anywhere in the data - array. The value may be a `cf.Data` object. + Returns True if the scalar *value* is contained anywhere in + the data. If *value* is not scalar then an exception is + raised. **Performance** - All delayed operations are exectued, and there is no - short-circuit once the first occurrence is found. + `__contains__` causes all delayed operations to be computed + unless *value* is a `Data` object with incompatible units, in + which case `False` is always returned. - **Examples:** + **Examples** - >>> d = cf.Data([[0.0, 1, 2], [3, 4, 5]], 'm') + >>> d = cf.Data([[0, 1, 2], [3, 4, 5]], 'm') >>> 4 in d True - >>> cf.Data(3) in d + >>> 4.0 in d True - >>> cf.Data([2.5], units='2 m') in d + >>> cf.Data(5) in d True - >>> [[2]] in d + >>> cf.Data(5, 'm') in d True - >>> numpy.array([[[2]]]) in d + >>> cf.Data(0.005, 'km') in d True - >>> Data(2, 'seconds') in d + + >>> 99 in d + False + >>> cf.Data(2, 'seconds') in d False - """ + >>> [1] in d + Traceback (most recent call last): + ... + TypeError: elementwise comparison failed; must test against a scalar, not [1] + >>> [1, 2] in d + Traceback (most recent call last): + ... + TypeError: elementwise comparison failed; must test against a scalar, not [1, 2] - def contains_chunk(a, value): - out = value in a - return np.array(out).reshape((1,) * a.ndim) + >>> d = cf.Data(["foo", "bar"]) + >>> 'foo' in d + True + >>> 'xyz' in d + False + + """ + # Check that value is scalar by seeing if its shape is () + shape = getattr(value, "shape", None) + if shape is None: + if isinstance(value, str): + # Strings are scalars, even though they have a len(). + shape = () + else: + try: + len(value) + except TypeError: + # value has no len() so assume that it is a scalar + shape = () + else: + # value has a len() so assume that it is not a scalar + shape = True + elif is_dask_collection(value) and math.isnan(value.size): + # value is a dask array with unknown size, so calculate + # the size. This is acceptable, as we're going to compute + # it anyway at the end of this method. + value.compute_chunk_sizes() + shape = value.shape + + if shape: + raise TypeError( + "elementwise comparison failed; must test against a scalar, " + f"not {value!r}" + ) - if isinstance(value, self.__class__): # TODDASK chek aother type stoo + # If value is a scalar Data object then conform its units + if isinstance(value, self.__class__): self_units = self.Units value_units = value.Units if value_units.equivalent(self_units): @@ -688,6 +733,8 @@ def contains_chunk(a, value): value = value.copy() value.Units = self_units elif value_units: + # No need to check the dask array if the value units + # are incompatible return False value = value._get_dask() @@ -698,10 +745,12 @@ def contains_chunk(a, value): dx_ind = out_ind dx = da.blockwise( - partial(contains_chunk, value=value), + cf_contains, out_ind, dx, dx_ind, + value, + (), adjust_chunks={i: 1 for i in out_ind}, dtype=bool, ) diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 53eed8b5e6..a68ab17555 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -8,6 +8,7 @@ from functools import reduce from operator import mul +import dask.array as da import numpy as np SCIPY_AVAILABLE = False @@ -1094,21 +1095,60 @@ def test_Data_AUXILIARY_MASK(self): self.assertEqual(f.shape, fm.shape) self.assertTrue((f._auxiliary_mask_return().array == fm).all()) - @unittest.skipIf(TEST_DASKIFIED_ONLY, "TypeError: 'int' is not iterable") - def test_Data___contains__(self): + def test_Data__contains__(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: return - d = cf.Data([[0.0, 1, 2], [3, 4, 5]], units="m") - self.assertIn(4, d) - self.assertNotIn(40, d) - self.assertIn(cf.Data(3), d) - self.assertIn(cf.Data([[[[3]]]]), d) - value = d[1, 2] - value.Units *= 2 - value.squeeze(0) - self.assertIn(value, d) - self.assertIn(np.array([[[2]]]), d) + d = cf.Data([[0, 1, 2], [3, 4, 5]], units="m", chunks=2) + + for value in ( + 4, + 4.0, + cf.Data(3), + cf.Data(0.005, "km"), + np.array(2), + da.from_array(2), + ): + self.assertIn(value, d) + + for value in ( + 99, + np.array(99), + da.from_array(99), + cf.Data(99, "km"), + cf.Data(2, "seconds"), + ): + self.assertNotIn(value, d) + + for value in ( + [1], + [[1]], + [1, 2], + [[1, 2]], + np.array([1]), + np.array([[1]]), + np.array([1, 2]), + np.array([[1, 2]]), + da.from_array([1]), + da.from_array([[1]]), + da.from_array([1, 2]), + da.from_array([[1, 2]]), + cf.Data([1]), + cf.Data([[1]]), + cf.Data([1, 2]), + cf.Data([[1, 2]]), + cf.Data([0.005], "km"), + ): + with self.assertRaises(TypeError): + value in d + + # Strings + d = cf.Data(["foo", "bar"]) + self.assertIn("foo", d) + self.assertNotIn("xyz", d) + + with self.assertRaises(TypeError): + ["foo"] in d @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'") def test_Data_asdata(self):