Skip to content

[Cherry-Pick][BugFix] Fix real token exceeding max_batched_tokens limit(#7438)#7440

Merged
freeliuzc merged 3 commits intoPaddlePaddle:release/2.5from
freeliuzc:cherry-pick-2.5-fix-max-num-batched-tokens
Apr 17, 2026
Merged

[Cherry-Pick][BugFix] Fix real token exceeding max_batched_tokens limit(#7438)#7440
freeliuzc merged 3 commits intoPaddlePaddle:release/2.5from
freeliuzc:cherry-pick-2.5-fix-max-num-batched-tokens

Conversation

@freeliuzc
Copy link
Copy Markdown
Collaborator

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 16, 2026 12:18
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 16, 2026

Thanks for your contribution!

@freeliuzc freeliuzc changed the title [Cherry-Pick][BugFix] Fix real token exceeding max_batched_tokens limit [Cherry-Pick][BugFix] Fix real token exceeding max_batched_tokens limit(#7438) Apr 16, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在修复调度器在特定情况下“实际 token 数”可能超过 max_num_batched_tokens 上限的问题,尝试在 ResourceManagerV1.schedule() 中调整 token 预算的计算方式(尤其是 speculative decoding 场景)。

Changes:

  • schedule() 内引入基于 speculative token 预留量的 token_budget 新计算逻辑。
  • 增加对 token_budget 的下限兜底,避免出现负数预算。

Comment on lines +728 to +729
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_budget = max(token_budget, min(max_num_batched_tokens, 512)) 这个兜底会在 token_budget 为负时把预算抬到 512(或 max_num_batched_tokens),可能直接抵消前面为 speculative 预留 token 的目的,从而仍然出现 batch token 超限。建议至少把下限改为 0(或 1),并/或在 config check 中增加约束(例如 max_num_batched_tokens >= max_num_seqs * (num_speculative_tokens+1),或按本 batch 实际 decoding seq 数动态计算预留)。

Suggested change
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
# Clamp to zero to avoid negative values without undoing speculative reservation.
token_budget = max(token_budget, 0)

Copilot uses AI. Check for mistakes.
if self.config.speculative_config is not None
else 1
)
# temperatory solution to avoid negative token_budget
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释里 temperatory 拼写错误,建议改为 temporary(代码注释保持英文即可)。

Suggested change
# temperatory solution to avoid negative token_budget
# temporary solution to avoid negative token_budget

Copilot uses AI. Check for mistakes.
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是 Cherry-Pick PR,但标题未按模板要求在末尾附带原始 develop PR 号(例如 ...(#1234))。建议补充原始 PR ID,避免 Cherry-Pick 校验/审批流程受影响。

Copilot uses AI. Check for mistakes.
error_reqs: list[tuple[str, str]] = []
token_budget = self.config.scheduler_config.max_num_batched_tokens
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述里的 Motivation/Modifications/Usage/Accuracy Tests 目前为空,按仓库模板至少需要说明为什么要改、解决什么问题,以及如何验证(命令或最小复现/测试结果)。建议补全这些信息,方便回溯与 release 分支合入评估。

Copilot uses AI. Check for mistakes.
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs
* (self.config.speculative_config.num_speculative_tokens + 1)
if self.config.speculative_config is not None
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里仅判断 speculative_config is not None 就直接读取 num_speculative_tokens 并扣减预算,但在代码库里常见的约定是以 speculative_config.method is not None 作为“启用 speculative”的判定(很多场景下 speculative_config 对象始终存在但 method=None)。否则会在 speculative 未启用时错误扣减,甚至在某些单测/构造 config(例如 SimpleNamespace(method=None))下触发属性不存在异常。建议改为同时判断 speculative_config is not None and speculative_config.method is not None

Suggested change
if self.config.speculative_config is not None
if self.config.speculative_config is not None
and self.config.speculative_config.method is not None

Copilot uses AI. Check for mistakes.
Comment on lines +721 to +729
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs
* (self.config.speculative_config.num_speculative_tokens + 1)
if self.config.speculative_config is not None
else 1
)
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段改动改变了 schedule() 的 token 预算计算逻辑(尤其是 speculative 场景下的扣减/兜底),但当前该文件的单测主要覆盖非 speculative(或 method=None)路径。建议补充至少一个单测:speculative_method=mtp/suffix 时,构造 running+waiting 请求并断言 scheduled 的 prefill token 不会让“实际 token(含 draft)”超过 max_num_batched_tokens,同时覆盖 token_budget 为负/接近 0 的边界。

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +721 to +727
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs
* (self.config.speculative_config.num_speculative_tokens + 1)
if self.config.speculative_config is not None
else 1
)
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的三元表达式在 speculative_config is None 时会把 token_budget 设为 1(而不是原来的 max_num_batched_tokens),会导致调度几乎无法正常工作。建议把“是否启用 speculative”分支写清楚:未启用时直接使用 max_num_batched_tokens,启用时再扣减预留的 speculative token。

Suggested change
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs
* (self.config.speculative_config.num_speculative_tokens + 1)
if self.config.speculative_config is not None
else 1
)
if self.config.speculative_config is None:
token_budget = self.config.scheduler_config.max_num_batched_tokens
else:
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs
* (self.config.speculative_config.num_speculative_tokens + 1)
)

Copilot uses AI. Check for mistakes.
PaddlePaddle-bot

This comment was marked as outdated.

else 1
)
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议直接 assert budget < 0 报错

else 1
)
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议直接 assert budget < 0 报错

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-16 21:30 CST

📋 Review 摘要

PR 概述:Cherry-pick 修复 speculative decoding 场景下实际 token 数超出 max_num_batched_tokens 限制的问题
变更范围engine/sched/resource_manager_v1.py — 调度器 token budget 计算逻辑
影响面 TagScheduler Speculative Decoding

📝 PR 规范检查

PR 标题符合 Cherry-Pick 规范,但描述中 MotivationModifications 章节均未填写,建议补充说明 bug 的根因和修改方案。

描述建议(可直接复制):

## Motivation

在 speculative decoding 场景下,每个 decoding 请求实际消耗 `num_speculative_tokens + 1` 个 token,
但调度器 token_budget 初始化时未预留 decoding 开销,导致实际 batched token 数可能超出
max_num_batched_tokens 限制。

## Modifications

在 `resource_manager_v1.py` 的 `schedule()` 方法中,初始化 token_budget 时预先扣除
`max_num_seqs * tokens_per_seq` 作为 decoding 预留,并设置 floor 防止 budget 为负。

问题

级别 文件 概述
❓ 疑问 resource_manager_v1.py:728 max_num_seqs 全量预扣可能过度压缩 prefill 预算
🟡 建议 resource_manager_v1.py:731 魔法数字 512 缺乏说明,极端场景下 floor 兜底可能不充分

总体评价

修复方向正确,解决了 speculative decoding 下 token budget 未考虑实际 decoding 消耗的问题。但作为 "temporary solution",预扣策略较为保守(按 max_num_seqs 全量预留),可能对 prefill 吞吐有一定影响;同时 floor 值硬编码 512 在极端配置下仍存在超限风险,建议后续迭代优化。

)
token_budget = (
self.config.scheduler_config.max_num_batched_tokens
- self.config.scheduler_config.max_num_seqs * tokens_per_seq
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 此处按 max_num_seqs(最大并发序列数)全量预扣 decoding token 预算,但实际 running 队列中的 decoding 请求数量通常远少于 max_num_seqs,这会导致 prefill 预算被过度压缩,可能影响 prefill 吞吐。

是否考虑过使用 len(self.running) 或实际 decoding 请求数来计算更精确的预留量?例如:

num_running_decode = sum(1 for r in self.running if r.num_computed_tokens >= r.need_prefill_tokens)
token_budget = (
    self.config.scheduler_config.max_num_batched_tokens
    - num_running_decode * tokens_per_seq
)

- self.config.scheduler_config.max_num_seqs * tokens_per_seq
)
# temperatory solution to avoid negative token_budget
token_budget = max(token_budget, min(self.config.scheduler_config.max_num_batched_tokens, 512))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 魔法数字 512 缺乏解释,建议补充注释说明其含义或提取为命名常量。

更重要的是,当 max_num_seqs * tokens_per_seq 较大导致预扣后 budget 为负时,floor 会将 token_budget 拉升到 min(max_num_batched_tokens, 512)。但后续循环中每个 decoding 请求仅减 1(第 791 行),而实际每个 decoding 请求消耗 tokens_per_seq 个 token。极端场景下(如 max_num_batched_tokens=2048, max_num_seqs=256, tokens_per_seq=6),floor 后 budget=512 用于 prefill,加上 decoding 消耗 N * tokens_per_seq 仍可能超出 max_num_batched_tokens

建议:

  1. 512 添加注释说明选择依据
  2. 考虑将 floor 值与 tokens_per_seq 关联,例如:
min_budget = max(self.config.cache_config.block_size, tokens_per_seq)
token_budget = max(token_budget, min_budget)

@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (release/2.5@12e4fff). Learn more about missing BASE report.

Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.5    #7440   +/-   ##
==============================================
  Coverage               ?   68.96%           
==============================================
  Files                  ?      390           
  Lines                  ?    54389           
  Branches               ?     8579           
==============================================
  Hits                   ?    37508           
  Misses                 ?    14186           
  Partials               ?     2695           
Flag Coverage Δ
GPU 68.96% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@freeliuzc freeliuzc merged commit 6119a07 into PaddlePaddle:release/2.5 Apr 17, 2026
34 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants