[Relay][ADT]Static Tensor Array#5103
Conversation
|
There are lots of AST written by hand. Can you try using the relay parser? |
|
Is the constraint only fixed "rank"? I have a use case for fixed "shape" tensor array (a tighter constraint) |
@MarisaKirisame This PR follows the patten of current Tensor Array. What do you mean by using relay parser? |
@masahi fixed shape is a subset of fixed rank and supported in this PR. For stack operator, even if we have all tensors in the array with the same static shape, the first axis output shape will still be Any(). |
|
@kevinthesun Great! I've been waiting for this (thanks @wweic). I think the first output axis of stack being Any makes sense and it is no problem for me, assuming it is the length of the array.
I think Marisa is talking about writing the new ADT and functions in the "Relay language" itself and use the parser to make it available to python, similar to the way List ADT and its helper functions are implemented in Prelude. |
I see. My understanding for this is List ADT primitives are small and mature enough to be implemented in prelude. We can also move tensor array related primitives into prelude when they becomes more mature. In addition, we might want to move the generic tensor array as well, not just static tensor array. Does this make sense? |
Yes. Since the generic tensor array is already implemented in this way, I also think it is better to let this in first and later port all tensor array stuff to Relay lang. There could be common code between generic/static that should be refactored in the process. |
fa9e4a7 to
92d7eaf
Compare
8a4b4fc to
769cdbd
Compare
|
@MarisaKirisame The reason we can not use relay parser is that we want to dynamically generate the operators for specific shape while we convert the TF model to relay IR. Current relay text parser can not easily do that. |
| shape_str = str(self.shape).replace('[', '').replace(']', '')\ | ||
| .replace('(', '').replace(')', '').replace(', ', '_')\ | ||
| .replace(',', '') |
There was a problem hiding this comment.
maybe we should improve this a bit. can we use '_'.join(self.shape)?
| """Defines the dynamic tensor ADT, which is the container for tensors | ||
| with variable shapes.""" |
There was a problem hiding this comment.
| """Defines the dynamic tensor ADT, which is the container for tensors | |
| with variable shapes.""" | |
| """Defines the static tensor ADT, which is the container for tensors | |
| with fixed shape.""" |
| lower = Var('lower', scalar_type('int32')) | ||
| upper = Var('upper', scalar_type('int32')) | ||
| tvar = Var('t') | ||
| case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]), |
There was a problem hiding this comment.
| case = Clause(PatternConstructor(self.get_var('tensor_constructor'), [PatternVar(tvar)]), | |
| case = Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), |
| Function([x], Match(x, [case], False), tensor_type_var(), []) | ||
|
|
||
| def define_tensor_array_read(self): | ||
| """Defines a function to get the head of a list. Assume the list has at least one |
There was a problem hiding this comment.
| """Defines a function to get the head of a list. Assume the list has at least one | |
| """Defines a function to get the nth element of a list. Assume the list has at least one |
| list[tensor_t] -> Tensor[(Any), int32] -> tensor_t -> list[tensor_t] | ||
|
|
||
| Set static indices shape by specifying indices_shape. | ||
| Set for_update to get static indices shape operator. |
There was a problem hiding this comment.
| Set for_update to get static indices shape operator. | |
| Set force_update to get static indices shape operator. |
|
@wweic Comments addressed. PTAL. |
| """ | ||
| ndim = len(self.shape) | ||
| # Skip scalar case | ||
| # We don't register unstask for scalar tensor array |
There was a problem hiding this comment.
| # We don't register unstask for scalar tensor array | |
| # We don't register unstack for scalar tensor array |
4a9dd1d to
183ac9e
Compare
|
@wweic Fixed. |
|
Thanks @wweic @masahi @MarisaKirisame |
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
* Add other static tensor array ops * Add tensor array get data * Minor refactor * Fix pylint * Update docstring * Make get data more generic * Improve test * Improve split test * Improve get data * Minor fix * Further improvement for static shape * Improve shape parsing * Unify get_static_name
Add static tensor array for fixed rank tensor array. With this change, we can more easily do type inference and optimization for most tensor array use cases.
Thanks for @wweic working on the base infra of StaticTensorArrayOps class.
@wweic @zhiics @yongwww @icemelon9