diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 274a421e5719..82403d7c40ee 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -168,8 +168,9 @@ struct ScatterNDAttrs : public tvm::AttrsNode { String mode; TVM_DECLARE_ATTRS(ScatterNDAttrs, "relay.attrs.ScatterNDAttrs") { - TVM_ATTR_FIELD(mode).describe( - "Accumulation mode of the scatter, either \"update\" or \"add\"."); + TVM_ATTR_FIELD(mode).set_default("update").describe( + "Accumulation mode of the ScatterND, " + "either \"update\", \"add\", \"mul\", \"min\" or \"max\"."); } }; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8de5e0e08bd8..aebc6daa5ebe 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2856,12 +2856,63 @@ def _impl_v1(cls, inputs, attr, params): class ScatterND(OnnxOpConverter): """Operator converter for ScatterND.""" + @classmethod + def _inputs_check(cls, inputs): + assert ( + len(inputs) == 3 + ), "ScatterND takes 3 inputs (data, indices, updates), {} given".format(len(inputs)) + assert infer_type(inputs[1]).checked_type.dtype == "int64" + + data_rank = len(infer_shape(inputs[0])) + assert data_rank > 0, "Data rank higher than 0 is expected" + indices_rank = len(infer_shape(inputs[1])) + assert indices_rank > 0, "Indices rank higher than 0 is expected" + updates_rank = len(infer_shape(inputs[2])) + assert ( + updates_rank == data_rank + indices_rank - infer_shape(inputs[1])[-1] - 1 + ), "Updates rank should be equal to data_rank + indices_rank - indices_shape[-1] - 1" + + @classmethod + def _reduction_check(cls, attr, red_valids=None): + reduction = attr.get("reduction", None) + if reduction is None: + reduction = b"update" + reduction = reduction.decode("utf-8") + if red_valids is None: + red_valids = ["update"] + assert reduction in red_valids, "Only {} reductions are supported, but {} is gotten".format( + red_valids, reduction + ) + + return reduction + @classmethod def _impl_v11(cls, inputs, attr, params): + cls._inputs_check(inputs) + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd(inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2]) + + @classmethod + def _impl_v16(cls, inputs, attr, params): + cls._inputs_check(inputs) + reduction = cls._reduction_check(attr, ["update", "add", "mul"]) + + indices_dim = len(infer_shape(inputs[1])) + axes = list(range(indices_dim)) + return _op.scatter_nd( + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction + ) + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls._inputs_check(inputs) + reduction = cls._reduction_check(attr, ["update", "add", "mul", "min", "max"]) + indices_dim = len(infer_shape(inputs[1])) axes = list(range(indices_dim)) return _op.scatter_nd( - inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], "update" + inputs[0], _op.transpose(inputs[1], axes[-1:] + axes[:-1]), inputs[2], reduction ) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c7234f340395..782797dadb83 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -420,7 +420,13 @@ def scatter_nd(data, indices, updates, mode="update"): The values to update. mode : string, optional - The accumulation mode for scatter. "update" or "add" + The accumulation mode for scatter. "update", "add", "mul", "min" or "max" + If update, the update values will replace the input data + If add, the update values will be added to the input data + If mul, the update values will be multiply to the input data + If min, there is choice of minimal between the update values and the input data + If max, there is choice of maximal between the update values and the input data + It is "update" by default Returns ------- diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index fa7545cd323a..1bdd53156623 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument """Scatter operator """ import tvm -from tvm import te, autotvm +from tvm import te, tir, autotvm from ..scatter import _verify_scatter_nd_inputs from ..generic import schedule_extern from .nms import atomic_add @@ -871,8 +871,20 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tir.min( + out[index], updates[i * fused_updates_dimension + j] + ) + elif mode == "max": + out[index] = tir.max( + out[index], updates[i * fused_updates_dimension + j] + ) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get() diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index e0578aab41b9..45629c005f79 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -16,11 +16,11 @@ # under the License. # pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks """Scatter operator""" -from ..te import extern, hybrid -from ..tir import decl_buffer, expr, ir_builder +from tvm import te, tir # hide redefinition of min and max +from tvm.tir import expr -@hybrid.script +@te.hybrid.script def _scatter_1d(data, indices, updates): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -30,7 +30,7 @@ def _scatter_1d(data, indices, updates): return out -@hybrid.script +@te.hybrid.script def _scatter_2d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -52,7 +52,7 @@ def _scatter_2d(data, indices, updates, axis): return out -@hybrid.script +@te.hybrid.script def _scatter_3d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -96,7 +96,7 @@ def _scatter_3d(data, indices, updates, axis): return out -@hybrid.script +@te.hybrid.script def _scatter_4d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) for i in range(data.shape[0]): @@ -269,7 +269,7 @@ def scatter_nd(data, indices, updates, mode): def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): # pylint: disable=invalid-name - ib = ir_builder.create() + ib = tir.ir_builder.create() data = ib.buffer_ptr(data_ptr) indices = ib.buffer_ptr(indices_ptr) @@ -308,13 +308,21 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): out[index] = updates[i * fused_updates_dimension + j] elif mode == "add": out[index] += updates[i * fused_updates_dimension + j] + elif mode == "mul": + out[index] *= updates[i * fused_updates_dimension + j] + elif mode == "min": + out[index] = tir.min(out[index], updates[i * fused_updates_dimension + j]) + elif mode == "max": + out[index] = tir.max(out[index], updates[i * fused_updates_dimension + j]) else: - raise NotImplementedError("scatter_nd mode not in [update, add]:", mode) + raise NotImplementedError( + "scatter_nd mode not in [update, add, mul, min, max]:", mode + ) return ib.get() - out_buf = decl_buffer(data.shape, data.dtype, "out_buf") - return extern( + out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + return te.extern( [data.shape], [data, indices, updates], lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 0a032843267a..470a67e86c93 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5398,8 +5398,6 @@ def verify_eyelike(indata, dynamic=False): "test_reduce_sum_negative_axes_keepdims_random", "test_roialign_aligned_true", "test_scatter_elements_with_duplicate_indices", - "test_scatternd_add", - "test_scatternd_multiply", "test_sequence_insert_at_back", "test_sequence_insert_at_front", "test_sequence_map_add_1_sequence_1_tensor", diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index efd37f2ecd22..225210f4d617 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -1983,6 +1983,7 @@ def verify_scatter_nd_with_stack( ) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=rtol, atol=atol) + # TODO(vcchernov): check frameworks' int type requirements. ONNX expects int64 only for indice_dtype in ["uint8", "uint16", "uint32"]: data = np.zeros((2, 2)).astype("int64") indices = np.array([[1, 1, 0], [0, 1, 0]]).astype(indice_dtype) @@ -2009,7 +2010,7 @@ def verify_scatter_nd_with_stack( verify_scatter_nd(data, indices, updates, out, mode="add") verify_scatter_nd_with_stack(data, indices, updates, out) - for mode in ["add", "update"]: + for mode in ["update", "add", "mul", "min", "max"]: indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( indice_dtype ) @@ -2019,10 +2020,20 @@ def verify_scatter_nd_with_stack( out = data.copy() for i in range(indices.shape[1]): for j in range(updates.shape[1]): - if mode == "add": - out[indices[0, i], indices[1, i], j] += updates[i, j] - elif mode == "update": + if mode == "update": out[indices[0, i], indices[1, i], j] = updates[i, j] + elif mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "mul": + out[indices[0, i], indices[1, i], j] *= updates[i, j] + elif mode == "min": + out[indices[0, i], indices[1, i], j] = min( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) + elif mode == "max": + out[indices[0, i], indices[1, i], j] = max( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) verify_scatter_nd(data, indices, updates, out, mode) verify_scatter_nd_with_stack(data, indices, updates, out, mode) diff --git a/tests/python/topi/python/test_topi_scatter.py b/tests/python/topi/python/test_topi_scatter.py index 025e44889d63..ccc34837a05a 100644 --- a/tests/python/topi/python/test_topi_scatter.py +++ b/tests/python/topi/python/test_topi_scatter.py @@ -61,7 +61,7 @@ def check_scatter_nd(data, indices, updates, out, mode="add"): out[0, :] += updates[2, :] check_scatter_nd(data, indices, updates, out) - for mode in ["add", "update"]: + for mode in ["update", "add", "mul", "min", "max"]: updates = np.ones((5, 3)).astype("float64") indices = np.stack((np.random.randint(2, size=5), np.random.randint(7, size=5))).astype( "int64" @@ -71,10 +71,20 @@ def check_scatter_nd(data, indices, updates, out, mode="add"): out = data.copy() for i in range(indices.shape[1]): for j in range(updates.shape[1]): - if mode == "add": - out[indices[0, i], indices[1, i], j] += updates[i, j] - elif mode == "update": + if mode == "update": out[indices[0, i], indices[1, i], j] = updates[i, j] + elif mode == "add": + out[indices[0, i], indices[1, i], j] += updates[i, j] + elif mode == "mul": + out[indices[0, i], indices[1, i], j] *= updates[i, j] + elif mode == "min": + out[indices[0, i], indices[1, i], j] = min( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) + elif mode == "max": + out[indices[0, i], indices[1, i], j] = max( + out[indices[0, i], indices[1, i], j], updates[i, j] + ) check_scatter_nd(data, indices, updates, out, mode)