Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 54 additions & 72 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7542,98 +7542,80 @@ def argmax(self, axis=None, unravel=False):
If no axis is specified then the returned index locates the
maximum of the whole data.

In case of multiple occurrences of the maximum values, the
indices corresponding to the first occurrence are returned.

**Performance**

If the data index is returned as a `tuple` (see the *unravel*
parameter) then all delayed operations are computed.

:Parameters:

axis: `int`, optional
The specified axis over which to locate the maximum
values. By default the maximum over the whole data is
located.
values. By default the maximum over the flattened data
is located.

unravel: `bool`, optional
If True, then when locating the maximum over the whole
data, return the location as a tuple of indices for each
axis. By default an index to the flattened array is
returned in this case. Ignored if locating the maxima over
a subset of the axes.

If True then when locating the maximum over the whole
data, return the location as an index for each axis as
a `tuple`. By default an index to the flattened array
is returned in this case. Ignored if locating the
maxima over a subset of the axes.

:Returns:

`int` or `tuple` or `Data`
`Data` or `tuple`
The location of the maximum, or maxima.

**Examples:**
**Examples**

>>> d = cf.Data(np.arange(6).reshape(2, 3))
>>> print(d.array)
[[0 1 2]
[3 4 5]]
>>> a = d.argmax()
>>> a
<CF Data(): 5>
>>> a.array
5

>>> index = d.argmax(unravel=True)
>>> index
(1, 2)
>>> d[index]
<CF Data(1, 1): [[5]]>

>>> d = cf.Data(numpy.arange(120).reshape(4, 5, 6))
>>> d.argmax()
119
>>> d.argmax(unravel=True)
(3, 4, 5)
>>> d.argmax(axis=0)
<CF Data(5, 6): [[3, ..., 3]]>
<CF Data(3): [1, 1, 1]>
>>> d.argmax(axis=1)
<CF Data(4, 6): [[4, ..., 4]]>
>>> d.argmax(axis=2)
<CF Data(4, 5): [[5, ..., 5]]>
<CF Data(2): [2, 2]>

"""
if axis is not None:
ndim = self._ndim
if -ndim - 1 <= axis < 0:
axis += ndim + 1
elif not 0 <= axis <= ndim:
raise ValueError(
"Can't argmax: Invalid axis specification: Expected "
"-{0}<=axis<{0}, got axis={1}".format(ndim, axis)
)

if ndim == 1 and axis == 0:
axis = None
# --- End: if
Only the location of the first occurrence is returned:

if axis is None:
config = self.partition_configuration(readonly=True)

out = []

for partition in self.partitions.matrix.flat:
partition.open(config)
array = partition.array
index = np.unravel_index(array.argmax(), array.shape)
mx = array[index]
index = [x[0] + i for x, i in zip(partition.location, index)]
out.append((mx, index))
partition.close()

mx, index = sorted(out)[-1]

if unravel:
return tuple(index)

return np.ravel_multi_index(index, self.shape)

# Parse axis
ndim = self._ndim
if -ndim - 1 <= axis < 0:
axis += ndim + 1
elif not 0 <= axis <= ndim:
raise ValueError(
"Can't argmax: Invalid axis specification: Expected "
"-{0}<=axis<{0}, got axis={1}".format(ndim, axis)
)
>>> d = cf.Data([0, 4, 2, 3, 4])
>>> d.argmax()
<CF Data(): 1>

sections = self.section(axis, chunks=True)
for key, d in sections.items():
array = d.varray.argmax(axis=axis)
array = np.expand_dims(array, axis)
sections[key] = type(self)(
array, self.Units, fill_value=self.fill_value
)
>>> d = cf.Data(np.arange(6).reshape(2, 3))
>>> d[1, 1] = 5
>>> print(d.array)
[[0 1 2]
[3 5 5]]
>>> d.argmax(1)
<CF Data(2): [2, 1]>

out = self.reconstruct_sectioned_data(sections)
"""
dx = self._get_dask()
a = dx.argmax(axis=axis)

out.squeeze(axis, inplace=True)
if unravel and (axis is None or self.ndim <= 1):
# Return a multidimensional index tuple
return tuple(np.array(da.unravel_index(a, self.shape)))

return out
return type(self)(a)

def get_data(self, default=ValueError(), _units=None, _fill_value=None):
"""Returns the data.
Expand Down
23 changes: 17 additions & 6 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2457,22 +2457,33 @@ def test_Data__round__(self):
with self.assertRaises(Exception):
_ = round(cf.Data([1, 2]))

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_argmax(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return

d = cf.Data(np.arange(1200).reshape(40, 5, 6))
d = cf.Data(np.arange(120).reshape(4, 5, 6))

self.assertEqual(d.argmax().array, 119)

self.assertEqual(d.argmax(), 1199)
self.assertEqual(d.argmax(unravel=True), (39, 4, 5))
index = d.argmax(unravel=True)
self.assertEqual(index, (3, 4, 5))
self.assertEqual(d[index].array, 119)

e = d.argmax(axis=1)
self.assertEqual(e.shape, (40, 6))
self.assertEqual(e.shape, (4, 6))
self.assertTrue(
e.equals(cf.Data.full(shape=(40, 6), fill_value=4, dtype=int))
e.equals(cf.Data.full(shape=(4, 6), fill_value=4, dtype=int))
)

self.assertEqual(d[d.argmax(unravel=True)].array, 119)

d = cf.Data([0, 4, 2, 3, 4])
self.assertEqual(d.argmax().array, 1)

# Bad axis
with self.assertRaises(Exception):
d.argmax(axis=d.ndim)

@unittest.skipIf(TEST_DASKIFIED_ONLY, "hits 'NoneType' is not iterable")
def test_Data__collapse_SHAPE(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
Expand Down