diff --git a/cf/data/data.py b/cf/data/data.py index c93d48f7d4..11e9346712 100644 --- a/cf/data/data.py +++ b/cf/data/data.py @@ -11661,6 +11661,7 @@ def stats( return out + @daskified(_DASKIFIED_VERBOSE) @_deprecated_kwarg_check("i") @_inplace_enabled(default=False) def swapaxes(self, axis0, axis1, inplace=False, i=False): @@ -11684,7 +11685,7 @@ def swapaxes(self, axis0, axis1, inplace=False, i=False): `Data` or `None` The data with swapped axis positions. - **Examples:** + **Examples** >>> d = cf.Data([[[1, 2, 3], [4, 5, 6]]]) >>> d @@ -11700,15 +11701,9 @@ def swapaxes(self, axis0, axis1, inplace=False, i=False): """ d = _inplace_enabled_define_and_cleanup(self) - - axis0 = d._parse_axes((axis0,))[0] - axis1 = d._parse_axes((axis1,))[0] - - if axis0 != axis1: - iaxes = list(range(d._ndim)) - iaxes[axis1], iaxes[axis0] = axis0, axis1 - d.transpose(iaxes, inplace=True) - + dx = self._get_dask() + dx = da.swapaxes(dx, axis0, axis1) + d._set_dask(dx, reset_mask_hardness=False) return d def save_to_disk(self, itemsize=None): diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 7e991b8f68..826fdcf36f 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -1927,22 +1927,21 @@ def test_Data_roll(self): self.assertEqual(f.shape, d.shape) self.assertTrue(f.equals(d, verbose=2)) - @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'") def test_Data_swapaxes(self): - if self.test_only and inspect.stack()[0][3] not in self.test_only: - return - - a = np.arange(10 * 15 * 19).reshape(10, 1, 15, 19) - - d = cf.Data(a.copy()) + a = np.ma.arange(24).reshape(2, 3, 4) + a[1, 1] = np.ma.masked + d = cf.Data(a, chunks=(-1, -1, 2)) for i in range(-a.ndim, a.ndim): for j in range(-a.ndim, a.ndim): - b = np.swapaxes(a.copy(), i, j) + b = np.swapaxes(a, i, j) e = d.swapaxes(i, j) - message = "cf.Data.swapaxes({}, {}) failed".format(i, j) - self.assertEqual(b.shape, e.shape, message) - self.assertTrue((b == e.array).all(), message) + self.assertEqual(b.shape, e.shape) + self.assertTrue((b == e.array).all()) + + # Bad axes + with self.assertRaises(IndexError): + d.swapaxes(3, -3) def test_Data_transpose(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: @@ -1956,12 +1955,8 @@ def test_Data_transpose(self): for axes in itertools.permutations(indices): a = np.transpose(a, axes) d.transpose(axes, inplace=True) - message = ( - "cf.Data.transpose({}) failed: " - "d.shape={}, a.shape={}".format(axes, d.shape, a.shape) - ) - self.assertEqual(d.shape, a.shape, message) - self.assertTrue((d.array == a).all(), message) + self.assertEqual(d.shape, a.shape) + self.assertTrue((d.array == a).all()) @unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'") def test_Data_unique(self):