Two small fixes to AMDCPU codegen for LLVM 10+ and ROCm 3.5+#5920
Two small fixes to AMDCPU codegen for LLVM 10+ and ROCm 3.5+#5920tqchen merged 1 commit intoapache:masterfrom
Conversation
3206394 to
4f6e499
Compare
|
@masahi This is needed for the latest ROCm and also recent LLVM. I'm not particularly fond of needing to go through the device API, but I haven't found a better way, as the codegen wants to be compilable without having HIP installed. |
- For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
|
Sorry, another small thing:
|
|
NOTE: using runtime detection of rocm features will only work if we are building on the same machine and won't work for cross compilation. While it is OK for now, let us keep that in mind and once we land https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844, we might want to allow user to explicitly specify the attr and only use auto detect if the attr is not specified(or march=native is used) |
|
@tqchen Yeah, so the background to this is that the recent release of ROCm 3.5 brings rather sweeping changes (changing the compiler backend for the HIP compilation among other things). My conclusion from this would be
|
…5920) - For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
…5920) - For LLVM 10+ we need to avoid calling Align with 0, or else we get a crash. - For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+) but for ROCm < 3.5 we want the code object 2. - As we want to separate codegen from the API, we need to add a device api query for the version. But every one else wants now one, too. (But I only filled it in for CUDA for now.) - I'm throwing in an addition of kMaxRegistersPerBlock for ROCm. This was introduced for CUDA in apache#5898.
Uh oh!
There was an error while loading. Please reload this page.