diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 177ec5d40146..dbe7e02eb80f 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2255,6 +2255,44 @@ MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, NDArrayHandle** aux_states, ExecutorHandle shared_exec_handle, ExecutorHandle* out); + + +MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int64_t* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); + + /*! * \brief DEPRECATED. Use MXExecutorReshapeEx instead. * Return a new executor with the same symbol and shared memory, diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index b8e8db57188c..6146ab9dc50e 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -1695,42 +1695,80 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, aux_state_handles = ctypes.POINTER(NDArrayHandle)() try: - check_call(_LIB.MXExecutorSimpleBindEx(self.handle, - ctypes.c_int(ctx.device_typeid), - ctypes.c_int(ctx.device_id), - num_ctx_map_keys, - ctx_map_keys, - ctx_map_dev_types, - ctx_map_dev_ids, - mx_uint(provided_req_type_list_len), - provided_grad_req_names, - provided_grad_req_types, - mx_uint(len(provided_arg_shape_names)), - c_str_array(provided_arg_shape_names), - c_array_buf(mx_int, - array('I', provided_arg_shape_data)), - c_array_buf(mx_uint, - array('i', provided_arg_shape_idx)), - num_provided_arg_types, - provided_arg_type_names, - provided_arg_type_data, - num_provided_arg_stypes, - provided_arg_stype_names, - provided_arg_stype_data, - mx_uint(len(shared_arg_name_list)), - c_str_array(shared_arg_name_list), - ctypes.byref(shared_buffer_len), - shared_buffer_names, - shared_buffer_handles, - ctypes.byref(updated_shared_buffer_names), - ctypes.byref(updated_shared_buffer_handles), - ctypes.byref(num_in_args), - ctypes.byref(in_arg_handles), - ctypes.byref(arg_grad_handles), - ctypes.byref(num_aux_states), - ctypes.byref(aux_state_handles), - shared_exec_handle, - ctypes.byref(exe_handle))) + if sys.version_info[0] > 2 and _int64_enabled(): + check_call(_LIB.MXExecutorSimpleBindEx64(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int64, + array('q', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) + else: + check_call(_LIB.MXExecutorSimpleBindEx(self.handle, + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + num_ctx_map_keys, + ctx_map_keys, + ctx_map_dev_types, + ctx_map_dev_ids, + mx_uint(provided_req_type_list_len), + provided_grad_req_names, + provided_grad_req_types, + mx_uint(len(provided_arg_shape_names)), + c_str_array(provided_arg_shape_names), + c_array_buf(mx_int, + array('I', provided_arg_shape_data)), + c_array_buf(mx_uint, + array('i', provided_arg_shape_idx)), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, + mx_uint(len(shared_arg_name_list)), + c_str_array(shared_arg_name_list), + ctypes.byref(shared_buffer_len), + shared_buffer_names, + shared_buffer_handles, + ctypes.byref(updated_shared_buffer_names), + ctypes.byref(updated_shared_buffer_handles), + ctypes.byref(num_in_args), + ctypes.byref(in_arg_handles), + ctypes.byref(arg_grad_handles), + ctypes.byref(num_aux_states), + ctypes.byref(aux_state_handles), + shared_exec_handle, + ctypes.byref(exe_handle))) except MXNetError as e: error_msg = "simple_bind error. Arguments:\n" for k, v in kwargs.items(): diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ff85b4fd62fa..afc64f73de7c 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -515,44 +515,11 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, API_END(); } -/*! - * \brief - * \param symbol_handle symbol handle - * \param dev_type default device type - * \param dev_id default device id - * \param num_g2c_keys number of group2ctx keys - * \param g2c_keys key list of group2ctx - * \param g2c_dev_types device type list of group2ctx - * \param g2c_dev_ids id list of group2ctx - * \param provided_grad_req_list_len grad_req length provided by users in front-end - * \param provided_grad_req_names grad_req names provided by users in front-end - * \param provided_grad_req_types req types provided by users in front-end - * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes - * \param provided_arg_shape_names name list of provided shapes - * \param provided_arg_shape_data provided shape data - * \param provided_arg_shape_idx provided shape data index - * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes - * \param provided_arg_dtype_names argument name list of provided dtypes - * \param provided_arg_dtypes data of provided dtypes - * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types - * \param provided_arg_stype_names argument name list of provided storage types - * \param provided_arg_stypes data of provided storage types - * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec - * \param shared_arg_name_list parameter name list passed from _bind_ith_exec - * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec - * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec - * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec - * \param updated_shared_buffer_name_list updated shared data array names after binding - * \param updated_shared_buffer_handle_list updated shared data arrays after binding - * \param num_in_args number of input arguments of this sym - * \param in_args list_arguments associated with the current executor - * \param arg_grads list of gradients of in_args associated with the current executor - * \param num_aux_states number of aux states of this sym - * \param aux_states list_auxiliary_states associated with the current executor - * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec - * \param out the handle of the executor to be created - */ -int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + +namespace mxnet { + +template +int _SimpleBindImpl(SymbolHandle symbol_handle, int dev_type, int dev_id, const uint32_t num_g2c_keys, @@ -564,7 +531,7 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, const char** provided_grad_req_types, const uint32_t num_provided_arg_shapes, const char** provided_arg_shape_names, - const int* provided_arg_shape_data, + const DType* provided_arg_shape_data, const uint32_t* provided_arg_shape_idx, const uint32_t num_provided_arg_dtypes, const char** provided_arg_dtype_names, @@ -849,6 +816,192 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, API_END(); } +} // namespace mxnet + + +/*! + * \brief Executor for simple_bind + * when INT64_TENSOR_SIZE = OFF + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBindEx(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + return mxnet::_SimpleBindImpl(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, + provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out); +} + + +// TODO(ChaiBapchya): add API doc for rest of C APIs for int64 +/*! + * \brief Large tensor specific implementation for simple_bind executor + * when USE_INT64_TENSOR_SIZE = ON + * \param symbol_handle symbol handle + * \param dev_type default device type + * \param dev_id default device id + * \param num_g2c_keys number of group2ctx keys + * \param g2c_keys key list of group2ctx + * \param g2c_dev_types device type list of group2ctx + * \param g2c_dev_ids id list of group2ctx + * \param provided_grad_req_list_len grad_req length provided by users in front-end + * \param provided_grad_req_names grad_req names provided by users in front-end + * \param provided_grad_req_types req types provided by users in front-end + * \param num_provided_arg_shapes number of user provided in_arg and aux_state shapes + * \param provided_arg_shape_names name list of provided shapes + * \param provided_arg_shape_data provided shape data + * \param provided_arg_shape_idx provided shape data index + * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes + * \param provided_arg_dtype_names argument name list of provided dtypes + * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types + * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec + * \param shared_arg_name_list parameter name list passed from _bind_ith_exec + * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec + * \param shared_buffer_name_list shared data array names passed from _bind_ith_exec + * \param shared_buffer_handle_list shared data array handles passed from _bind_ith_exec + * \param updated_shared_buffer_name_list updated shared data array names after binding + * \param updated_shared_buffer_handle_list updated shared data arrays after binding + * \param num_in_args number of input arguments of this sym + * \param in_args list_arguments associated with the current executor + * \param arg_grads list of gradients of in_args associated with the current executor + * \param num_aux_states number of aux states of this sym + * \param aux_states list_auxiliary_states associated with the current executor + * \param shared_exec_handle shared excutor handle passed from _bind_ith_exec + * \param out the handle of the executor to be created + */ +int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle, + int dev_type, + int dev_id, + const uint32_t num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const uint32_t provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const uint32_t num_provided_arg_shapes, + const char** provided_arg_shape_names, + const int64_t* provided_arg_shape_data, + const uint32_t* provided_arg_shape_idx, + const uint32_t num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const uint32_t num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const uint32_t num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + uint32_t* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + uint32_t* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out) { + return mxnet::_SimpleBindImpl(symbol_handle, + dev_type, dev_id, + num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids, + provided_grad_req_list_len, provided_grad_req_names, + provided_grad_req_types, + num_provided_arg_shapes, provided_arg_shape_names, + provided_arg_shape_data, provided_arg_shape_idx, + num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes, + num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes, + num_shared_arg_names, shared_arg_name_list, + shared_buffer_len, shared_buffer_name_list, + shared_buffer_handle_list, updated_shared_buffer_name_list, + updated_shared_buffer_handle_list, + num_in_args, in_args, arg_grads, + num_aux_states, aux_states, + shared_exec_handle, out); +} + + int MXExecutorReshape(int partial_shaping, int allow_up_sizing, int dev_type, diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index c18a95400f22..74ac179a7e60 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1351,17 +1351,17 @@ def check_trunc(): def create_input_for_trigonometric_ops(vals): - # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator + # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using broadcast_to operator inp = nd.array(vals).reshape(1, 5) inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10)) return inp -def assert_correctness_of_trigonometric_ops(output, expected_vals): +def assert_correctness_of_trigonometric_ops(output, expected_vals, atol=1e-3): # checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector output_idx_to_inspect = [0, 1, -3, -2, -1] for i in range(len(output_idx_to_inspect)): - assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= 1e-3 + assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= atol def test_trigonometric_ops(): diff --git a/tests/nightly/test_large_vector.py b/tests/nightly/test_large_vector.py index b8edc83220bd..c6a99a5d0826 100644 --- a/tests/nightly/test_large_vector.py +++ b/tests/nightly/test_large_vector.py @@ -64,7 +64,7 @@ def test_ndarray_random_randint(): high = 2**34 a = nd.random.randint(low, high, dtype=np.int64, shape=LARGE_X).asnumpy() assert a.shape == (LARGE_X,) - assert (a >= low).all() and (a < high).all() + assert (a >= low).all() and (a < high).all() def test_ndarray_empty(): @@ -710,6 +710,39 @@ def test_full(): assert a[-1] == 3 +def test_regression(): + shape = (LARGE_X, ) + + def check_regression(symbol, forward, shape): + # init executor + data_s = mx.symbol.Variable('data') + label_s = mx.symbol.Variable('label') + out_s = symbol(data=data_s, label=label_s) + exe = out_s.simple_bind(ctx=mx.cpu(0), data=shape, label=shape) + + arg_map = dict(zip(out_s.list_arguments(), exe.arg_arrays)) + + # init data + data = mx.random.uniform(-1, -1, shape) + arg_map["data"][:] = data + atol = 1e-5 + density = 0.5 + stype = 'default' + label = arg_map["label"] + label[:] = rand_ndarray(shape, stype, density=density) + exe.forward(is_train=True) + exe.backward() + np_out = forward(data.asnumpy()) + assert_almost_equal(exe.outputs[0].asnumpy(), np_out, atol=atol) + + check_regression(mx.symbol.LogisticRegressionOutput, + lambda x: 1.0 / (1.0 + np.exp(-x)), + shape) + check_regression(mx.symbol.LinearRegressionOutput, + lambda x: x, + shape) + + def test_sign(): a = mx.nd.random.normal(-1, 1, shape=LARGE_X) mx_res = mx.nd.sign(a) @@ -978,11 +1011,11 @@ def test_add_n(): def test_modulo(): x = mx.nd.ones(LARGE_X)*6 y = mx.nd.ones(LARGE_X)*4 - z = (x%y) + z = (x % y) assert z[0] == 2 assert z[-1] == 2 x = mx.nd.ones(LARGE_X)*5 - z = nd.modulo(x,y) + z = nd.modulo(x, y) assert z[0] == 1 assert z[-1] == 1 @@ -1022,6 +1055,16 @@ def test_gather(): assert np.sum(arr[idx] == 2) == 10 +def test_infer_shape(): + data_1 = mx.symbol.Variable('data_1') + data_2 = mx.symbol.Variable('data_2') + add = data_1+data_2 + # > add.infer_shape(data_1=(LARGE_X,), data_2=(LARGE_X,)) + # OUTPUT - arg_shapes, out_shapes, aux_shapes + _, out_shapes, _ = add.infer_shape(data_1=(LARGE_X,), data_2=(LARGE_X,)) + assert out_shapes == [(LARGE_X,)] + + if __name__ == '__main__': import nose nose.runmodule()