Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ struct ScatterNDAttrs : public tvm::AttrsNode<ScatterNDAttrs> {
String mode;

TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") {
TVM_ATTR_FIELD(mode).describe(
"Accumulation mode of the scatter, either \"update\" or \"add\".");
TVM_ATTR_FIELD(mode).set_default("update").describe(
"Accumulation mode of the ScatterND, "
"either \"update\", \"add\", \"mul\", \"min\" or \"max\".");
}
};

Expand Down
53 changes: 52 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2856,12 +2856,63 @@ def _impl_v1(cls, inputs, attr, params):
class ScatterND(OnnxOpConverter):
"""Operator converter for ScatterND."""

@classmethod
def _inputs_check(cls, inputs):
assert (
len(inputs) == 3
), "ScatterND takes 3 inputs (data, indices, updates), {} given".format(len(inputs))
assert infer_type(inputs[1]).checked_type.dtype == "int64"

data_rank = len(infer_shape(inputs[0]))
assert data_rank > 0, "Data rank higher than 0 is expected"
indices_rank = len(infer_shape(inputs[1]))
assert indices_rank > 0, "Indices rank higher than 0 is expected"
updates_rank = len(infer_shape(inputs[2]))
assert (
updates_rank == data_rank + indices_rank - infer_shape(inputs[1])[-1] - 1
), "Updates rank should be equal to data_rank + indices_rank - indices_shape[-1] - 1"

@classmethod
def _reduction_check(cls, attr, red_valids=None):
reduction = attr.get("reduction", None)
if reduction is None:
reduction = b"update"
reduction = reduction.decode("utf-8")
if red_valids is None:
red_valids = ["update"]
assert reduction in red_valids, "Only {} reductions are supported, but {} is gotten".format(
red_valids, reduction
)

return reduction

@classmethod
def _impl_v11(cls, inputs, attr, params):
cls._inputs_check(inputs)
indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2])

@classmethod
def _impl_v16(cls, inputs, attr, params):
cls._inputs_check(inputs)
reduction = cls._reduction_check(attr, ["update", "add", "mul"])

indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction
)

@classmethod
def _impl_v18(cls, inputs, attr, params):
cls._inputs_check(inputs)
reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"])

indices_dim = len(infer_shape(inputs[1]))
axes = list(range(indices_dim))
return _op.scatter_nd(
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update"
inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction
)


Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,13 @@ def scatter_nd(data, indices, updates, mode="update"):
The values to update.

mode : string, optional
The accumulation mode for scatter. "update" or "add"
The accumulation mode for scatter. "update", "add", "mul", "min" or "max"
If update, the update values will replace the input data
If add, the update values will be added to the input data
If mul, the update values will be multiply to the input data
If min, there is choice of minimal between the update values and the input data
If max, there is choice of maximal between the update values and the input data
It is "update" by default

Returns
-------
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/topi/cuda/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
"""Scatter operator """
import tvm
from tvm import te, autotvm
from tvm import te, tir, autotvm
from ..scatter import _verify_scatter_nd_inputs
from ..generic import schedule_extern
from .nms import atomic_add
Expand Down Expand Up @@ -871,8 +871,20 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = tir.min(
out[index], updates[i * fused_updates_dimension + j]
)
elif mode == "max":
out[index] = tir.max(
out[index], updates[i * fused_updates_dimension + j]
)
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
)

return ib.get()

Expand Down
28 changes: 18 additions & 10 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..te import extern, hybrid
from ..tir import decl_buffer, expr, ir_builder
from tvm import te, tir # hide redefinition of min and max
from tvm.tir import expr


@hybrid.script
@te.hybrid.script
def _scatter_1d(data, indices, updates):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand All @@ -30,7 +30,7 @@ def _scatter_1d(data, indices, updates):
return out


@hybrid.script
@te.hybrid.script
def _scatter_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand All @@ -52,7 +52,7 @@ def _scatter_2d(data, indices, updates, axis):
return out


@hybrid.script
@te.hybrid.script
def _scatter_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand Down Expand Up @@ -96,7 +96,7 @@ def _scatter_3d(data, indices, updates, axis):
return out


@hybrid.script
@te.hybrid.script
def _scatter_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in range(data.shape[0]):
Expand Down Expand Up @@ -269,7 +269,7 @@ def scatter_nd(data, indices, updates, mode):

def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
# pylint: disable=invalid-name
ib = ir_builder.create()
ib = tir.ir_builder.create()

data = ib.buffer_ptr(data_ptr)
indices = ib.buffer_ptr(indices_ptr)
Expand Down Expand Up @@ -308,13 +308,21 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
out[index] = updates[i * fused_updates_dimension + j]
elif mode == "add":
out[index] += updates[i * fused_updates_dimension + j]
elif mode == "mul":
out[index] *= updates[i * fused_updates_dimension + j]
elif mode == "min":
out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j])
elif mode == "max":
out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j])
else:
raise NotImplementedError("scatter_nd mode not in [update, add]:", mode)
raise NotImplementedError(
"scatter_nd mode not in [update, add, mul, min, max]:", mode
)

return ib.get()

out_buf = decl_buffer(data.shape, data.dtype, "out_buf")
return extern(
out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
[data, indices, updates],
lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
Expand Down
2 changes: 0 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5398,8 +5398,6 @@ def verify_eyelike(indata, dynamic=False):
"test_reduce_sum_negative_axes_keepdims_random",
"test_roialign_aligned_true",
"test_scatter_elements_with_duplicate_indices",
"test_scatternd_add",
"test_scatternd_multiply",
"test_sequence_insert_at_back",
"test_sequence_insert_at_front",
"test_sequence_map_add_1_sequence_1_tensor",
Expand Down
19 changes: 15 additions & 4 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ def verify_scatter_nd_with_stack(
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol)

# TODO(vcchernov): check frameworks' int type requirements. ONNX expects int64 only
for indice_dtype in ["uint8", "uint16", "uint32"]:
data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]]).astype(indice_dtype)
Expand All @@ -2009,7 +2010,7 @@ def verify_scatter_nd_with_stack(
verify_scatter_nd(data, indices, updates, out, mode="add")
verify_scatter_nd_with_stack(data, indices, updates, out)

for mode in ["add", "update"]:
for mode in ["update", "add", "mul", "min", "max"]:
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
indice_dtype
)
Expand All @@ -2019,10 +2020,20 @@ def verify_scatter_nd_with_stack(
out = data.copy()
for i in range(indices.shape[1]):
for j in range(updates.shape[1]):
if mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "update":
if mode == "update":
out[indices[0, i], indices[1, i], j] = updates[i, j]
elif mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "mul":
out[indices[0, i], indices[1, i], j] *= updates[i, j]
elif mode == "min":
out[indices[0, i], indices[1, i], j] = min(
out[indices[0, i], indices[1, i], j], updates[i, j]
)
elif mode == "max":
out[indices[0, i], indices[1, i], j] = max(
out[indices[0, i], indices[1, i], j], updates[i, j]
)
verify_scatter_nd(data, indices, updates, out, mode)
verify_scatter_nd_with_stack(data, indices, updates, out, mode)

Expand Down
18 changes: 14 additions & 4 deletions tests/python/topi/python/test_topi_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
out[0, :] += updates[2, :]
check_scatter_nd(data, indices, updates, out)

for mode in ["add", "update"]:
for mode in ["update", "add", "mul", "min", "max"]:
updates = np.ones((5, 3)).astype("float64")
indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype(
"int64"
Expand All @@ -71,10 +71,20 @@ def check_scatter_nd(data, indices, updates, out, mode="add"):
out = data.copy()
for i in range(indices.shape[1]):
for j in range(updates.shape[1]):
if mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "update":
if mode == "update":
out[indices[0, i], indices[1, i], j] = updates[i, j]
elif mode == "add":
out[indices[0, i], indices[1, i], j] += updates[i, j]
elif mode == "mul":
out[indices[0, i], indices[1, i], j] *= updates[i, j]
elif mode == "min":
out[indices[0, i], indices[1, i], j] = min(
out[indices[0, i], indices[1, i], j], updates[i, j]
)
elif mode == "max":
out[indices[0, i], indices[1, i], j] = max(
out[indices[0, i], indices[1, i], j], updates[i, j]
)

check_scatter_nd(data, indices, updates, out, mode)

Expand Down