diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 6483b99454a3..da194f885d1c 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -392,7 +392,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1 ) C = T.match_buffer( - c, [M_DIM, N_DIM], dtype=dtype, scope="global", offset_factor=1, strides=[s0, s1] + c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, strides=[s0, s1] ) with T.block("root"):