diff --git a/nemo_curator/_compat.py b/nemo_curator/_compat.py index 1dc07d9e07..a89426d529 100644 --- a/nemo_curator/_compat.py +++ b/nemo_curator/_compat.py @@ -20,3 +20,4 @@ # TODO: remove when dask min version gets bumped DASK_SHUFFLE_METHOD_ARG = _dask_version > parseVersion("2024.1.0") DASK_P2P_ERROR = _dask_version < parseVersion("2023.10.0") +DASK_SHUFFLE_CAST_DTYPE = _dask_version > parseVersion("2023.12.0") diff --git a/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py b/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py index a144b5602d..70bf73004c 100644 --- a/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py +++ b/nemo_curator/utils/fuzzy_dedup_utils/merge_utils.py @@ -16,13 +16,14 @@ from operator import getitem import numpy as np +import pandas as pd from dask.base import tokenize from dask.dataframe.core import new_dd_object from dask.dataframe.shuffle import partitioning_index from dask.highlevelgraph import HighLevelGraph from dask.utils import M -from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import rearange_by_column_direct +from nemo_curator._compat import DASK_SHUFFLE_CAST_DTYPE def _split_part(part, nsplits): @@ -129,6 +130,21 @@ def extract_partitioning_index( # a partition-wise merge between `left_df` and `right_df`. # We call this `global_partitioning_index`: + if DASK_SHUFFLE_CAST_DTYPE: + # Need to use the same type-casting logic as `shuffle` + dtypes = {} + if not isinstance(merge_on, list): + merge_on = [merge_on] + for col, dtype in left_df[merge_on].dtypes.items(): + if pd.api.types.is_numeric_dtype(dtype): + dtypes[col] = np.float64 + if not dtypes: + dtypes = None + cast_dtype = {"cast_dtype": dtypes} + else: + # `cast_dtype` argument doesn't exist yet + cast_dtype = {} + num_bucket_files = bk_mapping.file_id.max() + 1 global_partitioning_index = left_df[merge_on].map_partitions( partitioning_index, @@ -137,6 +153,7 @@ def extract_partitioning_index( enforce_metadata=False, transform_divisions=False, align_dataframes=False, + **cast_dtype, ) if total_bucket_partitions < num_bucket_files: @@ -157,7 +174,7 @@ def extract_partitioning_index( # want to send the rows of `left_df` to the partition # indices encoded in `global_partitioning_index`. Instead, we # need to take a modulus with `parts_per_bucket_batch` to - # define a `"_partitoins"` column. + # define a `"_partitions"` column. left_df["_partitions"] = global_partitioning_index % parts_per_bucket_batch return left_df, global_partitioning_index @@ -195,6 +212,10 @@ def merge_left_to_shuffled_right( subset_bucket_df, merge_on, ): + from nemo_curator.utils.fuzzy_dedup_utils.shuffle_utils import ( + rearange_by_column_direct, + ) + # We are merging an unshuffled batch of "left" partitions # with a shuffled batch of "right" partitions. To minimize # data movement, we can manaully rerrange the "left" batch diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index e89f998e0e..1c952d27d3 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -16,14 +16,17 @@ from itertools import combinations from typing import Iterable +import dask.dataframe as dd import numpy as np import pytest import yaml +from dask import config from dask.dataframe.utils import assert_eq from distributed import Client from nemo_curator import LSH, FuzzyDuplicates, FuzzyDuplicatesConfig, MinHash from nemo_curator.datasets import DocumentDataset +from nemo_curator.utils.fuzzy_dedup_utils.merge_utils import extract_partitioning_index from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from cudf = gpu_only_import("cudf") @@ -367,3 +370,74 @@ def test_from_yaml(self, tmpdir): config = FuzzyDuplicatesConfig.from_yaml(tmpdir / "config.yaml") for param in yaml_params: assert getattr(config, param) == yaml_params[param] + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param( + "cudf", + marks=pytest.mark.gpu, + ), + ], +) +def test_extract_partitioning_index(backend): + + def add_partition_info(df, partition_info=None): + if partition_info is None: + df["file_id"] = -1 + else: + df["file_id"] = partition_info["number"] + return df + + with config.set({"dataframe.backend": backend}): + + # Create a random `unshuffled` DataFrame with a + # "part_id" column to be used as the shuffle index + npartitions_left = 7 + unshuffled = dd.from_dict( + {"part_id": np.random.randint(25, size=1000, dtype="int32")}, + npartitions=npartitions_left, + ) + + # Create a `bk_mapping` DataFrame that defines + # the "correct" mapping beween "part_id" and + # the destination partition ("file_id") + npartitions_right = 5 + bk_mapping = ( + dd.from_dict( + {"part_id": np.arange(25, dtype="int32")}, + npartitions=npartitions_right, + ) + .shuffle("part_id") + .map_partitions(add_partition_info) + .compute() + ) + + # Use `extract_partitioning_index` to calculate + # the partitioning index and assign it as a new + # "_partitions" column + result, _ = extract_partitioning_index( + unshuffled, + "part_id", + bk_mapping, + npartitions_right, + npartitions_right, + ) + + # Rename the "_partitions" column, shuffle by "part_id", + # and then assign a "file_id" column to reflect the final + # partition of each row + check = ( + result.rename(columns={"_partitions": "expected_file_id"}) + .shuffle( + "part_id", + npartitions=npartitions_right, + ) + .map_partitions(add_partition_info) + .compute() + ) + + # Check that the real and expected partitions match + assert (check["file_id"] == check["expected_file_id"]).all()