diff --git a/cf/data/dask_utils.py b/cf/data/dask_utils.py index bcf9b412d2..3c12a65f24 100644 --- a/cf/data/dask_utils.py +++ b/cf/data/dask_utils.py @@ -9,6 +9,7 @@ import dask.array as da import numpy as np +from dask.core import flatten from ..cfdatetime import dt2rt, rt2dt from ..functions import atol as cf_atol @@ -83,7 +84,13 @@ def allclose(a_blocks, b_blocks, rtol=rtol, atol=atol): if not isinstance(b_blocks, list): b_blocks = (b_blocks,) - for a, b in zip(a_blocks, b_blocks): + # Note: If a_blocks or b_blocks has more than one chunk in + # more than one dimension they will comprise a nested + # sequence of sequences, that needs to be flattened so + # that we can safely iterate through the actual numpy + # array elements. + + for a, b in zip(flatten(a_blocks), flatten(b_blocks)): result &= np.ma.allclose( a, b, diff --git a/cf/test/test_Data.py b/cf/test/test_Data.py index 393c91451e..53eed8b5e6 100644 --- a/cf/test/test_Data.py +++ b/cf/test/test_Data.py @@ -420,6 +420,13 @@ def test_Data_equals(self): # Test ignore_data_type parameter self.assertTrue(d2.equals(d, ignore_data_type=True)) + # Test all possible chunk combinations + for j, i in itertools.product([1, 2], [1, 2, 3]): + d = cf.Data(np.arange(6).reshape(2, 3), "m", chunks=(j, i)) + for j, i in itertools.product([1, 2], [1, 2, 3]): + e = cf.Data(np.arange(6).reshape(2, 3), "m", chunks=(j, i)) + self.assertTrue(d.equals(e)) + @unittest.skipIf(TEST_DASKIFIED_ONLY, "hits unexpected kwarg 'ndim'") def test_Data_halo(self): if self.test_only and inspect.stack()[0][3] not in self.test_only: