diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b1ab40e1bf02..247b39301b78 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -493,18 +493,129 @@ def convert_pad(node, **kwargs): return [node] +def create_helper_tensor_node(input_vals, output_name, kwargs): + """create extra tensor node from numpy values""" + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_vals.dtype] + + tensor_node = onnx.helper.make_tensor_value_info( + name=output_name, + elem_type=data_type, + shape=input_vals.shape + ) + kwargs["initializer"].append( + onnx.helper.make_tensor( + name=output_name, + data_type=data_type, + dims=input_vals.shape, + vals=input_vals.flatten(), + raw=False, + ) + ) + + return [tensor_node] + +def create_helper_reshape_node(input_name, output_name, shape, kwargs): + """create extra reshape node with static shape""" + shape_tensor_node, = create_helper_tensor_node( + np.asarray(shape, dtype=np.int64), output_name + "__shape", kwargs + ) + reshape_node = onnx.helper.make_node( + "Reshape", + inputs=[input_name, shape_tensor_node.name], + outputs=[output_name], + name=output_name + ) + + return [shape_tensor_node, reshape_node] -def create_helper_trans_node(op_name, input_node, node_name): - """create extra transpose node for dot operator""" - node_name = op_name + "_" + node_name +def create_helper_trans_node(input_name, output_name, perm=None): + """create extra transpose node""" + attrs = {} + if perm is not None: + attrs['perm'] = perm trans_node = onnx.helper.make_node( 'Transpose', - inputs=[input_node], - outputs=[node_name], - name=node_name + inputs=[input_name], + outputs=[output_name], + name=output_name, + **attrs ) - return trans_node + return [trans_node] +def create_helper_concat_node(inputs, output_name, axis=0): + """create extra concat node""" + concat_node = onnx.helper.make_node( + "Concat", + inputs=inputs, + outputs=[output_name], + name=output_name, + axis=axis, + ) + return [concat_node] + +def create_helper_expand_node(input_name, output_name, expand_shape): + """create extra expand node""" + expand_node = onnx.helper.make_node( + "Expand", + inputs=[input_name, expand_shape], + outputs=[output_name], + name=output_name, + ) + return [expand_node] + +def create_helper_gather_node( + input_name, output_name, + indices, kwargs, + axis=None + ): + """create extra gather node with static indices""" + attrs = {} + if axis is not None: + attrs['axis'] = axis + gather_tensor_node, = create_helper_tensor_node( + np.asarray(indices, np.int64), output_name + "__indices", kwargs + ) + gather_node = onnx.helper.make_node( + "Gather", + inputs=[input_name, gather_tensor_node.name], + outputs=[output_name], + name=output_name, + **attrs + ) + return [gather_tensor_node, gather_node] + +def create_helper_build_values_node( + inputs, output_name, + dtype, kwargs, axis=0 + ): + """create extra node, with specified values + + (allows mixing node names and static values) + """ + values = [] + tensor_nodes = [] + for idx, inp in enumerate(inputs): + if not isinstance(inp, (str, bytes)): + inp, = create_helper_tensor_node( + np.array([inp], dtype=dtype), + output_name + "__value" + str(idx), + kwargs + ) + tensor_nodes.append(inp) + inp = inp.name + values.append(inp) + concat_node, = create_helper_concat_node(values, output_name, axis=axis) + return tensor_nodes + [concat_node,] + +def create_helper_shape_node(input_name, output_name): + """create extra shape node for specified input node""" + shape_node = onnx.helper.make_node( + "Shape", + inputs=[input_name], + outputs=[output_name], + name=output_name, + ) + return [shape_node] @mx_op.register("dot") def convert_dot(node, **kwargs): @@ -524,11 +635,11 @@ def convert_dot(node, **kwargs): op_name = "transpose" + str(kwargs["idx"]) if trans_a: - trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a') - input_node_a = op_name+"_a" + input_node_a = op_name + "_a" + trans_a_node, = create_helper_trans_node(input_nodes[0], input_node_a) if trans_b: - trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b') - input_node_b = op_name+"_b" + input_node_b = op_name + "_b" + trans_b_node, = create_helper_trans_node(input_nodes[1], input_node_b) matmul_node = onnx.helper.make_node( 'MatMul', @@ -1503,16 +1614,34 @@ def convert_slice_axis(node, **kwargs): in_shape = kwargs['in_shape'][0] ends = in_shape[axes] + export_nodes = [] + + starts = np.atleast_1d(np.asarray(starts, dtype=np.int)) + ends = np.atleast_1d(np.asarray(ends, dtype=np.int)) + axes = np.atleast_1d(np.asarray(axes, dtype=np.int)) + + starts_node = create_helper_tensor_node(starts, name + '__starts', kwargs) + export_nodes.extend(starts_node) + starts_node = starts_node[-1].name + + ends_node = create_helper_tensor_node(ends, name + '__ends', kwargs) + export_nodes.extend(ends_node) + ends_node = ends_node[-1].name + + axes_node = create_helper_tensor_node(axes, name + '__axes', kwargs) + export_nodes.extend(axes_node) + axes_node = axes_node[-1].name + + input_node = input_nodes[0] node = onnx.helper.make_node( "Slice", - input_nodes, + [input_node, starts_node, ends_node, axes_node], [name], - axes=[axes], - starts=[starts], - ends=[int(ends)], name=name, ) - return [node] + export_nodes.extend([node]) + + return export_nodes @mx_op.register("SliceChannel") @@ -2070,14 +2199,22 @@ def convert_topk(node, **kwargs): else: raise NotImplementedError("ONNX expects both value and indices as output") + export_nodes = [] + + k = np.asarray([k], dtype=np.int) + k_node = create_helper_tensor_node(k, name + '__k', kwargs) + export_nodes.extend(k_node) + k_node = k_node[-1].name + + input_node = input_nodes[0] topk_node = onnx.helper.make_node( "TopK", - input_nodes, + [input_node, k_node], outputs, axis=axis, - k=k, name=name ) + export_nodes.extend([topk_node]) return [topk_node] diff --git a/tests/python/unittest/onnx/test_cases.py b/tests/python/unittest/onnx/test_cases.py index 9a72d58e0490..932310281bd6 100644 --- a/tests/python/unittest/onnx/test_cases.py +++ b/tests/python/unittest/onnx/test_cases.py @@ -39,9 +39,6 @@ 'test_transpose', 'test_globalmaxpool', 'test_globalaveragepool', - 'test_slice_cpu', - 'test_slice_neg', - 'test_slice_end', 'test_reciprocal', 'test_sqrt', 'test_pow', @@ -54,19 +51,19 @@ 'test_operator_maxpool', 'test_operator_params', 'test_operator_permute2', - 'test_cos', - 'test_sin', + 'test_cos[^h]', + 'test_sin[^h]', 'test_tan', - 'test_acos', - 'test_asin', - 'test_atan', + 'test_acos[^h]', + 'test_asin[^h]', + 'test_atan[^h]', 'test_squeeze', - 'test_matmul', + 'test_matmul_', 'test_depthtospace', 'test_hardsigmoid', 'test_instancenorm', 'test_shape', - 'test_cast', + 'test_cast((?!STRING).)*$', 'test_clip', 'test_size', 'test_dropout', @@ -80,7 +77,6 @@ 'test_softplus', 'test_reduce_', 'test_split_equal', - 'test_top_k', 'test_gather' ], 'import': ['test_softsign', @@ -116,7 +112,7 @@ 'test_softmax_functional', 'test_softmax_lastdim', ], - 'export': ['test_ConvTranspose2d'] + 'export': [] } STANDARD_MODEL = { diff --git a/tests/python/unittest/onnx/test_node.py b/tests/python/unittest/onnx/test_node.py index f7fc5c855213..3e2786c1bacc 100644 --- a/tests/python/unittest/onnx/test_node.py +++ b/tests/python/unittest/onnx/test_node.py @@ -208,9 +208,8 @@ def test_imports(self): npt.assert_almost_equal(np_out, mxnet_out, decimal=4) def test_exports(self): - input_shape = (2,1,3,1) for test in export_test_cases: - test_name, onnx_name, mx_op, attrs = test + test_name, onnx_name, mx_op, input_shape, attrs = test input_sym = mx.sym.var('data') outsym = mx_op(input_sym, **attrs) converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32, @@ -287,10 +286,12 @@ def test_exports(self): ("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1}) ] -# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map) +# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, input_shape, attribute map) export_test_cases = [ - ("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}), - ("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)}) + ("test_expand", "Expand", mx.sym.broadcast_to, (2,1,3,1), {'shape': (2,1,3,1)}), + ("test_tile", "Tile", mx.sym.tile, (2,1,3,1), {'reps': (2,3)}), + ("test_topk", "TopK", mx.sym.topk, (2, 10, 2), {'k': 3, 'axis': 1, 'ret_typ': 'both', 'dtype': np.int64}), + ("test_slice_axis", "Slice", mx.sym.slice_axis, (2, 10, 2), {'begin': 3, 'end': 7, 'axis': 1}), ] if __name__ == '__main__':