[2/3][Feat]: Offline DFlash training#1295
[2/3][Feat]: Offline DFlash training#1295h-guo18 wants to merge 5 commits intohaoguo/spec-file-reorgfrom
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## haoguo/spec-file-reorg #1295 +/- ##
==========================================================
- Coverage 75.56% 75.56% -0.01%
==========================================================
Files 466 466
Lines 50238 50232 -6
==========================================================
- Hits 37962 37957 -5
+ Misses 12276 12275 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9e4eeb0 to
f208109
Compare
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
f208109 to
178b191
Compare
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: new feature
Part 2 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
dflash_offlineflag toDFlashConfigfor training from pre-computed hidden states; deletes base model layers to save memory.DFlashConfig:_derive_dflash_offline— derivedflash_offlinefromdata_args.offline_data_pathin validation context._resolve_mask_token_id— auto-detectdflash_mask_token_idfromtokenizer.mask_token_id._check_mask_token_id— fail fast if unset after resolution.HFDFlashModel.modify(): selectnum_orig_hidden_layerswhen offline; pick_base_model_lm_headdevice when no base layers present; drop base-modellayersmodule.HFDFlashModel.forward(): add offline branch — consumes precomputedbase_model_outputsviaDFlashBaseModelOutput.from_offline_dict, and whendflash_self_logit_distillationis enabled withbase_model_logitsabsent, recomputes logits frombase_model_hidden_statesvia_base_model_lm_head.DFlashBaseModelOutputdataclass inmodeling_dflash.py(withfrom_offline_dictclassmethod) to unify online/offline output shapes.examples/speculative_decoding/main.py: replace inlinemask_token_idauto-detect withDFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).Usage
Testing
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py— CPU unit tests for convert path (online keeps base layers, offline deletes them;num_orig_hidden_layersdrivestarget_layer_idsin offline mode) andDFlashConfig._derive_dflash_offlinevalidator.TestDFlashOfflineForwardGPUintests/gpu/torch/speculative/plugins/test_hf_dflash.py— GPU forward smoke with precomputedbase_model_outputs, plus thedflash_self_logit_distillationlogit-recompute path.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).dflash_offlineflag defaulting toFalse; validators fall through when context not provided.CONTRIBUTING.md: N/ATODO (follow-up)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.pyto support DFlash offline data. Current scripts are Eagle-specific — they hardcode the[2, N/2, N-3]aux-layer selection and emit{input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)(or a configurable list), not the Eagle triplet.base_model_hidden_stateskey (last-layer hidden) soDFlashBaseModelOutput.from_offline_dict+ thedflash_self_logit_distillationrecompute path can consume it.base_model_logitsdump so offline training can skip the self-distillation logit recomputation when logits are available.Additional Information
Base branch is #1296 (file reorg). Retarget to
mainonce #1296 merges.