diff --git a/cf/data/data.py b/cf/data/data.py index a5cf4bf9fc..9819ab0f68 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -1656,6 +1656,8 @@ def dumps(self): return json_dumps(d, default=convert_to_builtin_type) + @daskified(_DASKIFIED_VERBOSE) + @_inplace_enabled(default=False) def digitize( self, bins, @@ -1663,6 +1665,7 @@ def digitize( open_ends=False, closed_ends=None, return_bins=False, + inplace=False, ): """Return the indices of the bins to which each value belongs. @@ -1747,6 +1750,8 @@ def digitize( return_bins: `bool`, optional If True then also return the bins in their 2-d form. + {{inplace: `bool`, optional}} + :Returns: `Data`, [`Data`] @@ -1755,7 +1760,7 @@ def digitize( If *return_bins* is True then also return the bins in their 2-d form. - **Examples:** + **Examples** >>> d = cf.Data(numpy.arange(12).reshape(3, 4)) [[ 0 1 2 3] @@ -1811,9 +1816,9 @@ def digitize( [ 1 1 1 --]] """ - out = self.copy() + d = _inplace_enabled_define_and_cleanup(self) - org_units = self.Units + org_units = d.Units bin_units = getattr(bins, "Units", None) @@ -1830,12 +1835,16 @@ def digitize( else: bin_units = org_units - bins = np.asanyarray(bins) + # Get bins as a numpy array + if isinstance(bins, np.ndarray): + bins = bins.copy() + else: + bins = np.asanyarray(bins) if bins.ndim > 2: raise ValueError( - "The 'bins' parameter must be scalar, 1-d or 2-d" - "Got: {!r}".format(bins) + "The 'bins' parameter must be scalar, 1-d or 2-d. " + f"Got: {bins!r}" ) two_d_bins = None @@ -1848,7 +1857,7 @@ def digitize( if bins.shape[1] != 2: raise ValueError( "The second dimension of the 'bins' parameter must " - "have size 2. Got: {!r}".format(bins) + f"have size 2. Got: {bins!r}" ) bins.sort(axis=1) @@ -1858,11 +1867,9 @@ def digitize( for i, (u, l) in enumerate(zip(bins[:-1, 1], bins[1:, 0])): if u > l: raise ValueError( - "Overlapping bins: {}, {}".format( - tuple(bins[i]), tuple(bins[i + i]) - ) + f"Overlapping bins: " + f"{tuple(bins[i])}, {tuple(bins[i + i])}" ) - # --- End: for two_d_bins = bins bins = np.unique(bins) @@ -1900,8 +1907,8 @@ def digitize( "scalar." ) - mx = self.max().datum() - mn = self.min().datum() + mx = d.max().datum() + mn = d.min().datum() bins = np.linspace(mn, mx, int(bins) + 1, dtype=float) delete_bins = [] @@ -1913,7 +1920,8 @@ def digitize( "Can't set open_ends=True when closed_ends is True." ) - bins = bins.astype(float, copy=True) + if bins.dtype.kind != "f": + bins = bins.astype(float, copy=False) epsilon = np.finfo(float).eps ndim = bins.ndim @@ -1923,53 +1931,27 @@ def digitize( else: mx = bins[(-1,) * ndim] bins[(-1,) * ndim] += abs(mx) * epsilon - # --- End: if if not open_ends: delete_bins.insert(0, 0) delete_bins.append(bins.size) - if return_bins and two_d_bins is None: - x = np.empty((bins.size - 1, 2), dtype=bins.dtype) - x[:, 0] = bins[:-1] - x[:, 1] = bins[1:] - two_d_bins = x - - config = out.partition_configuration(readonly=True) - - for partition in out.partitions.matrix.flat: - partition.open(config) - array = partition.array - - mask = None - if np.ma.isMA(array): - mask = array.mask.copy() - - array = np.digitize(array, bins, right=upper) - - if delete_bins: - for n, d in enumerate(delete_bins): - d -= n - array = np.ma.where(array == d, np.ma.masked, array) - array = np.ma.where(array > d, array - 1, array) - # --- End: if - - if mask is not None: - array = np.ma.where(mask, np.ma.masked, array) - - partition.subarray = array - partition.Units = _units_None - - partition.close() - - out.dtype = int - - out.override_units(_units_None, inplace=True) + # Digitise the array + dx = d._get_dask() + dx = da.digitize(dx, bins, right=upper) + d._set_dask(dx, reset_mask_hardness=True) + d.override_units(_units_None, inplace=True) if return_bins: - return out, type(self)(two_d_bins, units=bin_units) + if two_d_bins is None: + two_d_bins = np.empty((bins.size - 1, 2), dtype=bins.dtype) + two_d_bins[:, 0] = bins[:-1] + two_d_bins[:, 1] = bins[1:] - return out + two_d_bins = type(self)(two_d_bins, units=bin_units) + return d, two_d_bins + + return d def median( self, diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 66421119af..7425e7a460 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -804,7 +804,6 @@ def test_Data__init__dtype_mask(self): self.assertTrue((d.array == a).all()) self.assertTrue((d.mask.array == np.ma.getmaskarray(a)).all()) - @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'") def test_Data_digitize(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: return @@ -829,15 +828,37 @@ def test_Data_digitize(self): b = np.digitize(a, [2, 6, 10, 50, 100], right=upper) self.assertTrue((e.array == b).all()) - - e.where( - cf.set([e.minimum(), e.maximum()]), - cf.masked, - e - 1, - inplace=True, + self.assertTrue( + (np.ma.getmask(e.array) == np.ma.getmask(b)).all() ) - f = d.digitize(bins, upper=upper) - self.assertTrue(e.equals(f, verbose=2)) + + # TODODASK: Reinstate the following test when + # __sub__, minimum, and maximum have + # been daskified + + # e.where( + # cf.set([e.minimum(), e.maximum()]), + # cf.masked, + # e - 1, + # inplace=True, + # ) + # f = d.digitize(bins, upper=upper) + # self.assertTrue(e.equals(f, verbose=2)) + + # Check returned bins + bins = [2, 6, 10, 50, 100] + e, b = d.digitize(bins, return_bins=True) + self.assertTrue( + (b.array == [[2, 6], [6, 10], [10, 50], [50, 100]]).all() + ) + self.assertTrue(b.Units == d.Units) + + # Check digitized units + self.assertTrue(e.Units == cf.Units(None)) + + # Check inplace + self.assertIsNone(d.digitize(bins, inplace=True)) + self.assertTrue(d.equals(e)) @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'") def test_Data_cumsum(self):