-
Notifications
You must be signed in to change notification settings - Fork 393
Add support for SymInt start values in slice_scatter_decomposition #4185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,8 +197,13 @@ def slice_scatter_decomposition( | |
| dim_size = input_tensor.shape[dim] | ||
| device_input_tensor = input_tensor.device | ||
|
|
||
| start = 0 if start is None else start # Ensure start is int | ||
| start = get_positive_dim(start, input_tensor.shape[dim]) | ||
| if start is None: | ||
| start = 0 | ||
| elif isinstance(start, int): | ||
| start = get_positive_dim(start, dim_size) | ||
| elif isinstance(start, torch.SymInt): | ||
| if start < 0: | ||
| start = start + dim_size | ||
| if end is None: # Ensure end is int | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi can you have a testcase for the above?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, I tested with my ZoomASR code. Without this, it will crash. After adding this it will run.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the info! Generally for any change in the codebase, we recommend adding a corresponding test case to capture the corner case for future CI runs.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, do you know where I need to add the test? Do you have examples?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| end = dim_size | ||
| end = ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the aten schema, all of
start,end, andstepcould be SymInt, not juststart. Can you fix all of them btw? I think a more general way is to support SymInt inget_positive_dim()like: