You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, dense op's input layouts are NC and NK for data and weight respectively. This causes an issue when the output of dense is used as the weight for another dense. We get "Incompatible layouts: NC vs NK" error.
There is no need to distinguish the second dim of data and weight, since both must be the same value. So I replaced NK with NC.
@masahi Thanks for fixing this issue! Just one random thought out of this issue that may further improve TVM: is the layout annotation imposing too many unnecessary constraints (i.e., asserting semantics of each dimension)? Will it be better to not annotate layout itself, but instead annotate how the layout got transformed?
As in this case, previously TVM assumes NK for the second operand, which essentially constraints the 1st dimension to have the semantic of "batch" and the 2nd "units". This is why this issue happens: another operation assumes a semantic of NC which conflicts with NK.
However, I feel like what AlterOpLayout needs is only how the layouts are transformed before and after the pass (e.g., the 2nd dimension is split by a factor of xxx), and it does not care about the semantics at all. The current fix solves the issue for nn.dense, but it seems the fundamental issue still exists. For example, F.conv2d(x, F.conv2d(y, z)) will be broken too (though I guess this is a weird pattern that may not occur in practice?).
@lazycal That is an interesting suggestion. Layout annotation in TVM was introduced a long time ago, and I believe this is the simplest solution that works in most cases. Indeed, F.conv2d(x, F.conv2d(y, z)) would break in the current system. What you suggested makes a lot of sense and I like I0_I1_10i0 annotation you mentioned in the discuss post.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes the issue reported in https://discuss.tvm.apache.org/t/pytorch-layout-cannot-convert-f-linear-x-f-linear-y-z/10866/
Currently,
denseop's input layouts areNCandNKfor data and weight respectively. This causes an issue when the output ofdenseis used as the weight for anotherdense. We get "Incompatible layouts: NC vs NK" error.There is no need to distinguish the second dim of
dataandweight, since both must be the same value. So I replacedNKwithNC.@comaniac @yzhliu