Hf checkpoint conversion for distributed checkpoints#424
Hf checkpoint conversion for distributed checkpoints#424BlueCrescent wants to merge 33 commits intomainfrom
Conversation
|
Should we maybe move the conversion directory into the checkpointing directory with this PR (after review). |
There was a problem hiding this comment.
Pull request overview
This PR adds support for converting distributed checkpoint (DCP) formats (FSDP2, PP, TP) to HuggingFace transformers format. The conversion is implemented as a two-step process: first converting DCP checkpoints to standard PyTorch format, then using the existing conversion pipeline to create HuggingFace models.
- Added new
convert_dcp_to_torchmodule to handle DCP-to-PyTorch checkpoint conversion - Extended the GPT-2 conversion script to support DCP checkpoints via
--dcpflag - Introduced
ConfigDictTypetype alias for better type consistency across configuration handling
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| src/modalities/checkpointing/convert_dcp_to_torch.py | New module implementing DCP to PyTorch checkpoint conversion with config file transformation |
| src/modalities/conversion/gpt2/convert_gpt2.py | Added --dcp flag support, new convert_gpt2_dcp function, and refactored main entry point |
| src/modalities/conversion/gpt2/conversion_model.py | Updated type hints to use ConfigDictType and added dtype assertion in model checking |
| src/modalities/config/config.py | Added ConfigDictType alias, new save_yaml_config_dict function, and fixed implicit return in resolver |
| src/modalities/models/utils.py | Updated type hints to use ConfigDictType for consistency |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…env context manager.
- Now only loading model weights into memory (no optimizer or scheduler weights). - Always creating a FP32 config since FSDP2 always has FP32 weights. - Disabled overwriting of existing config files.
…ity again. This was not required after all.
- Detection and warning if another attention implementation than Huggignface default is used since this is not saved with the checkpoint. - Correct handling and matching of FSDP2 mixed precision behavior. (In particular for rotary pos embeddings).
…llama implementation.
At this time, this bug seems to be fixed in main and we should be able to use a version >4.57.3 once it is released. Problematic line: https://github.com/huggingface/transformers/blob/47b0e478f324b54f177ea7998a0791870fdd0324/src/transformers/utils/generic.py#L947 Fixed version: https://github.com/huggingface/transformers/blob/d3ee06b8cb5e45aab51b85aafd54f4b3f7cad2e2/src/transformers/utils/generic.py#L791
…onment variables.
…l uses of that function work.
…for missing fields in the original config.
| self._env_override = EnvOverride( | ||
| { | ||
| "MASTER_ADDR": "localhost", | ||
| "MASTER_PORT": str(rdvz_port), | ||
| "RANK": str(global_rank), | ||
| "LOCAL_RANK": str(local_rank), | ||
| "WORLD_SIZE": str(world_size), | ||
| } | ||
| ) |
There was a problem hiding this comment.
Isn't this stuff typically taken care of by torchrun ? Why do we need the CudaEnv class ?
There was a problem hiding this comment.
The MultiProcessingCudaEnv is useful when running distributed stuff from Python directly. Previously, we only used this for our unit tests. For conversion, it was also necessary to have this to work with the DCP model in the conversion script.
…conversion_for_fsdp2
- also added missing test config
…e transformers it is not necessary anymore to retain buffers in fp32 - confirmed for torch 2.11.0
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 30 out of 30 changed files in this pull request and generated 11 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def __enter__(self): | ||
| for key, value in self._overrides.items(): | ||
| self._original[key] = os.environ.get(key) | ||
| os.environ[key] = value |
There was a problem hiding this comment.
Missing return statement in __enter__ method. Context manager __enter__ methods should return self to allow usage in with statements like with EnvOverride(...) as env:. Add return self at the end of the method.
| os.environ[key] = value | |
| os.environ[key] = value | |
| return self |
|
|
||
| def _load_hf_model_for_dcp_comparison( | ||
| hf_model_dir: str, dcp_modalities_config: ConfigDictType, device_hf: str | ||
| ) -> GPT2ForCausalLM: |
There was a problem hiding this comment.
Missing docstring for public function _load_hf_model_for_dcp_comparison. Although it's a private function (prefixed with _), other private functions in this file like _check_conversion_criteria, _get_layer_norm_value, and _map_attention_type have docstrings. Consider adding a docstring for consistency.
| ) -> GPT2ForCausalLM: | |
| ) -> GPT2ForCausalLM: | |
| """Load a Hugging Face GPT-2 model configured to match a DCP-converted modalities model. | |
| The model is loaded from ``hf_model_dir``, moved to the specified device, cast to the | |
| execution dtype defined in the FSDP mixed precision settings of ``dcp_modalities_config``, | |
| and its attention implementation is updated to mirror the attention type used by the | |
| DCP configuration. This ensures comparable outputs when validating the conversion. | |
| Args: | |
| hf_model_dir (str): Directory containing the pretrained Hugging Face GPT-2 checkpoint. | |
| dcp_modalities_config (ConfigDictType): Modalities configuration derived from the DCP | |
| checkpoint, used to determine execution dtype and attention implementation. | |
| device_hf (str): Device identifier (e.g. ``"cuda:0"`` or ``"cpu"``) to place the model on. | |
| Returns: | |
| GPT2ForCausalLM: The loaded and configured Hugging Face GPT-2 model. | |
| """ |
| @pytest.mark.skipif(torch.cuda.device_count() < 8, reason="This test requires 8 GPUs.") | ||
| def test_converting_dcp_gpt2_does_not_change_weights(converted_dcp_model: PreTrainedModel, dcp_checkpoint: str): | ||
| new_config: ConfigDictType = _build_single_node_dcp_config(dcp_checkpoint) | ||
| with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=0): |
There was a problem hiding this comment.
The device_id parameter passed to MultiProcessingCudaEnv is not valid. The MultiProcessingCudaEnv.__init__ method accepts **process_group_kwargs which are forwarded to dist.init_process_group(), but device_id is not a valid parameter for dist.init_process_group(). The CUDA device is set based on the LOCAL_RANK environment variable in the parent CudaEnv.__enter__ method (line 48 of cuda_env.py). This device_id argument will either be silently ignored or cause a TypeError at runtime.
| with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=0): | |
| with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570): |
| vocab_size: int = new_config["model_raw" if "model_raw" in new_config else "model"]["config"]["vocab_size"] | ||
| if isinstance(device_id_modalities, str): | ||
| device_id_modalities = int(device_id_modalities.replace("cuda:", "")) | ||
| with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=device_id_modalities): |
There was a problem hiding this comment.
Hardcoded port number 24570 may cause conflicts if multiple conversion processes run simultaneously or if the port is already in use. Consider using a dynamically allocated port or making it configurable. Other parts of the test code use find_free_port() to avoid such conflicts (e.g., tests/conversion/gpt2/conftest.py:96).
| hf_model_dir: str, dcp_modalities_config: ConfigDictType, device_hf: str | ||
| ) -> GPT2ForCausalLM: | ||
| # Need execution dtype of FSDP2 to get same outputs from model. | ||
| dtype = dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] |
There was a problem hiding this comment.
Accessing nested config key dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| dtype = dcp_modalities_config["fsdp_model"]["config"]["mixed_precision_settings"]["param_dtype"] | |
| try: | |
| fsdp_model_cfg = dcp_modalities_config["fsdp_model"]["config"] | |
| mixed_precision_settings = fsdp_model_cfg["mixed_precision_settings"] | |
| dtype = mixed_precision_settings["param_dtype"] | |
| except (KeyError, TypeError) as exc: | |
| raise ValueError( | |
| "Invalid DCP modalities config: expected " | |
| "'fsdp_model.config.mixed_precision_settings.param_dtype' to be present and correctly structured " | |
| "in order to load the HF model for comparison." | |
| ) from exc |
| vocab_size: int = new_config["model_raw" if "model_raw" in new_config else "model"]["config"]["vocab_size"] | ||
| if isinstance(device_id_modalities, str): | ||
| device_id_modalities = int(device_id_modalities.replace("cuda:", "")) | ||
| with MultiProcessingCudaEnv(ProcessGroupBackendType.nccl, 0, 0, 1, 24570, device_id=device_id_modalities): |
There was a problem hiding this comment.
The device_id parameter passed to MultiProcessingCudaEnv is not valid. The MultiProcessingCudaEnv.__init__ method accepts **process_group_kwargs which are forwarded to dist.init_process_group(), but device_id is not a valid parameter for dist.init_process_group(). The CUDA device is set based on the LOCAL_RANK environment variable in the parent CudaEnv.__enter__ method (line 48 of cuda_env.py). This device_id argument will either be silently ignored or cause a TypeError at runtime.
| f" Available keys: {list(dcp_config.keys())}" | ||
| ) | ||
| torch_config["model"] = dcp_config[model_key] | ||
| torch_config["model"]["config"]["use_meta_device"] = False |
There was a problem hiding this comment.
Accessing torch_config["model"]["config"]["use_meta_device"] without validation could raise KeyError or TypeError if the config structure from dcp_config[model_key] doesn't contain these nested keys. Consider adding validation or using .get() with proper error handling to provide a clear error message about the expected config structure.
| torch_config["model"]["config"]["use_meta_device"] = False | |
| model_section = torch_config.get("model") | |
| if not isinstance(model_section, dict): | |
| raise TypeError( | |
| f"Expected 'model' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_section).__name__!r}." | |
| ) | |
| model_config = model_section.get("config") | |
| if not isinstance(model_config, dict): | |
| raise TypeError( | |
| f"Expected 'model.config' section in config file '{config_src}' to be a mapping, " | |
| f"but got {type(model_config).__name__!r}." | |
| ) | |
| model_config["use_meta_device"] = False |
| "world_size": 1, | ||
| }, | ||
| } | ||
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key |
There was a problem hiding this comment.
Accessing nested config keys new_config["fsdp_model"]["config"]["model"]["instance_key"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key | |
| fsdp_model_cfg = new_config.get("fsdp_model") | |
| if not isinstance(fsdp_model_cfg, dict): | |
| raise ValueError("Expected 'fsdp_model' in DCP config to be a dict.") | |
| fsdp_model_config = fsdp_model_cfg.setdefault("config", {}) | |
| if not isinstance(fsdp_model_config, dict): | |
| raise ValueError("Expected 'fsdp_model[\"config\"]' in DCP config to be a dict.") | |
| fsdp_model_model_cfg = fsdp_model_config.setdefault("model", {}) | |
| if not isinstance(fsdp_model_model_cfg, dict): | |
| raise ValueError("Expected 'fsdp_model[\"config\"][\"model\"]' in DCP config to be a dict.") | |
| fsdp_model_model_cfg["instance_key"] = model_key |
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key | ||
| new_config["initialized_model"]["config"]["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} |
There was a problem hiding this comment.
Accessing nested config keys new_config["initialized_model"]["config"]["model"] without validation could raise KeyError or TypeError if the config structure doesn't match expectations. Consider adding validation or using safer access patterns with clear error messages.
| new_config["fsdp_model"]["config"]["model"]["instance_key"] = model_key | |
| new_config["initialized_model"]["config"]["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} | |
| # Ensure fsdp_model has the expected nested config/model structure | |
| fsdp_model_cfg = new_config.get("fsdp_model") | |
| if not isinstance(fsdp_model_cfg, dict): | |
| raise TypeError( | |
| f"Expected 'fsdp_model' to be a dict in DCP config, got {type(fsdp_model_cfg).__name__}" | |
| ) | |
| fsdp_model_config = fsdp_model_cfg.setdefault("config", {}) | |
| if not isinstance(fsdp_model_config, dict): | |
| raise TypeError( | |
| "Expected 'fsdp_model[\"config\"]' to be a dict in DCP config, " | |
| f"got {type(fsdp_model_config).__name__}" | |
| ) | |
| fsdp_model_model = fsdp_model_config.setdefault("model", {}) | |
| if not isinstance(fsdp_model_model, dict): | |
| raise TypeError( | |
| "Expected 'fsdp_model[\"config\"][\"model\"]' to be a dict in DCP config, " | |
| f"got {type(fsdp_model_model).__name__}" | |
| ) | |
| fsdp_model_model["instance_key"] = model_key | |
| # Ensure initialized_model has the expected nested config/model structure | |
| initialized_model_cfg = new_config.get("initialized_model") | |
| if not isinstance(initialized_model_cfg, dict): | |
| raise TypeError( | |
| f"Expected 'initialized_model' to be a dict in DCP config, got {type(initialized_model_cfg).__name__}" | |
| ) | |
| initialized_model_config = initialized_model_cfg.setdefault("config", {}) | |
| if not isinstance(initialized_model_config, dict): | |
| raise TypeError( | |
| "Expected 'initialized_model[\"config\"]' to be a dict in DCP config, " | |
| f"got {type(initialized_model_config).__name__}" | |
| ) | |
| initialized_model_config["model"] = {"instance_key": "fsdp_model", "pass_type": "BY_REFERENCE"} |
| # Restore original environment variables | ||
| self._env_override.__exit__(exc_type, exc_val, exc_tb) | ||
| super().__exit__(exc_type, exc_val, exc_tb) |
There was a problem hiding this comment.
The __exit__ method calls self._env_override.__exit__() before super().__exit__(). This means environment variables are restored before the CUDA environment is cleaned up. If the parent CudaEnv.__exit__ uses environment variables like LOCAL_RANK (which it does on line 59 of cuda_env.py), they will have been restored to their original values. Consider reversing the order: call super().__exit__() first, then self._env_override.__exit__() to ensure proper cleanup order.
| # Restore original environment variables | |
| self._env_override.__exit__(exc_type, exc_val, exc_tb) | |
| super().__exit__(exc_type, exc_val, exc_tb) | |
| # First, perform CUDA/distributed cleanup while overrides are still in effect | |
| super().__exit__(exc_type, exc_val, exc_tb) | |
| # Then restore original environment variables | |
| self._env_override.__exit__(exc_type, exc_val, exc_tb) |
What does this PR do?
Implements checkpoint conversion for DCP checkpoints (FSDP2, PP, TP). For this the checkpoint first gets converted to a normal Pytorch checkpoint (together with a corresponding config) and then gets converted using the existing code.
Note: Currently, no new tests were added due to the effort of creating and manipulating a dcp checkpoint as needed for those tests.
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)