Skip to content

High-perf SPMD paged attention (GQA) matching CANN IFA performance#655

Draft
learning-chip wants to merge 2 commits intohw-native-sys:mainfrom
learning-chip:pa_highperf
Draft

High-perf SPMD paged attention (GQA) matching CANN IFA performance#655
learning-chip wants to merge 2 commits intohw-native-sys:mainfrom
learning-chip:pa_highperf

Conversation

@learning-chip
Copy link
Copy Markdown

@learning-chip learning-chip commented Apr 22, 2026

Reaches ~1 TB/s on A2 for GQA shapes, equal to the perf of CANN's torch_npu.npu_incre_flash_attention (for some shapes even 1.1x faster)

@ChaoWao @chenshengxin2026

Reproduce

cd simpler/tests/st/a2a3/tensormap_and_ringbuffer/spmd_paged_attention_highperf/kernel
bash ./compile.sh
python ./test_pa_accuracy.py
python ./bench_pa_performance.py

TODOs:

Current version is 2000+ lines of raw CCE & host-side launch. Remaining items:

Acknowledgement

I took the tiling scheme from ATB PA tiling, but rewrote it in numpy instead of original C++, to more trivially interface with python launcher.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a standalone Paged Attention kernel for NPU, including tiling logic, benchmarking, and accuracy tests. Critical issues were identified in the tiling logic where memory offsets for various tensors are calculated as element offsets instead of byte offsets, and the workspace size for split-KV reduction is missing the batch dimension. Additionally, the sequence rank assignment is incorrect, the kernel entry point fails to handle certain tiling key configurations, and the performance runner lacks necessary workspace zero-initialization between launches.

Comment on lines +402 to +403
addr_q += num_heads * head_dim * q_seqlen
addr_o += num_heads * head_dim_v * q_seqlen
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The offsets addr_q and addr_o are calculated as element offsets, but the CCE kernel entry point in pa_entry.cce receives these tensors as uint8_t*. If the kernel adds these offsets directly to the base pointers without scaling by the element size, it will access incorrect memory locations. These should be converted to byte offsets using np.dtype(...).itemsize to avoid hardcoded values.

Suggested change
addr_q += num_heads * head_dim * q_seqlen
addr_o += num_heads * head_dim_v * q_seqlen
addr_q += num_heads * head_dim * q_seqlen * np.dtype(np.float16 if dtype in (torch.float16, torch.bfloat16) else np.float32).itemsize
addr_o += num_heads * head_dim_v * q_seqlen * np.dtype(np.float16 if dtype in (torch.float16, torch.bfloat16) else np.float32).itemsize
References
  1. To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.

Comment on lines +419 to +420
addr_l += kvCN * num_heads * q_seqlen
addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the Q/O offsets, addr_l and addr_ofd are element offsets being passed to a kernel that treats the corresponding GM pointers as uint8_t*. These must be byte offsets to ensure correct indexing in the kernel. Use np.dtype(...).itemsize instead of hardcoded values for robustness.

Suggested change
addr_l += kvCN * num_heads * q_seqlen
addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd
addr_l += kvCN * num_heads * q_seqlen * np.dtype(np.float32).itemsize
addr_ofd += num_heads * head_dim * q_seqlen * np.dtype(np.float32).itemsize
References
  1. To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.

tiling[base + 8] = seq_idx # original batch index
tiling[base + 9] = total_q_blk # = 0 for all-decoder
tiling[base + 10] = 0 # mask offset hi
tiling[base + 13] = indices[seq_idx] # sorted position → original index
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The assignment tiling[base + 13] = indices[seq_idx] appears to be incorrect. indices is a list of original sequence indices sorted by length, so indices[seq_idx] returns the original index of the sequence at rank seq_idx. Since the loop is already iterating over the original index seq_idx, this assignment does not provide the rank of the current sequence. It should instead store the rank (sorted position) of the current seq_idx to allow the kernel to process sequences in the optimized order.

Suggested change
tiling[base + 13] = indices[seq_idx] # sorted position → original index
tiling[base + 13] = indices.index(seq_idx) # Store the rank of sequence seq_idx

Comment on lines +439 to +440
o_core = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * head_dim * 4
l_size = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * 4
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The workspace size calculations for o_core and l_size are missing the batch dimension. These buffers are used to store partial results for split-KV reduction across the entire batch. Without the batch factor, the allocated workspace will be insufficient for any batch size greater than 1. Additionally, use np.dtype(...).itemsize instead of hardcoded values.

Suggested change
o_core = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * head_dim * 4
l_size = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * 4
o_core = batch * int(block_dim * SPLITKV_RATIO) * num_heads * head_dim * np.dtype(np.float32).itemsize
l_size = batch * int(block_dim * SPLITKV_RATIO) * num_heads * np.dtype(np.float32).itemsize
References
  1. To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.

self.null = empty_buf(device)
torch.npu.synchronize()

def __call__(self):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CustomPARunner does not zero-initialize the workspace buffers before each kernel launch. If the kernel uses these buffers for atomic operations or partial sum reductions (which is common in split-KV implementations), stale data from previous iterations will corrupt the results. While torch.zeros is used in __init__, the buffers must be cleared in __call__ if they are modified by the kernel.

Comment on lines +52 to +96
if (tiling_key_val == 0) { // fp16 BN
#ifdef __DAV_C220_CUBE__
UnpadAttentionDecoderAic<false, TilingKeyType::TILING_HALF_DATA, half, half, half> pa_aic_fp16(prefill_batch_size, decoder_batch_size);
pa_aic_fp16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset);
pa_aic_fp16.Run();
#elif __DAV_C220_VEC__
UnpadAttentionDecoderAiv<TilingKeyType::TILING_HALF_DATA, half, half> pa_aiv(prefill_batch_size, decoder_batch_size);
pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm,
mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm);
pa_aiv.Run();
#endif
} else if (tiling_key_val == 1) { // bf16 BN
#ifdef __DAV_C220_CUBE__
UnpadAttentionDecoderAic<false, TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, __bf16> pa_aic_bf16(prefill_batch_size, decoder_batch_size);
pa_aic_bf16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset);
pa_aic_bf16.Run();
#elif __DAV_C220_VEC__
UnpadAttentionDecoderAiv<TilingKeyType::TILING_BF16_DATA, __bf16, __bf16> pa_aiv(prefill_batch_size, decoder_batch_size);
pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm,
mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm);
pa_aiv.Run();
#endif
} else if (tiling_key_val == 16) { // fp16 BNS split-kv
#ifdef __DAV_C220_CUBE__
UnpadAttentionDecoderAic<true, TilingKeyType::TILING_HALF_DATA, half, half, half> pa_aic_fp16(prefill_batch_size, decoder_batch_size);
pa_aic_fp16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset);
pa_aic_fp16.Run();
#elif __DAV_C220_VEC__
UnpadAttentionDecoderAiv<TilingKeyType::TILING_HALF_DATA, half, half, true> pa_aiv(prefill_batch_size, decoder_batch_size);
pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm,
mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm);
pa_aiv.Run();
#endif
} else if (tiling_key_val == 17) { // bf16 BNS split-kv
#ifdef __DAV_C220_CUBE__
UnpadAttentionDecoderAic<true, TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, __bf16> pa_aic_bf16(prefill_batch_size, decoder_batch_size);
pa_aic_bf16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset);
pa_aic_bf16.Run();
#elif __DAV_C220_VEC__
UnpadAttentionDecoderAiv<TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, true> pa_aiv(prefill_batch_size, decoder_batch_size);
pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm,
mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm);
pa_aiv.Run();
#endif
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kernel entry point does not handle tiling_key_val values that include the is_split_block bit (bit 7). According to pa_tiling.py, this bit is set when block_size >= 128 and head_dim == 256. If a user provides such a configuration, the kernel will silently skip execution as none of the if/else if conditions will match.

as_ptr(o),
as_ptr(s_gm), as_ptr(p_gm), as_ptr(o_tmp), as_ptr(go),
as_ptr(o_core), as_ptr(l_gm), as_ptr(k16), as_ptr(v16),
as_ptr(tiling),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be that as_ptr has a non-negligible kernel launch overhead here?

def __call__(self):
stream = torch.npu.current_stream()._as_parameter_
_launch(
self.lib, self.eff_bd, stream,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing the class members here from Python class CustomPARunner might have a non-negligible overhead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants