Skip to content

[Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy#7467

Merged
freeliuzc merged 2 commits intoPaddlePaddle:developfrom
freeliuzc:fix_repeat_times_dev
Apr 17, 2026
Merged

[Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy#7467
freeliuzc merged 2 commits intoPaddlePaddle:developfrom
freeliuzc:fix_repeat_times_dev

Conversation

@freeliuzc
Copy link
Copy Markdown
Collaborator

@freeliuzc freeliuzc commented Apr 17, 2026

Motivation

  1. 之前重构 token_ids_all 时,对投机解码的 repeat kernel 重构错误,导致会访问越界
  2. 更改目前使用更多的 target_match 为默认验证策略

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

  1. 修复 Kernel
  2. 更改config默认值

Usage or Command

None

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 17, 2026 09:55
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 17, 2026

Thanks for your contribution!

yuanlehome
yuanlehome previously approved these changes Apr 17, 2026
@freeliuzc freeliuzc changed the title [Speculative Decoding] Fix apply repeat times penalty kernel and change spec default verify strategy [Speculative Decoding][BugFix] Fix apply repeat times penalty kernel and change spec default verify strategy Apr 17, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 聚焦于投机解码(Speculative Decoding)链路中的两个点:修复 repeat-times penalty 相关 CUDA kernel 在重构后可能出现的越界访问,并将 speculative decoding 的默认验证策略调整为更偏向 target_match 的行为。

Changes:

  • 修复 speculate_get_token_penalty_multi_scores 中 repeat times 统计的循环边界,避免 prompt 偏移后访问越界。
  • 更新对应的 Python reference 测试逻辑,使其按 cur_len 限制遍历范围。
  • SpeculativeConfig.verify_strategy 默认值从 topp 调整为 target_match

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
tests/operators/test_speculate_get_token_penalty_multi_scores.py 调整 reference 侧 repeat times 统计范围以匹配新的 cur_len 上界假设(但当前仍未覆盖 prompt 偏移场景)。
fastdeploy/config.py 修改 speculative decoding 的默认 verify strategy 为 target_match
custom_ops/gpu_ops/speculate_decoding/speculate_get_token_penalty_multi_scores.cu 修复 CUDA kernel 中 repeat times 统计的 loop 上界,降低越界风险并减少无效遍历。

Comment thread fastdeploy/config.py
Comment on lines 775 to 779
"benchmark_mode": False,
"enf_gen_phase_tag": False,
"enable_draft_logprob": False,
"verify_strategy": "topp",
"verify_strategy": "target_match",
"accept_policy": "normal",
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把 speculative decoding 的默认 verify_strategy 从 "topp" 改成了 "target_match",属于用户可见行为变更;目前文档仍写的是 topp (default),会造成配置说明与实际默认值不一致。建议同步更新 docs/features/speculative_decoding.md 与 docs/zh/features/speculative_decoding.md 中关于默认策略的描述,或在文档中明确默认值已调整。

Copilot uses AI. Check for mistakes.
Comment on lines +96 to 101
const int64_t* pre_ids_now = token_ids_all + bi * length_id + prompt_lens[bi];
int* repeat_times_now = repeat_times + token_idx * length;
for (int i = tid; i < cur_len[bi]; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_repeat_times 的循环上界从 length_id 改为 cur_len[bi] 解决了 prompt_lens 偏移后可能越界的问题,但当前单测 tests/operators/test_speculate_get_token_penalty_multi_scores.py 里 prompt_lens 全为 0,无法覆盖这次修复的关键路径。建议补充一个 prompt_lens 非 0(且接近 length_id 边界)的用例,确保不会再出现 bi*length_id+prompt_lens+cur_len 越界访问。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines 61 to 66
token_ids_all_now = token_ids_all[bi]
repeat_times_now = repeat_times[token_idx]

for i in range(length_id):
for i in range(cur_len[bi]):
id = token_ids_all_now[i]
if id < 0:
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试里的 reference 实现 update_repeat_times 目前完全不考虑 prompt_lens(而 CUDA kernel 会用 prompt_lens 做偏移)。这会导致一旦补充 prompt_lens 非 0 的用例时,reference 结果与 kernel 行为不一致,从而掩盖/误报问题。建议让 reference 也按 prompt_lens 偏移读取 token_ids_all,再结合 cur_len 统计重复次数。

Copilot generated this review using guidance from repository custom instructions.
Deleter-D
Deleter-D previously approved these changes Apr 17, 2026
Copy link
Copy Markdown
Collaborator

@Deleter-D Deleter-D left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

PaddlePaddle-bot

This comment was marked as outdated.

ming1753
ming1753 previously approved these changes Apr 17, 2026
Copy link
Copy Markdown
Collaborator

@ming1753 ming1753 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@a729e0f). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7467   +/-   ##
==========================================
  Coverage           ?   74.17%           
==========================================
  Files              ?      398           
  Lines              ?    54987           
  Branches           ?     8616           
==========================================
  Hits               ?    40786           
  Misses             ?    11468           
  Partials           ?     2733           
Flag Coverage Δ
GPU 74.17% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-17 22:19 CST

📋 Review 摘要

PR 概述:修复投机解码 repeat penalty kernel 的越界访问 bug,并将默认验证策略从 topp 改为 target_match
变更范围custom_ops/gpu_ops/speculate_decoding/fastdeploy/config.py、测试文件
影响面 TagSpeculative Decoding OP FDConfig

📝 PR 规范检查

Accuracy Tests 章节为空。此 PR 修复了 CUDA kernel 的越界访问 bug,建议补充精度对比测试结果以证明修复正确性。

问题

级别 文件 概述
🟡 建议 tests/layers/test_speculative_sampler.py 缺少 target_match 验证策略的测试用例

总体评价

核心 bug 修复正确且重要:update_repeat_times kernel 中 pre_ids_now 已通过 prompt_lens[bi] 做了偏移,之前以 length_id(整行 buffer 长度)为循环上界会导致读取超出当前 batch item 行边界的内存,改为 cur_len[bi] 后仅遍历实际有效的 generated token,修复了越界问题。Python 参考实现和单元测试已同步修改。默认策略变更合理,测试也做了相应适配。


# Use ngram method for speculative decoding
fd_config = _create_fd_config(max_model_len, method="ngram")
fd_config = _create_fd_config(max_model_len, method="ngram", verify_strategy="topp")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 缺少 target_match 验证策略的测试覆盖

当前所有 speculative sampler 测试用例都显式指定了 verify_strategy="topp",但本 PR 已将默认策略改为 target_match。建议至少新增一个使用 verify_strategy="target_match" 的测试用例(或直接使用默认值不传参),以确保新默认策略路径也有测试覆盖。

freeliuzc added a commit that referenced this pull request Apr 17, 2026
…nalty kernel and change spec default verify strategy(#7467) (#7468)

* fix repeat_time kernel and change default spec verify strategy

* fix unit_test
@freeliuzc freeliuzc merged commit 22a4f60 into PaddlePaddle:develop Apr 17, 2026
51 of 57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants