High-perf SPMD paged attention (GQA) matching CANN IFA performance#655
High-perf SPMD paged attention (GQA) matching CANN IFA performance#655learning-chip wants to merge 2 commits intohw-native-sys:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds a standalone Paged Attention kernel for NPU, including tiling logic, benchmarking, and accuracy tests. Critical issues were identified in the tiling logic where memory offsets for various tensors are calculated as element offsets instead of byte offsets, and the workspace size for split-KV reduction is missing the batch dimension. Additionally, the sequence rank assignment is incorrect, the kernel entry point fails to handle certain tiling key configurations, and the performance runner lacks necessary workspace zero-initialization between launches.
| addr_q += num_heads * head_dim * q_seqlen | ||
| addr_o += num_heads * head_dim_v * q_seqlen |
There was a problem hiding this comment.
The offsets addr_q and addr_o are calculated as element offsets, but the CCE kernel entry point in pa_entry.cce receives these tensors as uint8_t*. If the kernel adds these offsets directly to the base pointers without scaling by the element size, it will access incorrect memory locations. These should be converted to byte offsets using np.dtype(...).itemsize to avoid hardcoded values.
| addr_q += num_heads * head_dim * q_seqlen | |
| addr_o += num_heads * head_dim_v * q_seqlen | |
| addr_q += num_heads * head_dim * q_seqlen * np.dtype(np.float16 if dtype in (torch.float16, torch.bfloat16) else np.float32).itemsize | |
| addr_o += num_heads * head_dim_v * q_seqlen * np.dtype(np.float16 if dtype in (torch.float16, torch.bfloat16) else np.float32).itemsize |
References
- To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.
| addr_l += kvCN * num_heads * q_seqlen | ||
| addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd |
There was a problem hiding this comment.
Similar to the Q/O offsets, addr_l and addr_ofd are element offsets being passed to a kernel that treats the corresponding GM pointers as uint8_t*. These must be byte offsets to ensure correct indexing in the kernel. Use np.dtype(...).itemsize instead of hardcoded values for robustness.
| addr_l += kvCN * num_heads * q_seqlen | |
| addr_ofd += num_heads * head_dim * q_seqlen # embeddingSize for oFd | |
| addr_l += kvCN * num_heads * q_seqlen * np.dtype(np.float32).itemsize | |
| addr_ofd += num_heads * head_dim * q_seqlen * np.dtype(np.float32).itemsize |
References
- To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.
| tiling[base + 8] = seq_idx # original batch index | ||
| tiling[base + 9] = total_q_blk # = 0 for all-decoder | ||
| tiling[base + 10] = 0 # mask offset hi | ||
| tiling[base + 13] = indices[seq_idx] # sorted position → original index |
There was a problem hiding this comment.
The assignment tiling[base + 13] = indices[seq_idx] appears to be incorrect. indices is a list of original sequence indices sorted by length, so indices[seq_idx] returns the original index of the sequence at rank seq_idx. Since the loop is already iterating over the original index seq_idx, this assignment does not provide the rank of the current sequence. It should instead store the rank (sorted position) of the current seq_idx to allow the kernel to process sequences in the optimized order.
| tiling[base + 13] = indices[seq_idx] # sorted position → original index | |
| tiling[base + 13] = indices.index(seq_idx) # Store the rank of sequence seq_idx |
| o_core = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * head_dim * 4 | ||
| l_size = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * 4 |
There was a problem hiding this comment.
The workspace size calculations for o_core and l_size are missing the batch dimension. These buffers are used to store partial results for split-KV reduction across the entire batch. Without the batch factor, the allocated workspace will be insufficient for any batch size greater than 1. Additionally, use np.dtype(...).itemsize instead of hardcoded values.
| o_core = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * head_dim * 4 | |
| l_size = int(block_dim * SPLITKV_RATIO) * num_heads * block_dim * 4 | |
| o_core = batch * int(block_dim * SPLITKV_RATIO) * num_heads * head_dim * np.dtype(np.float32).itemsize | |
| l_size = batch * int(block_dim * SPLITKV_RATIO) * num_heads * np.dtype(np.float32).itemsize |
References
- To calculate the size in bytes of a numpy array, use np.dtype().itemsize instead of a hardcoded value. This is more self-documenting and robust if the data type changes in the future.
| self.null = empty_buf(device) | ||
| torch.npu.synchronize() | ||
|
|
||
| def __call__(self): |
There was a problem hiding this comment.
The CustomPARunner does not zero-initialize the workspace buffers before each kernel launch. If the kernel uses these buffers for atomic operations or partial sum reductions (which is common in split-KV implementations), stale data from previous iterations will corrupt the results. While torch.zeros is used in __init__, the buffers must be cleared in __call__ if they are modified by the kernel.
| if (tiling_key_val == 0) { // fp16 BN | ||
| #ifdef __DAV_C220_CUBE__ | ||
| UnpadAttentionDecoderAic<false, TilingKeyType::TILING_HALF_DATA, half, half, half> pa_aic_fp16(prefill_batch_size, decoder_batch_size); | ||
| pa_aic_fp16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset); | ||
| pa_aic_fp16.Run(); | ||
| #elif __DAV_C220_VEC__ | ||
| UnpadAttentionDecoderAiv<TilingKeyType::TILING_HALF_DATA, half, half> pa_aiv(prefill_batch_size, decoder_batch_size); | ||
| pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm, | ||
| mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm); | ||
| pa_aiv.Run(); | ||
| #endif | ||
| } else if (tiling_key_val == 1) { // bf16 BN | ||
| #ifdef __DAV_C220_CUBE__ | ||
| UnpadAttentionDecoderAic<false, TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, __bf16> pa_aic_bf16(prefill_batch_size, decoder_batch_size); | ||
| pa_aic_bf16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset); | ||
| pa_aic_bf16.Run(); | ||
| #elif __DAV_C220_VEC__ | ||
| UnpadAttentionDecoderAiv<TilingKeyType::TILING_BF16_DATA, __bf16, __bf16> pa_aiv(prefill_batch_size, decoder_batch_size); | ||
| pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm, | ||
| mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm); | ||
| pa_aiv.Run(); | ||
| #endif | ||
| } else if (tiling_key_val == 16) { // fp16 BNS split-kv | ||
| #ifdef __DAV_C220_CUBE__ | ||
| UnpadAttentionDecoderAic<true, TilingKeyType::TILING_HALF_DATA, half, half, half> pa_aic_fp16(prefill_batch_size, decoder_batch_size); | ||
| pa_aic_fp16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset); | ||
| pa_aic_fp16.Run(); | ||
| #elif __DAV_C220_VEC__ | ||
| UnpadAttentionDecoderAiv<TilingKeyType::TILING_HALF_DATA, half, half, true> pa_aiv(prefill_batch_size, decoder_batch_size); | ||
| pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm, | ||
| mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm); | ||
| pa_aiv.Run(); | ||
| #endif | ||
| } else if (tiling_key_val == 17) { // bf16 BNS split-kv | ||
| #ifdef __DAV_C220_CUBE__ | ||
| UnpadAttentionDecoderAic<true, TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, __bf16> pa_aic_bf16(prefill_batch_size, decoder_batch_size); | ||
| pa_aic_bf16.SetArgs(sync, q_gm, k_gm, v_gm, block_tables_gm, o_gm, s_gm, p_gm, o_tmp_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset); | ||
| pa_aic_bf16.Run(); | ||
| #elif __DAV_C220_VEC__ | ||
| UnpadAttentionDecoderAiv<TilingKeyType::TILING_BF16_DATA, __bf16, __bf16, true> pa_aiv(prefill_batch_size, decoder_batch_size); | ||
| pa_aiv.SetArgs(sync, k_gm, v_gm, deq_scale1_gm, offset1_gm, deq_scale2_gm, offset2_gm, block_tables_gm, | ||
| mask_gm, o_gm, s_gm, p_gm, o_tmp_gm, go_gm, o_core_tmp_gm, l_gm, gm_k16, gm_v16, tiling_para_gm, razorOffset, logN_gm); | ||
| pa_aiv.Run(); | ||
| #endif | ||
| } |
There was a problem hiding this comment.
The kernel entry point does not handle tiling_key_val values that include the is_split_block bit (bit 7). According to pa_tiling.py, this bit is set when block_size >= 128 and head_dim == 256. If a user provides such a configuration, the kernel will silently skip execution as none of the if/else if conditions will match.
| as_ptr(o), | ||
| as_ptr(s_gm), as_ptr(p_gm), as_ptr(o_tmp), as_ptr(go), | ||
| as_ptr(o_core), as_ptr(l_gm), as_ptr(k16), as_ptr(v16), | ||
| as_ptr(tiling), |
There was a problem hiding this comment.
Could it be that as_ptr has a non-negligible kernel launch overhead here?
| def __call__(self): | ||
| stream = torch.npu.current_stream()._as_parameter_ | ||
| _launch( | ||
| self.lib, self.eff_bd, stream, |
There was a problem hiding this comment.
Accessing the class members here from Python class CustomPARunner might have a non-negligible overhead.
Reaches ~1 TB/s on A2 for GQA shapes, equal to the perf of CANN's
torch_npu.npu_incre_flash_attention(for some shapes even 1.1x faster)@ChaoWao @chenshengxin2026
Reproduce
TODOs:
Current version is 2000+ lines of raw CCE & host-side launch. Remaining items:
Acknowledgement
I took the tiling scheme from ATB PA tiling, but rewrote it in numpy instead of original C++, to more trivially interface with python launcher.