[BYOC][TRT] Support batch norm for all ranks <=5, and all axes#7026
[BYOC][TRT] Support batch norm for all ranks <=5, and all axes#7026zhiics merged 3 commits intoapache:mainfrom
Conversation
ae0c87b to
27982b8
Compare
|
All tests in test_tensorrt.py passed locally |
c03ece8 to
f51f808
Compare
| auto input_dims = TrtDimsToVector(input->getDimensions()); | ||
| const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; | ||
| const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5; | ||
| ICHECK_LE(input_dims.size(), max_rank); |
There was a problem hiding this comment.
Could you convert these checks to use Diagnostic instead of generating an assertion, we should strive to replace most of these with end-user readable errors.
There was a problem hiding this comment.
Hi @jroesch, thanks for reviewing!
These checks are more for sanity checking, since the annotation functions in python/tvm/relay/op/contrib/tensorrt.py will filter out the unsupported ops before they ever get to this code. I don't expect users to ever see these.
Anyway, I can make a separate PR to port all of the ICHECK to Diagnostics.
zhiics
left a comment
There was a problem hiding this comment.
LGTM. Let's have a separate PR to migrate all errors to diagnostic.
|
Thanks @trevor-m @anijain2305 @jroesch |
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
Previous batch norm only supported rank 4 inputs with axis 1 or 3. Now we support input ranks and axes 1-5.