[TOPI][OP] cuda for argwhere#6868
Conversation
|
@zhiics I have a branch with the changes you'd need, but I haven't opened a PR because I've been fighting that memory corruption issue with topk. Would you like me to submit a PR to enable the other dynamic tests and include my refactors to strided slice? |
|
@mbrookhart Thanks. That would be cool. |
| max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) | ||
| nthread_tx = max_threads | ||
|
|
||
| # Limit threads to a single block to make sure atomic_add works normally. |
There was a problem hiding this comment.
Cuda does have a kernel level atomic add (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions). It is just slower or do we not have access to it from TIR?
There was a problem hiding this comment.
We use atomicAdd. However, if number of blocks is larger than a threshold(like 18), it will return incorrect result.
There was a problem hiding this comment.
I'm surprised. I'd expect atomicAdd to work with any number of blocks. Could you maybe expand this comment with why and when atomicAdd fails?
There was a problem hiding this comment.
The observation is that if input data size is large( > 300 * 300 for example), previous we don't limit the number of blocks and the output of IR routine would be incorrect. I didn't dig deeper into it at this time.
In addition we need to use thrust otherwise tvm implemetation of topk can also generate incorrect result.
|
Thanks @zhiics @mbrookhart @tkonolige |
* argwhere * cuda schedule * sort argwhere result * Use single block and thrust to fix flaky behavior * format * used dynamic strided_slice * Fix dynamic strided_slice * try new strided_slice * Improve dynamic strided slice to bind data depedent shape var. * all tests pass * remove print * use new strided_slice * clean Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
* argwhere * cuda schedule * sort argwhere result * Use single block and thrust to fix flaky behavior * format * used dynamic strided_slice * Fix dynamic strided_slice * try new strided_slice * Improve dynamic strided slice to bind data depedent shape var. * all tests pass * remove print * use new strided_slice * clean Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
* argwhere * cuda schedule * sort argwhere result * Use single block and thrust to fix flaky behavior * format * used dynamic strided_slice * Fix dynamic strided_slice * try new strided_slice * Improve dynamic strided slice to bind data depedent shape var. * all tests pass * remove print * use new strided_slice * clean Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
This PR adds cuda schedule for argwhere.
Will ping reviews when we can run argwhere relay tests.