[Codegen][CUDA] Fix make_int4x cuda codegen vectorize#8137
Merged
vinx13 merged 1 commit intoapache:mainfrom May 26, 2021
Merged
[Codegen][CUDA] Fix make_int4x cuda codegen vectorize#8137vinx13 merged 1 commit intoapache:mainfrom
vinx13 merged 1 commit intoapache:mainfrom
Conversation
Member
|
@vinx13 please help to manage this PR |
vinx13
approved these changes
May 26, 2021
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Jun 17, 2021
Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
trevor-m
pushed a commit
to neo-ai/tvm
that referenced
this pull request
Jun 17, 2021
Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Added support for int4x32 int4x16 int4x4 in BroadcastNode.
In the int4x4 testcase, the IR is:
Before the fix in codegen_c.cc, the codegen cuda is:
For int16_t, this index
(((int)blockIdx.x) * 4)) / 8is a bug.After the fix in codegen_c.cc, the codegen cuda is:
Could you please help review this fix? @vinx13 @Hzfengsy