[TIR] Introduce Pass InjectPTXLDG32#13973
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
I hope TIR |
| this->stream << "asm volatile (\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\"{.reg .pred p;\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; | ||
| this->PrintIndent(); | ||
| stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n" ; | ||
| // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; | ||
| this->PrintIndent(); | ||
| stream << ": \"=f\"(" << reg << "[" << local_addr << "]" << ")\n" ; | ||
| this->PrintIndent(); | ||
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" << guard << ")\n" ; | ||
| this->PrintIndent(); | ||
| stream << ");\n" ; |
There was a problem hiding this comment.
nit: you may use multi-line string in C++
|
|
||
| // The pass can now be invoked via the pass infrastructure, but we also add a Python binding for it | ||
| TVM_REGISTER_GLOBAL("tir.transform.InjectPTXLDG32").set_body_typed(InjectPTXLDG32); | ||
|
|
There was a problem hiding this comment.
nit: you may use clang-format to somehow organize the file slightly better
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); | ||
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_pred_ldg32", Bool); |
There was a problem hiding this comment.
the name is a bit confusing, can you discuss with @rainy-memory and figure out together something more comprehensible?
our key objective is that users may need to set at most one flag (zero is the best if possible) so that they could deliver the best GEMM performance out of the box
|
Let's fix the lint and merge it in asap. If you don't like that pylint claims about variable naming, just do: # pylint: disable=invalid-name
you code
# pylint: enable=invalid-name |
| * \brief tvm intrinsic for ptx predicate load with 32-bit data type. | ||
| * | ||
| */ | ||
| TVM_DLL const Op& inject_ptx_ldg32(); |
There was a problem hiding this comment.
naming: we do not need inject prefix as it can just be ptx_ldg32
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
| this->stream << "asm volatile (\n"; | ||
| this->stream << "\"{.reg .pred p;\\n\"\n"; | ||
| this->stream << "\" setp.ne.b32 p, %2, 0;\\n\"\n"; | ||
| this->stream << "\" @!p mov.b32 %0, 0;\\n\"\n"; | ||
| this->stream << "\" @p ld.global.nc.f32 %0, [%1];}\\n\"\n"; | ||
| // stream << "\" @p ld.global.nc.L2::128B.f32 %0, [%1];}\\n\"\n" ; | ||
| stream << ": \"=f\"(" << reg << "[" << local_addr << "]" | ||
| << ")\n"; | ||
| stream << ": \"l\"((void*)(" << global_buffer << "+" << global_addr << ")), \"r\"((int)" | ||
| << guard << ")\n"; | ||
| stream << ");\n"; |
There was a problem hiding this comment.
perhaps it would be clearer to write this way:
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else`
call node to `ptx_pred_ldg32` call node. When the store buffer is local
and the load value is global, the pass can change the if_then_else pattern
to a ptx pattern.
Test the pass with:
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
````
tqchen
left a comment
There was a problem hiding this comment.
my comments have been address, will let @junrushao handle this
|
@andy-yang-1 please fix the unittests and we are good to go |
This PR introduces a new pass InjectPTXLDG32 to change the `if_then_else` call node to `ptx_pred_ldg32` call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.
Test the pass with
```python
with tvm.transform.PassContext(config={"tir.ptx_pred_ldg32": True}):
mod = tvm.build(f, target="cuda")
```
This PR introduces a new pass InjectPTXLDG32 to change the
if_then_elsecall node toptx_pred_ldg32call node. When the store buffer is local and the load value is global, the pass can change the if_then_else pattern to a ptx pattern.Test the pass with