diff --git a/cf/data/data.py b/cf/data/data.py index 31cb3566dd..004f3eb6b8 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -7542,98 +7542,80 @@ def argmax(self, axis=None, unravel=False): If no axis is specified then the returned index locates the maximum of the whole data. + In case of multiple occurrences of the maximum values, the + indices corresponding to the first occurrence are returned. + + **Performance** + + If the data index is returned as a `tuple` (see the *unravel* + parameter) then all delayed operations are computed. + :Parameters: axis: `int`, optional The specified axis over which to locate the maximum - values. By default the maximum over the whole data is - located. + values. By default the maximum over the flattened data + is located. unravel: `bool`, optional - If True, then when locating the maximum over the whole - data, return the location as a tuple of indices for each - axis. By default an index to the flattened array is - returned in this case. Ignored if locating the maxima over - a subset of the axes. + + If True then when locating the maximum over the whole + data, return the location as an index for each axis as + a `tuple`. By default an index to the flattened array + is returned in this case. Ignored if locating the + maxima over a subset of the axes. :Returns: - `int` or `tuple` or `Data` + `Data` or `tuple` The location of the maximum, or maxima. - **Examples:** + **Examples** + + >>> d = cf.Data(np.arange(6).reshape(2, 3)) + >>> print(d.array) + [[0 1 2] + [3 4 5]] + >>> a = d.argmax() + >>> a + + >>> a.array + 5 + + >>> index = d.argmax(unravel=True) + >>> index + (1, 2) + >>> d[index] + - >>> d = cf.Data(numpy.arange(120).reshape(4, 5, 6)) - >>> d.argmax() - 119 - >>> d.argmax(unravel=True) - (3, 4, 5) >>> d.argmax(axis=0) - + >>> d.argmax(axis=1) - - >>> d.argmax(axis=2) - + - """ - if axis is not None: - ndim = self._ndim - if -ndim - 1 <= axis < 0: - axis += ndim + 1 - elif not 0 <= axis <= ndim: - raise ValueError( - "Can't argmax: Invalid axis specification: Expected " - "-{0}<=axis<{0}, got axis={1}".format(ndim, axis) - ) - - if ndim == 1 and axis == 0: - axis = None - # --- End: if + Only the location of the first occurrence is returned: - if axis is None: - config = self.partition_configuration(readonly=True) - - out = [] - - for partition in self.partitions.matrix.flat: - partition.open(config) - array = partition.array - index = np.unravel_index(array.argmax(), array.shape) - mx = array[index] - index = [x[0] + i for x, i in zip(partition.location, index)] - out.append((mx, index)) - partition.close() - - mx, index = sorted(out)[-1] - - if unravel: - return tuple(index) - - return np.ravel_multi_index(index, self.shape) - - # Parse axis - ndim = self._ndim - if -ndim - 1 <= axis < 0: - axis += ndim + 1 - elif not 0 <= axis <= ndim: - raise ValueError( - "Can't argmax: Invalid axis specification: Expected " - "-{0}<=axis<{0}, got axis={1}".format(ndim, axis) - ) + >>> d = cf.Data([0, 4, 2, 3, 4]) + >>> d.argmax() + - sections = self.section(axis, chunks=True) - for key, d in sections.items(): - array = d.varray.argmax(axis=axis) - array = np.expand_dims(array, axis) - sections[key] = type(self)( - array, self.Units, fill_value=self.fill_value - ) + >>> d = cf.Data(np.arange(6).reshape(2, 3)) + >>> d[1, 1] = 5 + >>> print(d.array) + [[0 1 2] + [3 5 5]] + >>> d.argmax(1) + - out = self.reconstruct_sectioned_data(sections) + """ + dx = self._get_dask() + a = dx.argmax(axis=axis) - out.squeeze(axis, inplace=True) + if unravel and (axis is None or self.ndim <= 1): + # Return a multidimensional index tuple + return tuple(np.array(da.unravel_index(a, self.shape))) - return out + return type(self)(a) def get_data(self, default=ValueError(), _units=None, _fill_value=None): """Returns the data. diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 1707e1cb87..c11b508d23 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -2457,22 +2457,33 @@ def test_Data__round__(self): with self.assertRaises(Exception): _ = round(cf.Data([1, 2])) - @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'") def test_Data_argmax(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: return - d = cf.Data(np.arange(1200).reshape(40, 5, 6)) + d = cf.Data(np.arange(120).reshape(4, 5, 6)) + + self.assertEqual(d.argmax().array, 119) - self.assertEqual(d.argmax(), 1199) - self.assertEqual(d.argmax(unravel=True), (39, 4, 5)) + index = d.argmax(unravel=True) + self.assertEqual(index, (3, 4, 5)) + self.assertEqual(d[index].array, 119) e = d.argmax(axis=1) - self.assertEqual(e.shape, (40, 6)) + self.assertEqual(e.shape, (4, 6)) self.assertTrue( - e.equals(cf.Data.full(shape=(40, 6), fill_value=4, dtype=int)) + e.equals(cf.Data.full(shape=(4, 6), fill_value=4, dtype=int)) ) + self.assertEqual(d[d.argmax(unravel=True)].array, 119) + + d = cf.Data([0, 4, 2, 3, 4]) + self.assertEqual(d.argmax().array, 1) + + # Bad axis + with self.assertRaises(Exception): + d.argmax(axis=d.ndim) + @unittest.skipIf(TEST_DASKIFIED_ONLY, "hits 'NoneType' is not iterable") def test_Data__collapse_SHAPE(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: