From 8d371a396c2aa2408299b01b778b90735d6089ac Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 9 Mar 2020 13:54:16 +0800 Subject: [PATCH 01/11] Support ADT as FFI return value --- include/mxnet/runtime/c_runtime_api.h | 18 +++++ include/mxnet/runtime/container.h | 2 +- include/mxnet/runtime/ndarray_handle.h | 47 ++++++++++++ python/mxnet/__init__.py | 2 + python/mxnet/_ffi/_ctypes/function.py | 22 +++++- python/mxnet/_ffi/_ctypes/object.py | 46 ++++++++++- python/mxnet/_ffi/object.py | 101 ++++++++++++++++++++++++- python/mxnet/container.py | 53 +++++++++++++ python/mxnet/ndarray_handle.py | 29 +++++++ src/runtime/container.cc | 72 ++++++++++++++++++ src/runtime/ndarray_handle.cc | 42 ++++++++++ src/runtime/object.cc | 15 ++++ src/runtime/object_internal.h | 8 ++ 13 files changed, 447 insertions(+), 10 deletions(-) create mode 100644 include/mxnet/runtime/ndarray_handle.h create mode 100644 python/mxnet/container.py create mode 100644 python/mxnet/ndarray_handle.py create mode 100644 src/runtime/container.cc create mode 100644 src/runtime/ndarray_handle.cc diff --git a/include/mxnet/runtime/c_runtime_api.h b/include/mxnet/runtime/c_runtime_api.h index bbc8862d5439..69de9ca27d12 100644 --- a/include/mxnet/runtime/c_runtime_api.h +++ b/include/mxnet/runtime/c_runtime_api.h @@ -156,6 +156,24 @@ MXNET_DLL int MXNetFuncListGlobalNames(int* out_size, */ MXNET_DLL int MXNetObjectFree(MXNetObjectHandle obj); + +/*! + * \brief Get the type_index from an object. + * + * \param obj The object handle. + * \param out_tindex the output type index. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNetObjectGetTypeIndex(MXNetObjectHandle obj, unsigned* out_tindex); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNetObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); + #ifdef __cplusplus } // extern "C" #endif diff --git a/include/mxnet/runtime/container.h b/include/mxnet/runtime/container.h index cd719aaa51a6..fc1d4a173669 100644 --- a/include/mxnet/runtime/container.h +++ b/include/mxnet/runtime/container.h @@ -171,8 +171,8 @@ class ADTObj : public Object, public InplaceArrayBase { uint32_t size{0}; // The fields of the structure follows directly in memory. - static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT; static constexpr const char* _type_key = "MXNet.ADT"; + static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT; MXNET_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object) private: diff --git a/include/mxnet/runtime/ndarray_handle.h b/include/mxnet/runtime/ndarray_handle.h new file mode 100644 index 000000000000..aa19595b50cd --- /dev/null +++ b/include/mxnet/runtime/ndarray_handle.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ndarray_handle.h + * \brief NDArray handle types + */ +#ifndef MXNET_RUNTIME_NDARRAY_HANDLE_H_ +#define MXNET_RUNTIME_NDARRAY_HANDLE_H_ +#include +#include + +namespace mxnet { + +class NDArrayHandleObj : public Object { + public: + /*! \brief the Internal value. */ + NDArray* value; + + static constexpr const char* _type_key = "MXNet.NDArrayHandle"; + MXNET_DECLARE_FINAL_OBJECT_INFO(NDArrayHandleObj, Object) +}; + +class NDArrayHandle : public ObjectRef { +public: + MXNET_DEFINE_OBJECT_REF_METHODS(NDArrayHandle, ObjectRef, NDArrayHandleObj) +}; + +}; // namespace mxnet + +#endif // MXNET_RUNTIME_NDARRAY_HANDLE_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index d6d6a1f49e8e..c9fac79ea085 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -108,3 +108,5 @@ from . import _api_internal from . import api +from . import container +from . import ndarray_handle diff --git a/python/mxnet/_ffi/_ctypes/function.py b/python/mxnet/_ffi/_ctypes/function.py index 0a005dd7b749..51f18bbb9a38 100644 --- a/python/mxnet/_ffi/_ctypes/function.py +++ b/python/mxnet/_ffi/_ctypes/function.py @@ -31,6 +31,8 @@ from .object import ObjectBase from ..node_generic import convert_to_node from ..._ctypes.ndarray import NDArrayBase +from .object import ObjectBase, _set_class_object +from . import object as _object ObjectHandle = ctypes.c_void_p @@ -118,8 +120,20 @@ def __call__(self, *args): else RETURN_SWITCH[ret_tcode.value](ret_val, args)) -_CLASS_OBJECT = None +def __init_handle_by_constructor__(fconstructor, args): + """Initialize handle by constructor""" + temp_args = [] + values, tcodes, num_args = _make_mxnet_args(args, temp_args) + ret_val = MXNetValue() + ret_tcode = ctypes.c_int() + if _LIB.MXNetFuncCall( + fconstructor.handle, values, tcodes, ctypes.c_int(num_args), + ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0: + raise get_last_ffi_error() + _ = temp_args + _ = args + assert ret_tcode.value == TypeCode.OBJECT_HANDLE + handle = ret_val.v_handle + return handle -def _set_class_object(obj_class): - global _CLASS_OBJECT - _CLASS_OBJECT = obj_class +_object.__init_by_constructor__ = __init_handle_by_constructor__ diff --git a/python/mxnet/_ffi/_ctypes/object.py b/python/mxnet/_ffi/_ctypes/object.py index 85ab415692f6..8692b2a9b54c 100644 --- a/python/mxnet/_ffi/_ctypes/object.py +++ b/python/mxnet/_ffi/_ctypes/object.py @@ -25,14 +25,32 @@ from .types import RETURN_SWITCH, TypeCode ObjectHandle = ctypes.c_void_p +__init_by_constructor__ = None + +"""Maps object type to its constructor""" +OBJECT_TYPE = {} + +_CLASS_OBJECT = None + +def _set_class_object(object_class): + global _CLASS_OBJECT + _CLASS_OBJECT = object_class + +def _register_object(index, cls): + """register object class""" + # if issubclass(cls, NDArrayBase): + # _register_ndarray(index, cls) + # return + OBJECT_TYPE[index] = cls def _return_object(x): handle = x.v_handle if not isinstance(handle, ObjectHandle): handle = ObjectHandle(handle) - # Does not support specific cpp node class for now - cls = function._CLASS_OBJECT + tindex = ctypes.c_uint() + check_call(_LIB.MXNetObjectGetTypeIndex(handle, ctypes.byref(tindex))) + cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) @@ -50,4 +68,26 @@ def __del__(self): if _LIB is not None: check_call(_LIB.MXNetObjectFree(self.handle)) - # Does not support creation of cpp node class via python class + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # assign handle first to avoid error raising + self.handle = None + handle = __init_by_constructor__(fconstructor, args) + if not isinstance(handle, ObjectHandle): + handle = ObjectHandle(handle) + self.handle = handle diff --git a/python/mxnet/_ffi/object.py b/python/mxnet/_ffi/object.py index e0a4aa600f25..5435a70928a6 100644 --- a/python/mxnet/_ffi/object.py +++ b/python/mxnet/_ffi/object.py @@ -17,10 +17,107 @@ # pylint: disable=invalid-name """Runtime Object API Acknowledgement: This file originates from incubator-tvm""" -from ._ctypes.function import _set_class_object -from ._ctypes.object import ObjectBase as _ObjectBase +import os +import ctypes +from ..base import _LIB, check_call, c_str + +try: + if int(os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from ._ctypes.function import _set_class_object + from ._ctypes.object import ObjectBase as _ObjectBase + from ._ctypes.object import _register_object + else: + from ._cy3.core import _set_class_object + from ._cy3.core import ObjectBase as _ObjectBase + from ._cy3.core import _register_object +except ImportError: + if int(os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from ._ctypes.function import _set_class_object + from ._ctypes.object import ObjectBase as _ObjectBase + from ._ctypes.object import _register_object class Object(_ObjectBase): """Base class for all mxnet's runtime objects.""" + +def register_object(type_key=None): + """register object type. + + Parameters + ---------- + type_key : str or cls + The type key of the node + + Examples + -------- + The following code registers MyObject + using type key "test.MyObject" + + .. code-block:: python + + @register_object("test.MyObject") + class MyObject(Object): + pass + """ + object_name = type_key if isinstance(type_key, str) else type_key.__name__ + + def register(cls): + """internal register function""" + if hasattr(cls, "_type_index"): + tindex = cls._type_index + else: + tidx = ctypes.c_uint() + check_call(_LIB.MXNetObjectTypeKey2Index( + c_str(object_name), ctypes.byref(tidx))) + tindex = tidx.value + _register_object(tindex, cls) + return cls + + if isinstance(type_key, str): + return register + + return register(type_key) + + +def getitem_helper(obj, elem_getter, length, idx): + """Helper function to implement a pythonic getitem function. + + Parameters + ---------- + obj: object + The original object + + elem_getter : function + A simple function that takes index and return a single element. + + length : int + The size of the array + + idx : int or slice + The argument passed to getitem + + Returns + ------- + result : object + The result of getitem + """ + if isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else length + step = idx.step if idx.step is not None else 1 + if start < 0: + start += length + if stop < 0: + stop += length + return [elem_getter(obj, i) for i in range(start, stop, step)] + + if idx < -length or idx >= length: + raise IndexError("Index out of range. size: {}, got index {}" + .format(length, idx)) + if idx < 0: + idx += length + return elem_getter(obj, idx) + + _set_class_object(Object) diff --git a/python/mxnet/container.py b/python/mxnet/container.py new file mode 100644 index 000000000000..441ff6bbe3da --- /dev/null +++ b/python/mxnet/container.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Container data structures. +Acknowledgement: This file originates from incubator-tvm +""" +from ._ffi.object import Object, register_object, getitem_helper +from ._ffi.function import _init_api + +@register_object("MXNet.ADT") +class ADT(Object): + """Algebatic data type(ADT) object. + + Parameters + ---------- + tag : int + The tag of ADT. + + fields : list[Object] or tuple[Object] + The source tuple. + """ + def __init__(self, tag, fields): + for f in fields: + assert isinstance(f, (Object)), "Expect object" \ + ", but received : {0}".format(type(f)) + self.__init_handle_by_constructor__(_ADT, tag, *fields) + + @property + def tag(self): + return _GetADTTag(self) + + def __getitem__(self, idx): + return getitem_helper( + self, _GetADTFields, len(self), idx) + + def __len__(self): + return _GetADTSize(self) + +_init_api("mxnet.container") diff --git a/python/mxnet/ndarray_handle.py b/python/mxnet/ndarray_handle.py new file mode 100644 index 000000000000..fbfdcce632b9 --- /dev/null +++ b/python/mxnet/ndarray_handle.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +NDarray Handle +""" +from ._ffi.object import Object, register_object +from ._ffi.function import _init_api + +@register_object("MXNet.NDArrayHandle") +class NDArrayHandle(Object): + @property + def value(self): + return _GetNDArrayHandleValue(self) + +_init_api("mxnet.ndarray_handle") diff --git a/src/runtime/container.cc b/src/runtime/container.cc new file mode 100644 index 000000000000..f47e2290789f --- /dev/null +++ b/src/runtime/container.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/container.cc + * \brief Implementations of common plain old data (POD) containers. + */ +// Acknowledgement: This file originates from incubator-tvm +#include +#include +#include +#include + +namespace mxnet { +namespace runtime { + +MXNET_REGISTER_GLOBAL("container._GetADTTag") +.set_body([](MXNetArgs args, MXNetRetValue* rv) { + ObjectRef obj = args[0]; + const auto& adt = Downcast(obj); + *rv = static_cast(adt.tag()); +}); + +MXNET_REGISTER_GLOBAL("container._GetADTSize") +.set_body([](MXNetArgs args, MXNetRetValue* rv) { + ObjectRef obj = args[0]; + const auto& adt = Downcast(obj); + *rv = static_cast(adt.size()); +}); + + +MXNET_REGISTER_GLOBAL("container._GetADTFields") +.set_body([](MXNetArgs args, MXNetRetValue* rv) { + ObjectRef obj = args[0]; + int idx = args[1]; + const auto& adt = Downcast(obj); + CHECK_LT(idx, adt.size()); + *rv = adt[idx]; +}); + +MXNET_REGISTER_GLOBAL("container._ADT") +.set_body([](MXNetArgs args, MXNetRetValue* rv) { + int itag = args[0]; + size_t tag = static_cast(itag); + std::vector fields; + for (int i = 1; i < args.size(); i++) { + fields.push_back(args[i]); + } + *rv = ADT(tag, fields); +}); + +MXNET_REGISTER_OBJECT_TYPE(ADTObj); + +} // namespace runtime + +} // namespace mxnet diff --git a/src/runtime/ndarray_handle.cc b/src/runtime/ndarray_handle.cc new file mode 100644 index 000000000000..482a3076805b --- /dev/null +++ b/src/runtime/ndarray_handle.cc @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/api/ndarary_handle.cc + * \brief Implementations of NDArrayHandle + */ +#include +#include +#include +#include + +namespace mxnet { +namespace runtime { + +MXNET_REGISTER_GLOBAL("ndarray_handle._GetNDArrayHandleValue") +.set_body([](MXNetArgs args, MXNetRetValue* rv) { + ObjectRef obj = args[0]; + const auto& handle = Downcast(obj); + *rv = handle; +}); + +MXNET_REGISTER_OBJECT_TYPE(NDArrayHandleObj); + +} // namespace runtime +} // namespace mxnet diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 76d8b9776d8a..11bada181cec 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -84,6 +84,7 @@ class TypeContext { uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); + std::cout << skey << std::endl; auto it = type_key2index_.find(skey); if (it != type_key2index_.end()) { return it->second; @@ -213,3 +214,17 @@ int MXNetObjectFree(MXNetObjectHandle obj) { mxnet::runtime::ObjectInternal::ObjectFree(obj); API_END(); } + +int MXNetObjectGetTypeIndex(MXNetObjectHandle obj, unsigned* out_tindex) { + API_BEGIN(); + CHECK(obj != nullptr); + out_tindex[0] = static_cast(obj)->type_index(); + API_END(); +} + +int MXNetObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { + API_BEGIN(); + out_tindex[0] = mxnet::runtime::ObjectInternal::ObjectTypeKey2Index( + type_key); + API_END(); +} diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 002468456741..12165bbfe703 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -46,6 +46,14 @@ class ObjectInternal { static_cast(obj)->DecRef(); } } + /*! + * \brief Expose TypeKey2Index + * \param type_key The original type key. + * \return the corresponding index. + */ + static uint32_t ObjectTypeKey2Index(const std::string& type_key) { + return Object::TypeKey2Index(type_key); + } }; } // namespace runtime From 842ec726b207f1a1365df979ad033376755a6adb Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 10 Mar 2020 00:50:36 +0800 Subject: [PATCH 02/11] Special operator= for NDArrayHandle --- include/mxnet/runtime/ndarray_handle.h | 5 +++++ include/mxnet/runtime/packed_func.h | 11 ++++++++++- src/runtime/ndarray_handle.cc | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/include/mxnet/runtime/ndarray_handle.h b/include/mxnet/runtime/ndarray_handle.h index aa19595b50cd..c0f632b653d9 100644 --- a/include/mxnet/runtime/ndarray_handle.h +++ b/include/mxnet/runtime/ndarray_handle.h @@ -39,6 +39,11 @@ class NDArrayHandleObj : public Object { class NDArrayHandle : public ObjectRef { public: + explicit NDArrayHandle(NDArray* value) { + runtime::ObjectPtr node = make_object(); + node->value = value; + data_ = std::move(node); + } MXNET_DEFINE_OBJECT_REF_METHODS(NDArrayHandle, ObjectRef, NDArrayHandleObj) }; diff --git a/include/mxnet/runtime/packed_func.h b/include/mxnet/runtime/packed_func.h index ac7b462ce471..066ec8538338 100644 --- a/include/mxnet/runtime/packed_func.h +++ b/include/mxnet/runtime/packed_func.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -651,6 +652,9 @@ class MXNetRetValue : public MXNetPODValue_ { return *this; } MXNetRetValue& operator=(ObjectRef other) { + if (other.as()) { + return operator=(Downcast(other)); + } return operator=(std::move(other.data_)); } template @@ -670,11 +674,16 @@ class MXNetRetValue : public MXNetPODValue_ { this->Assign(other); return *this; } - MXNetRetValue& operator=(::mxnet::NDArray* value) { + MXNetRetValue& operator=(NDArray* value) { this->SwitchToPOD(kNDArrayHandle); value_.v_handle = reinterpret_cast(value); return *this; } + MXNetRetValue& operator=(NDArrayHandle value) { + this->SwitchToPOD(kNDArrayHandle); + value_.v_handle = reinterpret_cast(value->value); + return *this; + } MXNetRetValue& operator=(const PythonArg& value) { this->SwitchToPOD(kPyArg); value_.v_int64 = value.offset(); diff --git a/src/runtime/ndarray_handle.cc b/src/runtime/ndarray_handle.cc index 482a3076805b..5afb1984740b 100644 --- a/src/runtime/ndarray_handle.cc +++ b/src/runtime/ndarray_handle.cc @@ -33,7 +33,7 @@ MXNET_REGISTER_GLOBAL("ndarray_handle._GetNDArrayHandleValue") .set_body([](MXNetArgs args, MXNetRetValue* rv) { ObjectRef obj = args[0]; const auto& handle = Downcast(obj); - *rv = handle; + *rv = handle->value; }); MXNET_REGISTER_OBJECT_TYPE(NDArrayHandleObj); From a5a61461c7d2b566eaedc08eb7490622bb350266 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 10 Mar 2020 01:58:30 +0800 Subject: [PATCH 03/11] SVD --- python/mxnet/ndarray/numpy/linalg.py | 4 +- src/api/operator/numpy/linalg/np_gesvd.cc | 46 +++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 src/api/operator/numpy/linalg/np_gesvd.cc diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 0acedf4bbab4..4b642f460a55 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -19,6 +19,7 @@ from . import _op as _mx_nd_np from . import _internal as _npi +from . import _api_internal __all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh'] @@ -329,7 +330,8 @@ def svd(a): >>> (ret - a < -1e-3).sum() array(0.) """ - return tuple(_npi.svd(a)) + # return tuple(_npi.svd(a)) + return tuple(_api_internal.svd(a)) def cholesky(a): diff --git a/src/api/operator/numpy/linalg/np_gesvd.cc b/src/api/operator/numpy/linalg/np_gesvd.cc new file mode 100644 index 000000000000..2e3b2ce22382 --- /dev/null +++ b/src/api/operator/numpy/linalg/np_gesvd.cc @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_gesvd.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_gesvd.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_gesvd-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.svd") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_svd"); + attrs.op = op; + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + // outputs + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1]), NDArrayHandle(ndoutputs[2])}); +}); + +} // namespace mxnet From 43f717bf368da4118e0b50aa60b56cf4896a6d64 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 10 Mar 2020 14:58:41 +0800 Subject: [PATCH 04/11] Support cython --- python/mxnet/_ffi/_cython/base.pxi | 2 + python/mxnet/_ffi/_cython/core.pyx | 1 + python/mxnet/_ffi/_cython/function.pxi | 25 +++++++ python/mxnet/_ffi/_cython/object.pxi | 98 ++++++++++++++++++++++++++ 4 files changed, 126 insertions(+) create mode 100644 python/mxnet/_ffi/_cython/object.pxi diff --git a/python/mxnet/_ffi/_cython/base.pxi b/python/mxnet/_ffi/_cython/base.pxi index bc2273bacd0d..84d02e0452c4 100644 --- a/python/mxnet/_ffi/_cython/base.pxi +++ b/python/mxnet/_ffi/_cython/base.pxi @@ -59,6 +59,8 @@ cdef extern from "mxnet/runtime/c_runtime_api.h": MXNetValue* ret_val, int* ret_type_code) int MXNetFuncFree(MXNetFunctionHandle func) + int MXNetObjectFree(ObjectHandle obj) + int MXNetObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index) cdef inline py_str(const char* x): diff --git a/python/mxnet/_ffi/_cython/core.pyx b/python/mxnet/_ffi/_cython/core.pyx index 482f494b6e5e..0110eaaedc1b 100644 --- a/python/mxnet/_ffi/_cython/core.pyx +++ b/python/mxnet/_ffi/_cython/core.pyx @@ -20,4 +20,5 @@ include "./base.pxi" include "./ndarray.pxi" include "./convert.pxi" +include "./object.pxi" include "./function.pxi" diff --git a/python/mxnet/_ffi/_cython/function.pxi b/python/mxnet/_ffi/_cython/function.pxi index d4c629a618d5..97b0af786af3 100644 --- a/python/mxnet/_ffi/_cython/function.pxi +++ b/python/mxnet/_ffi/_cython/function.pxi @@ -53,6 +53,9 @@ cdef inline int make_arg(object arg, elif arg is None: value[0].v_handle = NULL tcode[0] = kNull + elif isinstance(arg, ObjectBase): + value[0].v_handle = (arg).chandle + tcode[0] = kObjectHandle elif isinstance(arg, Number): value[0].v_float64 = arg tcode[0] = kFloat @@ -77,6 +80,8 @@ cdef inline object make_ret(MXNetValue value, int tcode, tuple args): return args[value.v_int64] elif tcode == kNull: return None + elif tcode == kObjectHandle: + return make_ret_object(value.v_handle) elif tcode == kInt: return value.v_int64 elif tcode == kFloat: @@ -130,6 +135,19 @@ cdef inline int FuncCall(void* chandle, return 0 +cdef inline int ConstructorCall(void* constructor_handle, + int type_code, + tuple args, + void** handle) except -1: + """Call contructor of a handle function""" + cdef MXNetValue ret_val + cdef int ret_tcode + FuncCall(constructor_handle, args, &ret_val, &ret_tcode) + assert ret_tcode == type_code + handle[0] = ret_val.v_handle + return 0 + + cdef class FunctionBase: cdef MXNetFunctionHandle chandle cdef int is_global @@ -169,3 +187,10 @@ cdef class FunctionBase: cdef int ret_tcode FuncCall(self.chandle, args, &ret_val, &ret_tcode) return make_ret(ret_val, ret_tcode, args) + + +_CLASS_OBJECT = None + +def _set_class_object(obj_class): + global _CLASS_OBJECT + _CLASS_OBJECT = obj_class diff --git a/python/mxnet/_ffi/_cython/object.pxi b/python/mxnet/_ffi/_cython/object.pxi new file mode 100644 index 000000000000..4f31f2ad5aaa --- /dev/null +++ b/python/mxnet/_ffi/_cython/object.pxi @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Maps object type to its constructor +Acknowledgement: This file originates from incubator-tvm +""" +cdef list OBJECT_TYPE = [] + +def _register_object(int index, object cls): + """register object class""" + global OBJECT_TYPE + while len(OBJECT_TYPE) <= index: + OBJECT_TYPE.append(None) + OBJECT_TYPE[index] = cls + + +cdef inline object make_ret_object(void* chandle): + global OBJECT_TYPE + global _CLASS_OBJECT + cdef unsigned tindex + cdef object cls + object_type = OBJECT_TYPE + CALL(MXNetObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(OBJECT_TYPE): + cls = OBJECT_TYPE[tindex] + if cls is not None: + obj = cls.__new__(cls) + else: + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + else: + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle + return obj + + +cdef class ObjectBase: + cdef void* chandle + + cdef inline _set_handle(self, handle): + cdef unsigned long long ptr + if handle is None: + self.chandle = NULL + else: + ptr = handle.value + self.chandle = (ptr) + + property handle: + def __get__(self): + if self.chandle == NULL: + return None + else: + return ctypes_handle(self.chandle) + + def __set__(self, value): + self._set_handle(value) + + def __dealloc__(self): + CALL(MXNetObjectFree(self.chandle)) + + def __init_handle_by_constructor__(self, fconstructor, *args): + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + # avoid error raised during construction. + self.chandle = NULL + cdef void* chandle + ConstructorCall( + (fconstructor).chandle, + kObjectHandle, args, &chandle) + self.chandle = chandle From c4cc3a7b92baaed785b3c462bd032b10d7e7f6c2 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 10 Mar 2020 17:44:48 +0800 Subject: [PATCH 05/11] Clear --- python/mxnet/__init__.py | 1 - python/mxnet/ndarray_handle.py | 29 ----------------------------- 2 files changed, 30 deletions(-) delete mode 100644 python/mxnet/ndarray_handle.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index c9fac79ea085..83cf72d4c179 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -109,4 +109,3 @@ from . import _api_internal from . import api from . import container -from . import ndarray_handle diff --git a/python/mxnet/ndarray_handle.py b/python/mxnet/ndarray_handle.py deleted file mode 100644 index fbfdcce632b9..000000000000 --- a/python/mxnet/ndarray_handle.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -NDarray Handle -""" -from ._ffi.object import Object, register_object -from ._ffi.function import _init_api - -@register_object("MXNet.NDArrayHandle") -class NDArrayHandle(Object): - @property - def value(self): - return _GetNDArrayHandleValue(self) - -_init_api("mxnet.ndarray_handle") From d9148c0d254c9297a0bbabf8463e4efee3da009e Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 10 Mar 2020 19:09:57 +0800 Subject: [PATCH 06/11] Add split --- python/mxnet/ndarray/numpy/_op.py | 32 ++++++++++++----------- src/api/operator/numpy/np_matrix_op.cc | 35 ++++++++++++++++++++++++++ src/operator/tensor/matrix_op-inl.h | 11 ++++++++ src/runtime/object.cc | 1 - 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 3d30333d6da2..518bd31a079c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -3698,21 +3698,23 @@ def split(ary, indices_or_sections, axis=0): If `indices_or_sections` is given as an integer, but a split does not result in equal division. """ - axis_size = ary.shape[axis] - if isinstance(indices_or_sections, integer_types): - sections = indices_or_sections - if axis_size % sections: - raise ValueError('array split does not result in an equal division') - section_size = int(axis_size / sections) - indices = [i * section_size for i in range(sections)] - elif isinstance(indices_or_sections, (list, set, tuple)): - indices = [0] + list(indices_or_sections) - else: - raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints') - ret = _npi.split(ary, indices, axis, False) - assert isinstance(ret, list), 'Output of split should be list,' \ - ' got a return type {}'.format(type(ret)) - return ret + return list(_api_internal.split(ary, indices_or_sections, axis)) + # axis_size = ary.shape[axis] + # if isinstance(indices_or_sections, integer_types): + # sections = indices_or_sections + # if axis_size % sections: + # raise ValueError('array split does not result in an equal division') + # section_size = int(axis_size / sections) + # indices = [i * section_size for i in range(sections)] + # elif isinstance(indices_or_sections, (list, set, tuple)): + # indices = [0] + list(indices_or_sections) + # else: + # raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints') + # ret = _npi.split(ary, indices, axis, False) + # assert isinstance(ret, list), 'Output of split should be list,' \ + # ' got a return type {}'.format(type(ret)) + # return ret + # return list(_api_internal.split(ary, indices_or_sections, axis, False)) # pylint: enable=redefined-outer-name diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index 080cca867ecb..2348e8b62af2 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -46,4 +46,39 @@ MXNET_REGISTER_API("_npi.expand_dims") *ret = ndoutputs[0]; }); +MXNET_REGISTER_API("_npi.split") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_split"); + nnvm::NodeAttrs attrs; + op::SplitParam param; + param.axis = args[2].operator int(); + param.squeeze_axis = false; + if (args[1].type_code() == kDLInt) { + param.indices = TShape(0, 0); + param.sections = args[1].operator int(); + } else { + TShape t = TShape(args[1].operator ObjectRef()); + param.indices = TShape(t.ndim() + 1, 0); + for (size_t i = 0; i < t.ndim(); ++i) { + param.indices[i + 1] = t[i]; + } + param.sections = 0; + } + attrs.parsed = std::move(param); + attrs.op = op; + SetAttrDict(&attrs); + + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); +}); + } // namespace mxnet diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 670104bcfdb0..6efde79f202b 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2729,6 +2729,17 @@ struct SplitParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(sections).set_default(0) .describe("Number of sections if equally splitted. Default to 0 which means split by indices."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream indices_s, axis_s, squeeze_axis_s, sections_s; + indices_s << indices; + axis_s << axis; + squeeze_axis_s << squeeze_axis; + sections_s << sections; + (*dict)["indices"] = indices_s.str(); + (*dict)["axis"] = axis_s.str(); + (*dict)["squeeze_axis"] = squeeze_axis_s.str(); + (*dict)["sections"] = sections_s.str(); + } }; // struct SplitParam inline mxnet::TShape GetSplitIndices(const mxnet::TShape& ishape, int axis, int sections) { diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 11bada181cec..ee7ed74ecb88 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -84,7 +84,6 @@ class TypeContext { uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); - std::cout << skey << std::endl; auto it = type_key2index_.find(skey); if (it != type_key2index_.end()) { return it->second; From 2a4198c1a05d1bae47d948b315c9a8e3e3be57d0 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 12 Mar 2020 01:06:46 +0800 Subject: [PATCH 07/11] Refine --- benchmark/python/ffi/benchmark_ffi.py | 2 ++ include/mxnet/runtime/ndarray_handle.h | 2 +- python/mxnet/_ffi/_ctypes/function.py | 1 - python/mxnet/_ffi/_ctypes/object.py | 1 - python/mxnet/container.py | 1 + python/mxnet/ndarray/numpy/_op.py | 20 +++----------------- src/api/operator/numpy/linalg/np_gesvd.cc | 4 +++- src/api/operator/numpy/np_matrix_op.cc | 8 ++++++-- src/api/operator/numpy/np_tensordot_op.cc | 8 ++++++-- src/operator/numpy/np_cumsum.cc | 1 + 10 files changed, 23 insertions(+), 25 deletions(-) diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index 88af3cf3d55e..dfe2f58dce5b 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -54,6 +54,8 @@ def prepare_workloads(): OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1))) OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2']) OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) + OpArgMngr.add_workload("linalg.svd", pool['3x3']) + OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1) OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1) diff --git a/include/mxnet/runtime/ndarray_handle.h b/include/mxnet/runtime/ndarray_handle.h index c0f632b653d9..e6da83c0c213 100644 --- a/include/mxnet/runtime/ndarray_handle.h +++ b/include/mxnet/runtime/ndarray_handle.h @@ -38,7 +38,7 @@ class NDArrayHandleObj : public Object { }; class NDArrayHandle : public ObjectRef { -public: + public: explicit NDArrayHandle(NDArray* value) { runtime::ObjectPtr node = make_object(); node->value = value; diff --git a/python/mxnet/_ffi/_ctypes/function.py b/python/mxnet/_ffi/_ctypes/function.py index 51f18bbb9a38..58823404909f 100644 --- a/python/mxnet/_ffi/_ctypes/function.py +++ b/python/mxnet/_ffi/_ctypes/function.py @@ -28,7 +28,6 @@ from ..base import c_str from .types import MXNetValue, TypeCode from .types import RETURN_SWITCH -from .object import ObjectBase from ..node_generic import convert_to_node from ..._ctypes.ndarray import NDArrayBase from .object import ObjectBase, _set_class_object diff --git a/python/mxnet/_ffi/_ctypes/object.py b/python/mxnet/_ffi/_ctypes/object.py index 8692b2a9b54c..241ac100de86 100644 --- a/python/mxnet/_ffi/_ctypes/object.py +++ b/python/mxnet/_ffi/_ctypes/object.py @@ -21,7 +21,6 @@ """ import ctypes from ...base import _LIB, check_call -from . import function from .types import RETURN_SWITCH, TypeCode ObjectHandle = ctypes.c_void_p diff --git a/python/mxnet/container.py b/python/mxnet/container.py index 441ff6bbe3da..f0760b3b1d21 100644 --- a/python/mxnet/container.py +++ b/python/mxnet/container.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=undefined-variable """ Container data structures. Acknowledgement: This file originates from incubator-tvm diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 518bd31a079c..1101c33a0ce4 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -990,7 +990,7 @@ def add(x1, x2, out=None, **kwargs): * If both inputs are of integer types (including boolean), not supported yet. """ if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): - _np.add(x1, x2, out=out) + return _np.add(x1, x2, out=out) return _api_internal.add(x1, x2, out) @@ -3698,23 +3698,9 @@ def split(ary, indices_or_sections, axis=0): If `indices_or_sections` is given as an integer, but a split does not result in equal division. """ + if isinstance(indices_or_sections, set): + indices_or_sections = list(indices_or_sections) return list(_api_internal.split(ary, indices_or_sections, axis)) - # axis_size = ary.shape[axis] - # if isinstance(indices_or_sections, integer_types): - # sections = indices_or_sections - # if axis_size % sections: - # raise ValueError('array split does not result in an equal division') - # section_size = int(axis_size / sections) - # indices = [i * section_size for i in range(sections)] - # elif isinstance(indices_or_sections, (list, set, tuple)): - # indices = [0] + list(indices_or_sections) - # else: - # raise ValueError('indices_or_sections must be either int, or tuple / list / set of ints') - # ret = _npi.split(ary, indices, axis, False) - # assert isinstance(ret, list), 'Output of split should be list,' \ - # ' got a return type {}'.format(type(ret)) - # return ret - # return list(_api_internal.split(ary, indices_or_sections, axis, False)) # pylint: enable=redefined-outer-name diff --git a/src/api/operator/numpy/linalg/np_gesvd.cc b/src/api/operator/numpy/linalg/np_gesvd.cc index 2e3b2ce22382..0c9e922315d4 100644 --- a/src/api/operator/numpy/linalg/np_gesvd.cc +++ b/src/api/operator/numpy/linalg/np_gesvd.cc @@ -40,7 +40,9 @@ MXNET_REGISTER_API("_npi.svd") // outputs int num_outputs = 0; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1]), NDArrayHandle(ndoutputs[2])}); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), + NDArrayHandle(ndoutputs[1]), + NDArrayHandle(ndoutputs[2])}); }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index 2348e8b62af2..a71197438f72 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -50,6 +50,8 @@ MXNET_REGISTER_API("_npi.split") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_split"); + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; nnvm::NodeAttrs attrs; op::SplitParam param; param.axis = args[2].operator int(); @@ -57,6 +59,10 @@ MXNET_REGISTER_API("_npi.split") if (args[1].type_code() == kDLInt) { param.indices = TShape(0, 0); param.sections = args[1].operator int(); + CHECK_GT(param.sections, 0) + << "ValueError: number sections must be larger than 0"; + CHECK_EQ(inputs[0]->shape()[param.axis] % param.sections, 0) + << "ValueError: array split does not result in an equal division"; } else { TShape t = TShape(args[1].operator ObjectRef()); param.indices = TShape(t.ndim() + 1, 0); @@ -69,8 +75,6 @@ MXNET_REGISTER_API("_npi.split") attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; int num_outputs = 0; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); std::vector ndarray_handles; diff --git a/src/api/operator/numpy/np_tensordot_op.cc b/src/api/operator/numpy/np_tensordot_op.cc index b163757f85b1..eef58b5b3389 100644 --- a/src/api/operator/numpy/np_tensordot_op.cc +++ b/src/api/operator/numpy/np_tensordot_op.cc @@ -39,8 +39,9 @@ inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args, attrs.parsed = param; SetAttrDict(&attrs); int num_outputs = 0; + int num_inputs = 2; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } @@ -52,9 +53,11 @@ inline static void _npi_tensordot(runtime::MXNetArgs args, nnvm::NodeAttrs attrs; ADT adt = Downcast(args[2].operator ObjectRef()); if (const IntegerObj* lop = adt[0].as()) { + // axes is a tuple of int, like axes=(0, 1) param.a_axes_summed = Tuple(1, lop->value); param.b_axes_summed = Tuple(1, Downcast(adt[1])->value); } else { + // axes is a tuple of tuples of int, like axes=((0, 1), (1, 0)) param.a_axes_summed = Tuple(adt[0]); param.b_axes_summed = Tuple(adt[1]); } @@ -62,8 +65,9 @@ inline static void _npi_tensordot(runtime::MXNetArgs args, attrs.parsed = std::move(param); SetAttrDict(&attrs); int num_outputs = 0; + int num_inputs = 2; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); *ret = reinterpret_cast(ndoutputs[0]); } diff --git a/src/operator/numpy/np_cumsum.cc b/src/operator/numpy/np_cumsum.cc index ea0f9b6b11bc..594f7b796347 100644 --- a/src/operator/numpy/np_cumsum.cc +++ b/src/operator/numpy/np_cumsum.cc @@ -66,6 +66,7 @@ inline bool CumsumType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(CumsumParam); NNVM_REGISTER_OP(_npi_cumsum) +.add_alias("cumsum") .describe(R"code(Return the cumulative sum of the elements along a given axis.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(1) From 4706207ee75ec5ec2b2209dceafea9a47239f1d2 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 12 Mar 2020 11:40:05 +0800 Subject: [PATCH 08/11] Fix ci --- src/api/operator/numpy/linalg/np_gesvd.cc | 1 - src/api/operator/numpy/np_matrix_op.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/api/operator/numpy/linalg/np_gesvd.cc b/src/api/operator/numpy/linalg/np_gesvd.cc index 0c9e922315d4..a4517849cbaf 100644 --- a/src/api/operator/numpy/linalg/np_gesvd.cc +++ b/src/api/operator/numpy/linalg/np_gesvd.cc @@ -24,7 +24,6 @@ #include #include #include "../../utils.h" -#include "../../../../operator/numpy/linalg/np_gesvd-inl.h" namespace mxnet { diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index a71197438f72..cc268c202c9b 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -66,7 +66,7 @@ MXNET_REGISTER_API("_npi.split") } else { TShape t = TShape(args[1].operator ObjectRef()); param.indices = TShape(t.ndim() + 1, 0); - for (size_t i = 0; i < t.ndim(); ++i) { + for (int i = 0; i < t.ndim(); ++i) { param.indices[i + 1] = t[i]; } param.sections = 0; From 45ed1ba55ea3708f91f39ef29a2525216d08423c Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 13 Mar 2020 17:02:25 +0800 Subject: [PATCH 09/11] Fix typo --- src/api/operator/op_utils.cc | 2 +- src/api/operator/op_utils.h | 2 +- src/operator/numpy/np_cumsum-inl.h | 2 +- src/operator/tensor/init_op.h | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/api/operator/op_utils.cc b/src/api/operator/op_utils.cc index bb54662e7a62..1cf813eb8688 100644 --- a/src/api/operator/op_utils.cc +++ b/src/api/operator/op_utils.cc @@ -28,7 +28,7 @@ namespace mxnet { -std::string String2MXNetTypeWithBool(int dtype) { +std::string MXNetTypeWithBool2String(int dtype) { switch (dtype) { case mshadow::kFloat32: return "float32"; diff --git a/src/api/operator/op_utils.h b/src/api/operator/op_utils.h index f41680df6fd6..285919cd14c4 100644 --- a/src/api/operator/op_utils.h +++ b/src/api/operator/op_utils.h @@ -28,7 +28,7 @@ namespace mxnet { -std::string String2MXNetTypeWithBool(int dtype); +std::string MXNetTypeWithBool2String(int dtype); std::string MXNetPercentileType2String(int interpolation); } // namespace mxnet diff --git a/src/operator/numpy/np_cumsum-inl.h b/src/operator/numpy/np_cumsum-inl.h index b6e0eab5a8f5..665a546f6587 100644 --- a/src/operator/numpy/np_cumsum-inl.h +++ b/src/operator/numpy/np_cumsum-inl.h @@ -64,7 +64,7 @@ struct CumsumParam : public dmlc::Parameter { dtype_s << dtype; (*dict)["axis"] = axis_s.str(); if (dtype.has_value()) { - (*dict)["dtype"] = String2MXNetTypeWithBool(dtype.value()); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype.value()); } else { (*dict)["dtype"] = dtype_s.str(); } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index f6610a980f6a..fb739c690607 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -39,6 +39,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "../mshadow_op.h" +#include "../../api/operator/op_utils.h" namespace mxnet { @@ -61,11 +62,10 @@ struct InitOpParam : public dmlc::Parameter { .describe("Target data type."); } void SetAttrDict(std::unordered_map* dict) { - std::ostringstream shape_s, dtype_s; + std::ostringstream shape_s; shape_s << shape; - dtype_s << dtype; (*dict)["shape"] = shape_s.str(); - (*dict)["dtype"] = dtype_s.str(); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); // We do not set ctx, because ctx has been set in dict instead of InitOpParam. // Setting ctx here results in an error. } From 9e9f3f394211126690488693e9e485f43b06ab55 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 16 Mar 2020 11:53:13 +0800 Subject: [PATCH 10/11] Clear --- python/mxnet/ndarray/numpy/linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 4b642f460a55..fdcbdac2247a 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -330,7 +330,6 @@ def svd(a): >>> (ret - a < -1e-3).sum() array(0.) """ - # return tuple(_npi.svd(a)) return tuple(_api_internal.svd(a)) From ef7d877123104cebbcb27f067ed829df7d0ff731 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Tue, 17 Mar 2020 18:45:33 +0800 Subject: [PATCH 11/11] Resolve sanity issues --- python/mxnet/_ffi/_cython/function.pxi | 1 + src/runtime/object_internal.h | 1 + 2 files changed, 2 insertions(+) diff --git a/python/mxnet/_ffi/_cython/function.pxi b/python/mxnet/_ffi/_cython/function.pxi index 97b0af786af3..1e6aa8625ec2 100644 --- a/python/mxnet/_ffi/_cython/function.pxi +++ b/python/mxnet/_ffi/_cython/function.pxi @@ -191,6 +191,7 @@ cdef class FunctionBase: _CLASS_OBJECT = None + def _set_class_object(obj_class): global _CLASS_OBJECT _CLASS_OBJECT = obj_class diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 12165bbfe703..7252bace9491 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -46,6 +46,7 @@ class ObjectInternal { static_cast(obj)->DecRef(); } } + /*! * \brief Expose TypeKey2Index * \param type_key The original type key.