Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions include/mxnet/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ typedef enum {
kNull = 4U,
kMXNetType = 5U,
kMXNetContext = 6U,
kArrayHandle = 7U,
kObjectHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kNDArrayHandle = 14U,
kObjectHandle = 7U,
kStr = 8U,
kBytes = 9U,
kPyArg = 10U,
kNDArrayHandle = 11U,
// Extension codes for other frameworks to integrate MXNet PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down
36 changes: 6 additions & 30 deletions include/mxnet/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <mxnet/runtime/container.h>
#include <mxnet/runtime/ffi_helper.h>
#include <mxnet/runtime/data_type.h>
#include <mxnet/runtime/py_arg.h>
#include <mxnet/node/container.h>
#include <mxnet/ir/expr.h>
#include <mxnet/ndarray.h>
Expand Down Expand Up @@ -416,7 +417,6 @@ class MXNetPODValue_ {
}
operator void*() const {
if (type_code_ == kNull) return nullptr;
if (type_code_ == kArrayHandle) return value_.v_handle;
MXNET_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle;
}
Expand Down Expand Up @@ -520,11 +520,6 @@ class MXNetArgValue : public MXNetPODValue_ {
MXNET_CHECK_TYPE_CODE(type_code_, kNDArrayHandle);
return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle);
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -597,11 +592,6 @@ class MXNetRetValue : public MXNetPODValue_ {
operator MXNetDataType() const {
return MXNetDataType(operator DLDataType());
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -668,10 +658,6 @@ class MXNetRetValue : public MXNetPODValue_ {
SwitchToObject(kObjectHandle, std::move(other));
return *this;
}
MXNetRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
MXNetRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
Expand All @@ -689,6 +675,11 @@ class MXNetRetValue : public MXNetPODValue_ {
value_.v_handle = reinterpret_cast<void*>(value);
return *this;
}
MXNetRetValue& operator=(const PythonArg& value) {
this->SwitchToPOD(kPyArg);
value_.v_int64 = value.offset();
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
Expand Down Expand Up @@ -717,7 +708,6 @@ class MXNetRetValue : public MXNetPODValue_ {
/*! \return The value field, if the data is POD */
const MXNetValue& value() const {
CHECK(type_code_ != kObjectHandle &&
type_code_ != kFuncHandle &&
type_code_ != kStr) << "MXNetRetValue.value can only be used for POD data";
return value_;
}
Expand All @@ -741,10 +731,6 @@ class MXNetRetValue : public MXNetPODValue_ {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
}
case kObjectHandle: {
*this = other.operator ObjectRef();
break;
Expand Down Expand Up @@ -792,7 +778,6 @@ class MXNetRetValue : public MXNetPODValue_ {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
Expand Down Expand Up @@ -857,7 +842,6 @@ inline const char* TypeCode2Str(int type_code) {
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kFuncHandle: return "FunctionHandle";
case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
Expand Down Expand Up @@ -1012,10 +996,6 @@ class MXNetArgsSetter {
values_[i].v_handle = value;
type_codes_[i] = kHandle;
}
void operator()(size_t i, DLTensor* value) const {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kStr;
Expand All @@ -1038,10 +1018,6 @@ class MXNetArgsSetter {
values_[i].v_handle = const_cast<MXNetByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
Expand Down
42 changes: 42 additions & 0 deletions include/mxnet/runtime/py_arg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file py_arg.h
* \brief Python runtime arguments specifier.
*/
#ifndef MXNET_RUNTIME_PY_ARG_H_
#define MXNET_RUNTIME_PY_ARG_H_

namespace mxnet {
namespace runtime {

class PythonArg {
public:
explicit PythonArg(int offset): offset_(offset) {}
int offset() const {
return offset_;
}
private:
int offset_;
};

} // namespace runtime

} // namespace mxnet
#endif // MXNET_RUNTIME_PY_ARG_H_
7 changes: 6 additions & 1 deletion python/mxnet/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import ctypes
from numbers import Number, Integral
import numpy as onp

from ...base import get_last_ffi_error, _LIB
from ..base import c_str
Expand Down Expand Up @@ -66,6 +67,9 @@ def _make_mxnet_args(args, temp_args):
elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE
elif isinstance(arg, type):
values[i].v_str = c_str(onp.dtype(arg).name)
type_codes[i] = TypeCode.STR
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
Expand Down Expand Up @@ -110,7 +114,8 @@ def __call__(self, *args):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
return (RETURN_SWITCH[ret_tcode.value](ret_val) if ret_tcode.value != TypeCode.PYARG
else RETURN_SWITCH[ret_tcode.value](ret_val, args))


_CLASS_OBJECT = None
Expand Down
16 changes: 7 additions & 9 deletions python/mxnet/_ffi/_ctypes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ class TypeCode(object):
NULL = 4
MXNET_TYPE = 5
MXNET_CONTEXT = 6
ARRAY_HANDLE = 7
OBJECT_HANDLE = 8
MODULE_HANDLE = 9
FUNC_HANDLE = 10
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
NDARRAYHANDLE = 14
OBJECT_HANDLE = 7
STR = 8
BYTES = 9
PYARG = 10
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this type code for?

Copy link
Copy Markdown
Contributor Author

@hzfan hzfan Mar 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically it designates a python argument to be returned. For example, PythonArg(2) says the 2nd python argument should be returned.

It is introduced for out. If parameter out is not None, then out should serve as the return value. Note that we should not return the NDArray* specified by out, as it creates another NDArrayBase in frontend, and thus one NDArray* in backend gets destructed twice by two frontend NDArrayBase

NDARRAYHANDLE = 11
EXT_BEGIN = 15


Expand All @@ -54,5 +51,6 @@ class MXNetValue(ctypes.Union):
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.NULL: lambda x: None,
TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle))
TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle)),
TypeCode.PYARG: lambda x, args: args[x.v_int64],
}
13 changes: 5 additions & 8 deletions python/mxnet/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ cdef enum MXNetTypeCode:
kNull = 4
kMXNetType = 5
kMXNetContext = 6
kArrayHandle = 7
kObjectHandle = 8
kModuleHandle = 9
kFuncHandle = 10
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kNDArrayHandle = 14
kObjectHandle = 7
kStr = 8
kBytes = 9
kPyArg = 10
kNDArrayHandle = 11
kExtBegin = 15

cdef extern from "mxnet/runtime/c_runtime_api.h":
Expand Down
18 changes: 13 additions & 5 deletions python/mxnet/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Acknowledgement: This file originates from incubator-tvm"""

import ctypes
import numpy as onp
import traceback
from ...ndarray._internal import NDArrayBase
from numbers import Number, Integral
Expand Down Expand Up @@ -58,14 +59,23 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kHandle
elif isinstance(arg, type):
tstr = c_str(onp.dtype(arg).name)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0


cdef inline object make_ret(MXNetValue value, int tcode):
cdef inline object make_ret(MXNetValue value, int tcode, tuple args):
"""convert result to return value."""
if tcode == kNull:
if tcode == kNDArrayHandle:
return c_make_array(value.v_handle)
elif tcode == kPyArg:
return args[value.v_int64]
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
Expand All @@ -75,8 +85,6 @@ cdef inline object make_ret(MXNetValue value, int tcode):
return py_str(value.v_str)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kNDArrayHandle:
return c_make_array(value.v_handle)
raise ValueError("Unhandled type code %d" % tcode)


Expand Down Expand Up @@ -160,4 +168,4 @@ cdef class FunctionBase:
cdef MXNetValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
return make_ret(ret_val, ret_tcode, args)
51 changes: 0 additions & 51 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,57 +134,6 @@ def _np_sometrue(a, axis=None, keepdims=False, out=None):
pass


def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
Return the cumulative sum of the elements along a given axis.

Parameters
----------
a : array_like
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default
(None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the
elements are summed. If `dtype` is not specified, it defaults
to the dtype of `a`, unless `a` has an integer dtype with a
precision less than that of the default platform integer. In
that case, the default platform integer is used.
out : ndarray, optional
Alternative output array in which to place the result. It must
have the same shape and buffer length as the expected output
but the type will be cast if necessary. See `doc.ufuncs`
(Section "Output arguments") for more details.

Returns
-------
cumsum_along_axis : ndarray.
A new array holding the result is returned unless `out` is
specified, in which case a reference to `out` is returned. The
result has the same size as `a`, and the same shape as `a` if
`axis` is not None or `a` is a 1-d array.

Examples
--------
>>> a = np.array([[1,2,3], [4,5,6]])
>>> a
array([[1, 2, 3],
[4, 5, 6]])
>>> np.cumsum(a)
array([ 1, 3, 6, 10, 15, 21])
>>> np.cumsum(a, dtype=float) # specifies type of output value(s)
array([ 1., 3., 6., 10., 15., 21.])
>>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns
array([[1, 2, 3],
[5, 7, 9]])
>>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows
array([[ 1, 3, 6],
[ 4, 9, 15]])
"""
pass


def _npx_nonzero(a):
"""
Return the indices of the elements that are non-zero.
Expand Down
Loading