[TensorRT] Add transpose_a/b for TensorRT batch_matmul#8607
[TensorRT] Add transpose_a/b for TensorRT batch_matmul#8607trevor-m merged 3 commits intoapache:mainfrom
Conversation
| y = relay.var("y", shape=(y_shape), dtype="float32") | ||
| out = relay.nn.batch_matmul(x, y) | ||
| out = relay.nn.batch_matmul( | ||
| relay.transpose(x, [0, 2, 1]) if transa else x, |
There was a problem hiding this comment.
I don't think you need these relay.transpose on the inputs to test functionality of transa/transb args.
There was a problem hiding this comment.
Good point, I've changed to using x/y_shape instead.
| # Transpose matrix dimensions of b. | ||
| b = _op.transpose(b, [0, 2, 1]) | ||
| # Perform a batch matmul. | ||
| output = _op.nn.batch_matmul(a, b) | ||
| output = _op.nn.batch_matmul(a, b, transpose_b=False) |
There was a problem hiding this comment.
Thanks! @ymwangg
Just a little concern about changing the default behavior of framework frontend, since currently the default topi schedule support for NN format is not as strong as the original NT one.
This may cause confusions to those who have used onnx frontend before or who is using onnx frontend now.
To give an example, I've added an extra config to TensorFlow frontend which uses the NT format by default but provides an option to use the normal format. I think that would be better before we have prepared a strong enough topi.
p.s.: You see, I've also kept the default layout for nn.batch_matmul to be the original NT.
tvm/python/tvm/relay/frontend/tensorflow_ops.py
Lines 1191 to 1199 in 7653972
There was a problem hiding this comment.
@jcf94 Thanks for the pointer. I will refactor to make NN optional.
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
* Add transpose support for tensorrt batch_matmul * Address PR comment * Refactor to add ONNX_DEFAULT_CONFIGS
This PR added transpose_a/b for TensorRT batch_matmul, fixed a warning and compilation error with TensorRT-8. It also removed the redundant transpose op in onnx matmul. Tested with both TensorRT-7 and TensorRT-8.
cc @trevor-m @comaniac