Add Prefill–Decode Separation with Batched Prompt Ingestion and Logits Skipping #102
Open
orionpapadakis wants to merge 23 commits intobeehive-lab:mainfrom
Open
Add Prefill–Decode Separation with Batched Prompt Ingestion and Logits Skipping #102orionpapadakis wants to merge 23 commits intobeehive-lab:mainfrom
orionpapadakis wants to merge 23 commits intobeehive-lab:mainfrom
Conversation
…ith InferenceCoreWithPrefillDecode and InferenceEngineWithPrefillDecode
… Implements `InferenceEngineWithPrefillDecode` and `TornadoVMMasterPlanWithPrefillDecode` for batched token generation. Refactor `Llama` to support the batched prefill flag.
…king state, with cuda graphs only)
…ts to dedicated classes and packages Move `LlamaFP16BatchPrefillLayers` to `tornadovm.layers.type.fp16.prefll` and `LlamaFP16FFNLayersForUnifiedDecode` to `tornadovm.layers.type.fp16.decode`
…phs` option to ease debugging
…lamaFP16FFNLayersDecode`
…6LayersBatchPrefill`
…adoVMMasterPlanWithBatchPrefillDecode`
…d refactor task graph consumption logic Introduce `LogitsFP16LayerDecode` with KV-cache pass-through. Override `consumeFromDevice` and `persistOnDevice` in LlamaFFN layers to fix cross-graph propagation for both CUDA and interpreter modes.
…nd `batched-prefill-decode` execution paths for both CPU and GPU
…tandard, prefill-decode, and batched-prefill-decode setups.
…om TornadoVM execution paths.
…for prefill and decode paths in TornadoVM
Member
|
/rerun all |
Contributor
|
🚀 Workflow rerun started Mode: |
Contributor
|
✅ Workflow rerun success |
…PrefillDecode` This fixes GPU prefill-decode without batching without CUDA Graphs
…`, CPU/GPU) for standard, prefill-decode and prefill-decode with batching
…code and batched-prefill-decode paths in `Mistral`, `Phi3`, `Qwen2`, and `Qwen3` models.
…U prefill-decode and batched-prefill-decode paths
…`batched-prefill-decode` test cases for Llama 3.2 1B FP16
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR implements the prefill-decode concept in GPULlama3.java:
Prefill (or prompt ingestion) is the inference pass for prompt (input) tokens. Currently they are ingested sequentially (one-by-one). However, all prompt tokens are already known and also they are independent entities; hence the sequential ingestion is sub-optimal. A well known practice is to ingest prompt tokens in batches (i.e., of 32) instead of one-by-one. In addition, logits creation can be skipped for all prompt tokens but the last one as they are not used at all. The logits of the last prompt token are used to get the first generated token which in its turn feeds the decode phase.
Decode (new token generation) is the inference pass for each generated token. In contrast to prefill, this phase remains token-sequential (as it is now) because each generated token depends on the previous one.
Based on the above, this PR breaks down the prefill-decode feature implementation in 4 discrete phases:
Phase 1: CPU prefill/decode split — sequential, skip logits during prefill
Phase 2: GPU prefill/decode split — sequential, skip logits during prefill
Phase 3: CPU batched prefill (batch size B, default 32)
Phase 4: GPU batched prefill (batch size B, default 32)
Implementation
Top-level dispatch
Llama#generateTokens/#generateTokensGPUperform a three-way dispatch:PREFILL_BATCH_SIZE > 1 → InferenceEngineWithBatchPrefillDecode (Phase 3/4)
PREFILL_BATCH_SIZE == 1 → InferenceEngineWithPrefillDecode (Phase 1/2)
(default) → InferenceEngine (standard)
Per-phase entry points
InferenceEngineWithPrefillDecode#generateTokensLlamaInferenceCoreWithPrefillDecode#forwardJavaPrefillInferenceEngineWithPrefillDecode#generateTokensGPULlamaInferenceCoreWithPrefillDecode#forwardTornadoVMPrefillTornadoVMMasterPlanWithPrefillDecode(N+2 graphs)InferenceEngineWithBatchPrefillDecode#generateTokensLlamaInferenceCoreBatchPrefillDecode#batchForwardJavaPrefillInferenceEngineWithBatchPrefillDecode#generateTokensGPULlamaInferenceCoreBatchPrefillDecode#batchForwardTornadoVMPrefillTornadoVMMasterPlanWithBatchPrefillDecode(2N+3 graphs)Phase 2 note:
LlamaFP16FFNLayersPrefillDecodefixes no-CUDA-graph mode — layer 0 allocates theKV cache via
FIRST_EXECUTION, layers 1..N use namedconsumeFromDevice.Phase 4 note: KV cache flows from batch prefill into decode via
persistOnDevice/consumeFromDeviceacross the prefill→decode graph boundary, all within a single
TornadoExecutionPlan.Quantization
Both prefill/decode plans switch on
GGMLTypeincreateExecutionPlan():F16proceeds,Q8_0and others throw
UnsupportedOperationExceptionat plan-construction time, mirroring theQuantizationPlannerFactorypattern in the standard plan.Functional Status
All 4 phases fully implemented and verified for LLaMA, FP16, CUDA graphs and no-CUDA-graphs.
Remaining work: model coverage (Mistral, Qwen2, Qwen3, DeepSeek, Phi-3, Granite) and Q8_0
(extension points already in place in both execution plans).
Performance (RTX 5090 ROG Laptop, TornadoVM PTX, LLaMA-3.2-1B FP16, B=32)
How to run