[Unity] Added known tir.Expr to relax.PrimValue#15577
[Unity] Added known tir.Expr to relax.PrimValue#15577Lunderberg merged 5 commits intoapache:unityfrom
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Prior to this commit, a `relax.PrimValue` could have a datatype, but
couldn't have a corresponding `tir.PrimExpr`. As a result, it could
not be used to specify tensor shapes. This makes some expressions
require fallback to `R.Tensor(ndim=ndim)`, even though the shape could
still be inferred.
```python
@R.function
def func(
A: R.Tensor(16, 16),
first_n_rows: R.prim("int64"),
) -> R.Tensor([first_n_rows, 16]):
# ^^^^^^^^^^^^
# R.Tensor requires a PrimExpr, not relax.Expr
#
# Operations may require PrimExpr
# vvvvvvvvvvvv
out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
return out
```
This commit adds a `Optional<PrimExpr> value` field to the
`PrimStructInfo`. This field acts similarly to the `PrimExpr` fields
already used in `ShapeStructInfo`, and may contain symbolic variables.
```python
@R.function
def func(
A: R.Tensor(16, 16),
# TIR definitions in signature allow in-line definitions,
# similar to R.Tensor and R.Shape. R.Prim takes `dtype` or
# `value` kwarg to distinguish between in-line symbolic variable
# and string representation of dtype.
first_n_rows: R.prim(value="first_n_rows_tir"),
) -> R.Tensor(["first_n_rows_tir", 16]):
# Body contains a TIR variable definition, which may be used
# in function calls, inferred shape annotations.
first_n_rows_tir = T.int64()
out = R.op.strided_slice(axis=[0], begin=[0], end=[first_n_rows])
return out
```
Use distinct PrimStructInfo arguments for dtype/value
Update TVMScript printer
Parser updates, Support R.Prim(value=...) annotations in function signature
4d2eac6 to
087e636
Compare
slyubomirsky
left a comment
There was a problem hiding this comment.
I think this change is a good addition and that it's implemented very well, thanks @Lunderberg! I will start updating the spec to account for this.
My main concern are for the cases where shape variables are used "before" they're defined (going left to right in a function). When writing the spec, I assume this wasn't a case we wanted. If that is a case we want to permit, then should our policy be to scan the arguments for binding positions first? That would be good to figure out.
| # Guard against incorrect usage. For backwards compatibility, | ||
| # the dtype and value are in the opposite order from most | ||
| # usages. While PrimStructInfo could take a single positional | ||
| # argument and check the type, this would require an API | ||
| # difference from TVMScript's PrimProxy, which cannot. | ||
| # (PrimProxy uses string arguments for datatype, and also for | ||
| # inline variable definitions when used in a function | ||
| # signature, and requires separate arguments to distinguish | ||
| # the two cases.) |
There was a problem hiding this comment.
Very user-friendly touch :)
| """The bound variable should be replaced when appearing in R.Shape""" | ||
|
|
||
| @R.function(private=True) | ||
| def before(A: R.Tensor(["M * N"]), x: R.Shape(["M", "N"])): |
There was a problem hiding this comment.
Are we sure we want to permit cases like this? I believe I wrote the specification draft to require argument vars to be in binding positions, left to right. If we permit uses before the binding position, then should the type-checker scan through the arguments, identify the bindings first, and only then check the position?
There was a problem hiding this comment.
I was thinking so, mainly because it would be the same support that is provided in TIR. (ArgBinder::BindDLTensor defines the symbolic variables prior to asserting based on the shapes.)
There's also a couple of edge cases that I think would be easier to handle by allowing it. Swapping the order of commutative operators, hoisting R.match_cast from the body into the function signature, and handling cases with mutually-dependent shapes. (e.g. The first parameter is R.Tensor(["a * b", "b"]) and the second parameter is R.Tensor(["a * b", "a"]).)
There was a problem hiding this comment.
Okay, I'll have to modify the specification to account for it.
There was a problem hiding this comment.
Sounds good, and thank you for all your work keeping the spec up to date!
Hzfengsy
left a comment
There was a problem hiding this comment.
The changes generally look good to me. Thanks @Lunderberg for the great addition!
Only one question: should the value be general PrimExpr, or be a tir.Var. IIUC, it maps the prim value to one symbolic var and is used in the shape deduction. A direct var should be easier than a generic expr in this case.
|
@Hzfengsy Thank you, and I do think the value should be a
|
Hzfengsy
left a comment
There was a problem hiding this comment.
Thanks for the clarification. LGTM now!
This test was implemented in apache#15626, but was initially disabled as it depended on functionality not introduced until apache#15577. Since that PR has landed, cleaning up and enabling the unit test.
This test was implemented in apache#15626, but was initially disabled as it depended on functionality not introduced until apache#15577. Since that PR has landed, cleaning up and enabling the unit test.
Prior to this commit, a
relax.PrimValuecould have a datatype, but couldn't have a correspondingtir.PrimExpr. As a result, it could not be used to specify tensor shapes. This makes some expressions require fallback toR.Tensor(ndim=ndim), even though the shape could still be inferred.This commit adds a
Optional<PrimExpr> valuefield to thePrimStructInfo. This field acts similarly to thePrimExprfields already used inShapeStructInfo, and may contain symbolic variables.