Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def prepare_workloads():
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)


Expand Down
18 changes: 18 additions & 0 deletions include/mxnet/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,24 @@ MXNET_DLL int MXNetFuncListGlobalNames(int* out_size,
*/
MXNET_DLL int MXNetObjectFree(MXNetObjectHandle obj);


/*!
* \brief Get the type_index from an object.
*
* \param obj The object handle.
* \param out_tindex the output type index.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNetObjectGetTypeIndex(MXNetObjectHandle obj, unsigned* out_tindex);

/*!
* \brief Convert type key to type index.
* \param type_key The key of the type.
* \param out_tindex the corresponding type index.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNetObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
uint32_t size{0};
// The fields of the structure follows directly in memory.

static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT;
static constexpr const char* _type_key = "MXNet.ADT";
static constexpr const uint32_t _type_index = TypeIndex::kMXNetADT;
MXNET_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object)

private:
Expand Down
52 changes: 52 additions & 0 deletions include/mxnet/runtime/ndarray_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file ndarray_handle.h
* \brief NDArray handle types
*/
#ifndef MXNET_RUNTIME_NDARRAY_HANDLE_H_
#define MXNET_RUNTIME_NDARRAY_HANDLE_H_
#include <mxnet/ndarray.h>
#include <mxnet/runtime/object.h>

namespace mxnet {

class NDArrayHandleObj : public Object {
public:
/*! \brief the Internal value. */
NDArray* value;

static constexpr const char* _type_key = "MXNet.NDArrayHandle";
MXNET_DECLARE_FINAL_OBJECT_INFO(NDArrayHandleObj, Object)
};

class NDArrayHandle : public ObjectRef {
public:
explicit NDArrayHandle(NDArray* value) {
runtime::ObjectPtr<NDArrayHandleObj> node = make_object<NDArrayHandleObj>();
node->value = value;
data_ = std::move(node);
}
MXNET_DEFINE_OBJECT_REF_METHODS(NDArrayHandle, ObjectRef, NDArrayHandleObj)
};

}; // namespace mxnet

#endif // MXNET_RUNTIME_NDARRAY_HANDLE_H_
11 changes: 10 additions & 1 deletion include/mxnet/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <mxnet/runtime/object.h>
#include <mxnet/runtime/ndarray.h>
#include <mxnet/runtime/container.h>
#include <mxnet/runtime/ndarray_handle.h>
#include <mxnet/runtime/ffi_helper.h>
#include <mxnet/runtime/data_type.h>
#include <mxnet/runtime/py_arg.h>
Expand Down Expand Up @@ -651,6 +652,9 @@ class MXNetRetValue : public MXNetPODValue_ {
return *this;
}
MXNetRetValue& operator=(ObjectRef other) {
if (other.as<NDArrayHandleObj>()) {
return operator=(Downcast<NDArrayHandle, ObjectRef>(other));
}
return operator=(std::move(other.data_));
}
template<typename T>
Expand All @@ -670,11 +674,16 @@ class MXNetRetValue : public MXNetPODValue_ {
this->Assign(other);
return *this;
}
MXNetRetValue& operator=(::mxnet::NDArray* value) {
MXNetRetValue& operator=(NDArray* value) {
this->SwitchToPOD(kNDArrayHandle);
value_.v_handle = reinterpret_cast<void*>(value);
return *this;
}
MXNetRetValue& operator=(NDArrayHandle value) {
this->SwitchToPOD(kNDArrayHandle);
value_.v_handle = reinterpret_cast<void*>(value->value);
return *this;
}
MXNetRetValue& operator=(const PythonArg& value) {
this->SwitchToPOD(kPyArg);
value_.v_int64 = value.offset();
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,4 @@

from . import _api_internal
from . import api
from . import container
23 changes: 18 additions & 5 deletions python/mxnet/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from ..base import c_str
from .types import MXNetValue, TypeCode
from .types import RETURN_SWITCH
from .object import ObjectBase
from ..node_generic import convert_to_node
from ..._ctypes.ndarray import NDArrayBase
from .object import ObjectBase, _set_class_object
from . import object as _object

ObjectHandle = ctypes.c_void_p

Expand Down Expand Up @@ -118,8 +119,20 @@ def __call__(self, *args):
else RETURN_SWITCH[ret_tcode.value](ret_val, args))


_CLASS_OBJECT = None
def __init_handle_by_constructor__(fconstructor, args):
"""Initialize handle by constructor"""
temp_args = []
values, tcodes, num_args = _make_mxnet_args(args, temp_args)
ret_val = MXNetValue()
ret_tcode = ctypes.c_int()
if _LIB.MXNetFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.OBJECT_HANDLE
handle = ret_val.v_handle
return handle

def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
_object.__init_by_constructor__ = __init_handle_by_constructor__
47 changes: 43 additions & 4 deletions python/mxnet/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,35 @@
"""
import ctypes
from ...base import _LIB, check_call
from . import function
from .types import RETURN_SWITCH, TypeCode

ObjectHandle = ctypes.c_void_p
__init_by_constructor__ = None

"""Maps object type to its constructor"""
OBJECT_TYPE = {}

_CLASS_OBJECT = None

def _set_class_object(object_class):
global _CLASS_OBJECT
_CLASS_OBJECT = object_class

def _register_object(index, cls):
"""register object class"""
# if issubclass(cls, NDArrayBase):
# _register_ndarray(index, cls)
# return
OBJECT_TYPE[index] = cls


def _return_object(x):
handle = x.v_handle
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
# Does not support specific cpp node class for now
cls = function._CLASS_OBJECT
tindex = ctypes.c_uint()
check_call(_LIB.MXNetObjectGetTypeIndex(handle, ctypes.byref(tindex)))
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
# Avoid calling __init__ of cls, instead directly call __new__
# This allows child class to implement their own __init__
obj = cls.__new__(cls)
Expand All @@ -50,4 +67,26 @@ def __del__(self):
if _LIB is not None:
check_call(_LIB.MXNetObjectFree(self.handle))

# Does not support creation of cpp node class via python class
def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.

Parameters
----------
fconstructor : Function
Constructor function.

args: list of objects
The arguments to the constructor

Note
----
We have a special calling convention to call constructor functions.
So the return handle is directly set into the Node object
instead of creating a new Node.
"""
# assign handle first to avoid error raising
self.handle = None
handle = __init_by_constructor__(fconstructor, args)
if not isinstance(handle, ObjectHandle):
handle = ObjectHandle(handle)
self.handle = handle
2 changes: 2 additions & 0 deletions python/mxnet/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ cdef extern from "mxnet/runtime/c_runtime_api.h":
MXNetValue* ret_val,
int* ret_type_code)
int MXNetFuncFree(MXNetFunctionHandle func)
int MXNetObjectFree(ObjectHandle obj)
int MXNetObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)


cdef inline py_str(const char* x):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/_ffi/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
include "./base.pxi"
include "./ndarray.pxi"
include "./convert.pxi"
include "./object.pxi"
include "./function.pxi"
26 changes: 26 additions & 0 deletions python/mxnet/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ cdef inline int make_arg(object arg,
elif arg is None:
value[0].v_handle = NULL
tcode[0] = kNull
elif isinstance(arg, ObjectBase):
value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
Expand All @@ -77,6 +80,8 @@ cdef inline object make_ret(MXNetValue value, int tcode, tuple args):
return args[value.v_int64]
elif tcode == kNull:
return None
elif tcode == kObjectHandle:
return make_ret_object(value.v_handle)
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
Expand Down Expand Up @@ -130,6 +135,19 @@ cdef inline int FuncCall(void* chandle,
return 0


cdef inline int ConstructorCall(void* constructor_handle,
int type_code,
tuple args,
void** handle) except -1:
"""Call contructor of a handle function"""
cdef MXNetValue ret_val
cdef int ret_tcode
FuncCall(constructor_handle, args, &ret_val, &ret_tcode)
assert ret_tcode == type_code
handle[0] = ret_val.v_handle
return 0


cdef class FunctionBase:
cdef MXNetFunctionHandle chandle
cdef int is_global
Expand Down Expand Up @@ -169,3 +187,11 @@ cdef class FunctionBase:
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode, args)


_CLASS_OBJECT = None

Comment thread
haojin2 marked this conversation as resolved.

def _set_class_object(obj_class):
global _CLASS_OBJECT
_CLASS_OBJECT = obj_class
Loading