diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a1da1e92f6c..2612854ffdce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -676,7 +676,9 @@ else() endif() add_library(sample_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_custom_op/gemm_lib.cc) +add_library(subgraph_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/example/extensions/lib_subgraph/subgraph_lib.cc) target_include_directories(sample_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) +target_include_directories(subgraph_lib PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include/mxnet) set(MXNET_INSTALL_TARGETS mxnet) if(UNIX) string(APPEND CMAKE_CUDA_FLAGS "${CUDA_ARCH_FLAGS_SPACES}") @@ -690,10 +692,13 @@ if(UNIX) target_link_libraries(mxnet PRIVATE mxnet_static) target_link_libraries(mxnet_static PUBLIC ${CMAKE_DL_LIBS}) target_compile_options(sample_lib PUBLIC -shared) + target_compile_options(subgraph_lib PUBLIC -shared) set_target_properties(mxnet_static PROPERTIES OUTPUT_NAME mxnet) elseif(MSVC) target_compile_options(sample_lib PUBLIC /LD) + target_compile_options(subgraph_lib PUBLIC /LD) set_target_properties(sample_lib PROPERTIES PREFIX "lib") + set_target_properties(subgraph_lib PROPERTIES PREFIX "lib") if(USE_CUDA) if(MSVC) diff --git a/Makefile b/Makefile index 7e65e5a7f1b7..44c719f6d4d9 100644 --- a/Makefile +++ b/Makefile @@ -457,7 +457,7 @@ endif .PHONY: clean all extra-packages test lint clean_all rcpplint rcppexport roxygen\ cython2 cython3 cython cyclean -all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib +all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages sample_lib subgraph_lib SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) OBJ = $(patsubst %.cc, build/%.o, $(SRC)) @@ -667,6 +667,8 @@ pylint: # sample lib for MXNet extension dynamically loading custom operator sample_lib: $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_custom_op/gemm_lib.cc -o libsample_lib.so -I include/mxnet +subgraph_lib: + $(CXX) -shared -fPIC -std=c++11 example/extensions/lib_subgraph/subgraph_lib.cc -o libsubgraph_lib.so -I include/mxnet # Cython build cython: @@ -762,7 +764,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - - $(RM) libsample_lib.so + $(RM) libsample_lib.so libsubgraph_lib.so $(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS)) $(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS)) else @@ -774,7 +776,7 @@ clean: rclean mkldnn_clean cyclean testclean $(EXTRA_PACKAGES_CLEAN) cd $(NNVM_PATH); $(MAKE) clean; cd - cd $(TVM_PATH); $(MAKE) clean; cd - cd $(AMALGAMATION_PATH); $(MAKE) clean; cd - - $(RM) libsample_lib.so + $(RM) libsample_lib.so libsubgraph_lib.so endif clean_all: clean diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 81479328f054..c697e1e58788 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -23,23 +23,23 @@ utils = load('ci/Jenkinsfile_utils.groovy') // mxnet libraries -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // Python wheels mx_pip = 'build/*.whl' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' +mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. -mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' +mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' -mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*' // Python unittest for CPU diff --git a/example/extensions/lib_custom_op/Makefile b/example/extensions/lib_custom_op/Makefile index 66079a16a338..090d17d98a22 100644 --- a/example/extensions/lib_custom_op/Makefile +++ b/example/extensions/lib_custom_op/Makefile @@ -15,13 +15,10 @@ # specific language governing permissions and limitations # under the License. -all: subgraph_lib gemm_lib +all: gemm_lib gemm_lib: g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I ../../../include/mxnet -subgraph_lib: - g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet - clean: - rm -rf libsubgraph_lib.so libgemm_lib.so + rm -rf libgemm_lib.so diff --git a/example/extensions/lib_custom_op/subgraph_lib.cc b/example/extensions/lib_custom_op/subgraph_lib.cc deleted file mode 100644 index 27da0fd9a324..000000000000 --- a/example/extensions/lib_custom_op/subgraph_lib.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2019 by Contributors - * \file subgraph_lib.cc - * \brief subgraph operator implementation library file - */ - -#include -#include "lib_api.h" - -MXReturnValue parseAttrs(std::map attrs, - int* num_in, int* num_out) { - *num_in = 1; - *num_out = 1; - if (attrs.count(SUBGRAPH_SYM_JSON)) { - // example of subgraph json parsing - JsonParser jp; - JsonVal val = jp.parse_to_json(attrs[SUBGRAPH_SYM_JSON]); - int input = 0; - for (auto &item : val.map[JsonVal("nodes")].list) { - if (item.map[JsonVal("op")].str == "null") - input++; - } - int output = val.map[JsonVal("heads")].list.size(); - *num_in = input; - *num_out = output; - } - return MX_SUCCESS; -} - -class MyStatefulOp : public CustomStatefulOp { - public: - explicit MyStatefulOp(std::string sym) : subgraph_sym(sym) {} - - MXReturnValue Forward(std::vector inputs, - std::vector outputs, - OpResource op_res) { - std::cout << "Info: subgraph symbol is: " << std::endl; - std::cout << subgraph_sym << std::endl; - float* in_data = inputs[0].data(); - float* out_data = outputs[0].data(); - std::cout << "Info: output is: " << std::endl; - for (int i = 0; i < inputs[0].size(); i++) { - out_data[i] = in_data[i]; - } - return MX_SUCCESS; - } - - private: - std::string subgraph_sym; -}; - -MXReturnValue createOpState(std::map attrs, - CustomStatefulOp** op_inst) { - std::string serialized_subgraph = "[empty]"; - // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field - // custom subgraph is stored as json string in custom operator attrs map entry - if (attrs.count(SUBGRAPH_SYM_JSON)) { - // user can now parse json and run other custom ops inside subgraph - serialized_subgraph = attrs[SUBGRAPH_SYM_JSON]; - } - *op_inst = new MyStatefulOp(serialized_subgraph); - std::cout << "Info: stateful operator created" << std::endl; - return MX_SUCCESS; -} - -REGISTER_OP(_custom_subgraph_op) -.setParseAttrs(parseAttrs) -.setIsSubgraphOp() -.setCreateOpState(createOpState); - -MXReturnValue initialize(int version) { - if (version >= 10400) { - std::cout << "MXNet version " << version << " supported" << std::endl; - return MX_SUCCESS; - } else { - std::cout << "MXNet version " << version << " not supported" << std::endl; - return MX_FAIL; - } -} diff --git a/example/extensions/lib_subgraph/Makefile b/example/extensions/lib_subgraph/Makefile new file mode 100644 index 000000000000..c45100b69ef7 --- /dev/null +++ b/example/extensions/lib_subgraph/Makefile @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +all: subgraph_lib + +subgraph_lib: + g++ -shared -fPIC -std=c++11 subgraph_lib.cc -o libsubgraph_lib.so -I ../../../include/mxnet + +clean: + rm -rf libsubgraph_lib.so diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc b/example/extensions/lib_subgraph/subgraph_lib.cc new file mode 100644 index 000000000000..3ebdfc138a79 --- /dev/null +++ b/example/extensions/lib_subgraph/subgraph_lib.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file subgraph_lib.cc + * \brief subgraph operator implementation library file + */ + +#include +#include +#include +#include "lib_api.h" + +/* function to execute log operator on floats */ +void myLog(MXTensor &in, MXTensor &out) { + float* inp = in.data(); + float* outp = out.data(); + for (int64_t i = 0; i < in.size(); i++) { + outp[i] = logf(inp[i]); + } +} +/* function to execute exp operator on floats */ +void myExp(MXTensor &in, MXTensor &out) { + float* inp = in.data(); + float* outp =out.data(); + for (int64_t i = 0; i < in.size(); i++) { + outp[i] = expf(inp[i]); + } +} + +/* function to execute ops in subgraph + * In MXNet, subgraphs are sorted in topological order + * so all we need to do is go through the ops in order + * and execute each op. + */ +MXReturnValue myExecutor(std::vector inputs, + std::vector outputs, + std::string subgraph_sym) { + std::cout << "Info: subgraph symbol is: " << std::endl; + std::cout << subgraph_sym << std::endl; + + // convert json string to json object + JsonParser parser; + JsonVal json_val = parser.parse_to_json(subgraph_sym); + // get nodes list + JsonVal nodes = json_val.map[JsonVal("nodes")]; + //counter for inputs + int input_cnt = 0; + // temporary tensor storage + std::vector data; + // track memory allocations to free later + std::vector to_free; + + // loop over nodes + for(int i=0; i(); + float *res_data = result.data(); + // loop and copy data + for (int64_t i = 0; i < result.size(); i++) { + out_data[i] = res_data[i]; + } + } + + // free allocated temporary storage + for (void* ptr : to_free) { + free(ptr); + } + + return MX_SUCCESS; +} + +class MyStatefulOp : public CustomStatefulOp { + public: + explicit MyStatefulOp(std::string sym, std::map attrs) + : subgraph_sym(sym), attrs_(attrs) { + for (auto kv : attrs) { + std::cout << "subgraphOp attributes: " << kv.first << " ==> " << kv.second << std::endl; + } + } + + MXReturnValue Forward(std::vector inputs, + std::vector outputs, + OpResource op_res) { + return myExecutor(inputs, outputs, subgraph_sym); + } + + private: + std::string subgraph_sym; + std::map attrs_; +}; + +MXReturnValue createOpState(std::map attrs, + CustomStatefulOp** op_inst) { + std::string serialized_subgraph = "[empty]"; + // MXNet subgraph is stored as Symbol in operator node attrs subgraphs field + // custom subgraph is stored as json string in custom operator attrs map entry + if (attrs.count(SUBGRAPH_SYM_JSON)) { + // user can now parse json and run other custom ops inside subgraph + serialized_subgraph = attrs[SUBGRAPH_SYM_JSON]; + } + attrs.erase(SUBGRAPH_SYM_JSON); + *op_inst = new MyStatefulOp(serialized_subgraph, attrs); + std::cout << "Info: stateful operator created" << std::endl; + return MX_SUCCESS; +} + +REGISTER_OP(_custom_subgraph_op) +.setIsSubgraphOp() +.setCreateOpState(createOpState); + +const std::vector op_names({"exp","log"}); + +MXReturnValue mySupportedOps(std::string json, + const int num_ids, + int *ids, + std::unordered_map& options) { + for (auto kv : options) { + std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; + } + //convert json string to json object + JsonParser parser; + JsonVal json_val = parser.parse_to_json(json); + //get nodes list + JsonVal nodes = json_val.map[JsonVal("nodes")]; + + //loop over nodes + for(int i=0; i& options, + std::unordered_map& attrs) { + for (auto kv : options) { + std::cout << "option: " << kv.first << " ==> " << kv.second << std::endl; + } + if(options.find("reject") != options.end() && + options["reject"].compare("True") == 0) { + *accept = false; + std::cout << "rejecting subgraph" << std::endl; + } else { + *accept = true; + std::cout << "accepting subgraph" << std::endl; + attrs["myKey"] = "myVal"; + } + return MX_SUCCESS; +} + +REGISTER_PARTITIONER(myProp) +.addStrategy("strategy1", mySupportedOps, "_custom_subgraph_op") +.setAcceptSubgraph("strategy1", myAcceptSubgraph); + +MXReturnValue initialize(int version) { + if (version >= 10400) { + std::cout << "MXNet version " << version << " supported" << std::endl; + return MX_SUCCESS; + } else { + std::cout << "MXNet version " << version << " not supported" << std::endl; + return MX_FAIL; + } +} diff --git a/example/extensions/lib_custom_op/test_subgraph.py b/example/extensions/lib_subgraph/test_subgraph.py similarity index 52% rename from example/extensions/lib_custom_op/test_subgraph.py rename to example/extensions/lib_subgraph/test_subgraph.py index 2625b13f6794..8169261d4d42 100644 --- a/example/extensions/lib_custom_op/test_subgraph.py +++ b/example/extensions/lib_subgraph/test_subgraph.py @@ -39,22 +39,38 @@ b = mx.sym.var('b') c = a + b d = mx.sym.exp(c) -ret = mx.sym.log(d) +sym = mx.sym.log(d) -op_names = ['exp','log'] -out = SymbolHandle() +#execute in MXNet +print('-------------------------------') +print('Testing regular MXNet execution') +exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) +out = exe.forward() +print(out) -check_call(_LIB.MXBuildSubgraphByOpNames(ret.handle, - c_str('default'), - mx_uint(len(op_names)), - c_str_array(op_names), - ctypes.byref(out))) -partitioned_sym = mx.sym.Symbol(out) -json_sym = partitioned_sym.tojson() +# with propogating shapes/types +print('-------------------------------') +print('Testing partitioning with shapes/types') +arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] +mysym2 = sym.optimize_for("myProp",arg_array) +print(mysym2.tojson()) +exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) +out2 = exe2.forward() +print(out2) -mystr = json_sym.replace("_CachedOp","_custom_subgraph_op") -mysym = mx.sym.load_json(mystr) +# with propogating shapes/types, rejecting subgraph +print('-------------------------------') +print('Testing partitioning with shapes/types - rejecting subgraph') +arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')] +mysym2 = sym.optimize_for("myProp", arg_array, reject=True) +exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) +out2 = exe2.forward() +print(out2) -exe = mysym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) -out = exe.forward() -print(out) +# without propogating shapes/types +print('-------------------------------') +print('Testing partitioning without shapes/types') +mysym3 = sym.optimize_for("myProp", myOpt='yello') +exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) +out3 = exe3.forward() +print(out3) diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index c7887aad378f..cc0ec0f938af 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -30,8 +30,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -554,7 +556,7 @@ typedef MXReturnValue (*inferShape_t)(std::map, std::vector >&, std::vector >&); typedef MXReturnValue (*mutateInputs_t)(std::map, - std::vector&); + std::vector&); typedef MXReturnValue (*createOpState_t)(std::map, CustomStatefulOp**); @@ -614,6 +616,52 @@ class CustomOp { bool isSGop; }; +/*! \brief Custom Subgraph Create function template */ +typedef MXReturnValue (*supportedOps_t)(std::string, int, int*, + std::unordered_map&); +typedef MXReturnValue (*acceptSubgraph_t)(std::string, int, bool*, + std::unordered_map&, + std::unordered_map&); + +/*! + * \brief An abstract class for subgraph property + */ +class CustomPartitioner { + public: + CustomPartitioner() : name("ERROR") {} + explicit CustomPartitioner(const char* backend_name) : + name(backend_name) {} + CustomPartitioner& addStrategy(const char* prop_name, + supportedOps_t fn, + const char* sg_name) { + strategies.push_back(prop_name); + supportedOps.push_back(fn); + op_names.push_back(sg_name); + return *this; + } + CustomPartitioner& setAcceptSubgraph(const char* prop_name, acceptSubgraph_t fn) { + accept_map[std::string(prop_name)] = fn; + return *this; + } + acceptSubgraph_t getAcceptSubgraph(int stg_id) { + std::string prop(strategies[stg_id]); + if (accept_map.find(prop) != accept_map.end()) + return accept_map[prop]; + else + return nullptr; + } + + /*! \brief partitioner name */ + const char* name; + std::map accept_map; + /*! \brief strategy names */ + std::vector strategies; + /*! \brief supported ops function */ + std::vector supportedOps; + /*! \brief subgraph operator name */ + std::vector op_names; +}; + /*! * \brief Registry class to registers things (ops, properties) * Singleton class @@ -670,10 +718,17 @@ class Registry { #define MX_REGISTER_NAME_(Name) MXNet ## _CustomOp ## _ #define MX_REGISTER_DEF_(Name) CustomOp MX_REGISTER_NAME_(Name) +#define MX_REGISTER_PROP_NAME_(Name) MXNet ## _CustomSubProp ## _ +#define MX_REGISTER_PROP_DEF_(Name) CustomPartitioner MX_REGISTER_PROP_NAME_(Name) + /*! \brief assign a var to a value */ #define REGISTER_OP(Name) MX_STR_CONCAT(MX_REGISTER_DEF_(Name), __COUNTER__) = \ Registry::get()->add(MX_TOSTRING(Name)) +#define REGISTER_PARTITIONER(Name) \ + MX_STR_CONCAT(MX_REGISTER_PROP_DEF_(Name), __COUNTER__) = \ + Registry::get()->add(MX_TOSTRING(Name)) + /* -------------- BELOW ARE CTYPE FUNCTIONS PROTOTYPES --------------- */ /*! @@ -688,7 +743,7 @@ typedef int (*opRegSize_t)(void); typedef int (*opRegGet_t)(int, const char**, fcomp_t*, fcomp_t*, parseAttrs_t*, inferType_t*, inferShape_t*, mutateInputs_t*, - createOpState_t*, bool*); + createOpState_t*, int*); #define MXLIB_OPCALLFREE_STR "_opCallFree" typedef int (*opCallFree_t)(void*); @@ -721,10 +776,30 @@ typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const void**); #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" -typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, size_t*, +typedef int (*opCallFStatefulComp_t)(int, void*, const int64_t**, int*, void**, int*, size_t*, int, const int64_t**, int*, void**, int*, size_t*, int, xpu_malloc_t, void*); +#define MXLIB_PARTREGSIZE_STR "_partRegSize" +typedef int (*partRegSize_t)(void); + +#define MXLIB_PARTREGGETCOUNT_STR "_partRegGetCount" +typedef int (*partRegGetCount_t)(int, const char**); + +#define MXLIB_PARTREGGET_STR "_partRegGet" +typedef void (*partRegGet_t)(int, int, const char**, supportedOps_t*, + acceptSubgraph_t*, const char**); + +#define MXLIB_PARTCALLSUPPORTEDOPS_STR "_partCallSupportedOps" +typedef int (*partCallSupportedOps_t)(supportedOps_t, const char*, int, int *, + const char* const*, const char* const*, int); +#define MXLIB_PARTCALLACCEPTSUBGRAPH_STR "_partCallAcceptSubgraph" +typedef int (*partCallAcceptSubgraph_t)(acceptSubgraph_t acceptSubgraph, + const char *json, int subgraph_id, + int *accept, const char* const*, + const char* const*, int, + char***, char***, int*); + #define MXLIB_INITIALIZE_STR "initialize" typedef int (*initialize_t)(int); @@ -761,7 +836,7 @@ extern "C" { _opRegGet(int idx, const char** name, fcomp_t* fcomp, fcomp_t* fgrad, parseAttrs_t* parse, inferType_t* type, inferShape_t* shape, mutateInputs_t* mutate, - createOpState_t* create_op, bool *isSGop) { + createOpState_t* create_op, int *isSGop) { CustomOp op = Registry::get()->get(idx); *name = op.name; *fcomp = op.forward; @@ -980,7 +1055,7 @@ extern "C" { #else int #endif - _opCallFStatefulCompute(bool is_forward, void* state_op, + _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes, int* indims, void** indata, int* intypes, size_t* inIDs, int num_in, const int64_t** outshapes, int* outdims, @@ -1006,6 +1081,106 @@ extern "C" { return op_ptr->Backward(inputs, outputs, res); } + /*! \brief returns number of partitioners registered in this library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _partRegSize() { + return Registry::get()->size(); + } + + /* returns number of strategies registered for partitioner + * at specified index */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _partRegGetCount(int idx, const char** name) { + CustomPartitioner part = Registry::get()->get(idx); + *name = part.name; + return part.strategies.size(); + } + + /*! \brief returns partitioner registration at specified index */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) void __cdecl +#else + void +#endif + _partRegGet(int part_idx, int stg_idx, const char** strategy, supportedOps_t* supportedOps, + acceptSubgraph_t* acceptSubgraph, const char** op_name) { + CustomPartitioner part = Registry::get()->get(part_idx); + *strategy = part.strategies[stg_idx]; + *supportedOps = part.supportedOps[stg_idx]; + *op_name = part.op_names[stg_idx]; + *acceptSubgraph = part.getAcceptSubgraph(stg_idx); + } + + /*! \brief returns status of calling parse attributes function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _partCallSupportedOps(supportedOps_t supportedOps, const char *json, + int num_ids, int *ids, const char* const* opt_keys, + const char* const* opt_vals, int num_opts) { + std::string subgraph_json(json); + // create map of attributes from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) { + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + } + return supportedOps(subgraph_json, num_ids, ids, opts); + } + + /*! \brief returns status of calling parse attributes function for operator from library */ +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + __declspec(dllexport) int __cdecl +#else + int +#endif + _partCallAcceptSubgraph(acceptSubgraph_t acceptSubgraph, const char *json, + int subgraph_id, int *accept, const char* const* opt_keys, + const char* const* opt_vals, int num_opts, + char*** attr_keys, char*** attr_vals, int *num_attrs) { + std::string subgraph_json(json); + bool accept_bool = false; + // create map of attributes from list + std::unordered_map opts; + for (int i = 0; i < num_opts; i++) { + opts[std::string(opt_keys[i])] = std::string(opt_vals[i]); + } + + // attributes to set on subgraph node + std::unordered_map attrs; + + MXReturnValue retval = acceptSubgraph(subgraph_json, subgraph_id, &accept_bool, opts, attrs); + *accept = accept_bool; + + if (attrs.size() > 0) { + *num_attrs = attrs.size(); + // allocate space for attributes + *attr_keys = static_cast(malloc (attrs.size() * sizeof(char*))); + *attr_vals = static_cast(malloc (attrs.size() * sizeof(char*))); + + // copy attributes + int i = 0; + for (auto kv : attrs) { + (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); + (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); + snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); + snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); + i++; + } + } + + return retval; + } + /*! * \brief Checks if the MXNet version is supported by the library. * If supported, initializes the library. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index bb98d96f733a..11b1ddcce5a7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -54,6 +54,8 @@ #include "../operator/subgraph/common.h" #include "../operator/tensor/matrix_op-inl.h" #include "../operator/tvmop/op_module.h" +#include "../operator/subgraph/partitioner/custom_subgraph_property.h" +#include "../operator/subgraph/subgraph_property.h" #include "../common/utils.h" #include "nnvm/pass_functions.h" @@ -142,6 +144,13 @@ int MXLoadLib(const char *path) { opCallFStatefulComp_t callFStatefulComp = get_func(lib, const_cast(MXLIB_OPCALLFSTATEFULCOMP_STR)); + partCallSupportedOps_t callSupportedOps = + get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); + + + partCallAcceptSubgraph_t callAcceptSubgraph = + get_func(lib, const_cast(MXLIB_PARTCALLACCEPTSUBGRAPH_STR)); + // get number of operators registered in the library opRegSize_t opRegSize = get_func(lib, const_cast(MXLIB_OPREGSIZE_STR)); int numOps = opRegSize(); @@ -164,15 +173,18 @@ int MXLoadLib(const char *path) { mutateInputs_t mutate_fp = nullptr; createOpState_t create_opstate_fp = nullptr; bool isSubgraphOp = false; + int _isSubgraphOp = 0; // get custom operator implemenation from the dynamic library opRegGet(i, &name, &fcomp_fp, &fgrad_fp, &parse_fp, &type_fp, &shape_fp, - &mutate_fp, &create_opstate_fp, &isSubgraphOp); + &mutate_fp, &create_opstate_fp, &_isSubgraphOp); + // set bool, dont pass bool across ABI boundary + isSubgraphOp = _isSubgraphOp; - CHECK(parse_fp != nullptr) << "Error loading '" << name - << "' custom op, ParseAttrs function was not set."; if (!isSubgraphOp) { // validate custom operator functions from the dynamic library + CHECK(parse_fp != nullptr) << "Error loading '" << name + << "' custom op, ParseAttrs function was not set."; CHECK(fcomp_fp != nullptr || create_opstate_fp != nullptr) << "Error loading '" << name << "' custom op, Forward or CreateOpState function was not set."; CHECK(type_fp != nullptr) << "Error loading '" << name @@ -650,9 +662,6 @@ int MXLoadLib(const char *path) { // check if operator is already registered const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); - regOp.set_attr_parser(attr_parser); - regOp.set_num_inputs(num_inputs); - regOp.set_num_outputs(num_outputs); int plevel = 10; if (regOpPtr != nullptr) { // overwrite registration of existing op with custom op @@ -662,6 +671,9 @@ int MXLoadLib(const char *path) { plevel++; } if (!isSubgraphOp) { + regOp.set_attr_parser(attr_parser); + regOp.set_num_inputs(num_inputs); + regOp.set_num_outputs(num_outputs); regOp.set_attr("FInferType", infer_type, plevel); regOp.set_attr("FInferShape", infer_shape, plevel); regOp.set_attr("FInferStorageType", infer_storage_type, plevel); @@ -671,6 +683,8 @@ int MXLoadLib(const char *path) { regOp.set_attr("FMutateInputs", mutate_inputs, plevel); } else { using namespace mxnet::op; + regOp.set_num_inputs(DefaultSubgraphOpNumInputs); + regOp.set_num_outputs(DefaultSubgraphOpNumOutputs); regOp.set_attr("FInferType", DefaultSubgraphOpType, plevel); regOp.set_attr("FInferShape", @@ -712,6 +726,57 @@ int MXLoadLib(const char *path) { } regOp.add_argument("data", "NDArray[]", "Source inputs"); } + + // get number of partitioners registered in the library + partRegSize_t partRegSize = get_func(lib, + const_cast(MXLIB_PARTREGSIZE_STR)); + int numParts = partRegSize(); + LOG(INFO) << "Found " << numParts << " partitioners in library"; + + /* + * Get all custom partitioners implementation from custom library + * loop and register each partitioner in the library to NNVM + */ + partRegGetCount_t partRegGetCount = get_func(lib, + const_cast(MXLIB_PARTREGGETCOUNT_STR)); + partRegGet_t partRegGet = get_func(lib, const_cast(MXLIB_PARTREGGET_STR)); + for (int i = 0; i < numParts; i++) { + const char* name; + // get custom partitioner strategy count from the dynamic library + int count = partRegGetCount(i, &name); + CHECK(count > 0) << "Error loading '" << name + << "' custom partitioner, no strategies defined"; + std::string name_str(name); + LOG(INFO) << "\tPartitioner[" << i << "] " << name; + + mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_BACKEND__(name); + + for (int j = 0; j < count; j++) { + const char* strategy; + // function pointers holding implementation from custom library + supportedOps_t supportedOps_fp = nullptr; + acceptSubgraph_t acceptSubgraph_fp = nullptr; + // name of subgraph op + const char* op_name = nullptr; + + // get custom partitioner strategy from the dynamic library + partRegGet(i, j, &strategy, &supportedOps_fp, &acceptSubgraph_fp, &op_name); + // validate custom partitioner functions from the dynamic library + CHECK(supportedOps_fp != nullptr) << "Error loading '" << name + << "' custom partitioner strategy '" << strategy + << "', supportedOps function was not set."; + std::string strategy_str(strategy); + std::string op_name_str(op_name); + LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str + << " subgraphOp: '" << op_name_str << "'"; + + // MXNET_REGISTER_SUBGRAPH_PROPERTY(customBackend, CustomSubgraphProperty); + mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__(name_str, + std::make_shared( + strategy_str, callSupportedOps, supportedOps_fp, + callAcceptSubgraph, acceptSubgraph_fp, callFree, op_name_str)); + } + } API_END(); } diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 8e7617d57c44..b5380b702f6d 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -559,10 +559,28 @@ void CutGraphInputs(const std::vector &input_entries, ++(it->second); } nnvm::NodePtr n = nnvm::CreateVariableNode(var_name + std::to_string(name_count_map[var_name])); + // set attribute for subgraph input to indicate if it is from an arg/param to model + if (e->node->is_variable()) + n->attrs.dict["isArg"] = "True"; + else + n->attrs.dict["isArg"] = "False"; *e = nnvm::NodeEntry{n, 0, 0}; } } +/*! + * \brief This function reattaches the original input nodes that were cut + * by CutGraphInputs. This function is used when subgraphs are rejected, it + * reattaches the subgraph back to the main graph where it was cut earlier. + */ +void ReattachGraphInputs(const std::vector &input_entries, + std::vector *orig_entries) { + for (size_t i = 0; i < input_entries.size(); ++i) { + nnvm::NodeEntry *e = input_entries[i]; + *e = orig_entries->at(i); + } +} + /*! * \brief Replace a set of nodes belonging to the same subgraph with a subgrpah node * and keep the subgraph in the subgraph node. @@ -620,6 +638,8 @@ void CreateSubgraphNode(nnvm::Graph* g, sn->outputs[n.get()].push_back(i); } } + } else { + ReattachGraphInputs(input_entries, &orig_input_entries); } #if DEBUG_SUBGRAPH if (n) diff --git a/src/operator/subgraph/partitioner/custom_subgraph_property.h b/src/operator/subgraph/partitioner/custom_subgraph_property.h new file mode 100644 index 000000000000..b4ea1a087d71 --- /dev/null +++ b/src/operator/subgraph/partitioner/custom_subgraph_property.h @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * This file contains an implementation of a subgraph property + * that interfaces between MXNet and custom subgraph properties + * created by users in external libraries. It does not implement + * any custom subgraphing logic itself, rather it calls APIs + * in the user's custom library to enable control of partitioning + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_PARTITIONER_CUSTOM_SUBGRAPH_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_PARTITIONER_CUSTOM_SUBGRAPH_PROPERTY_H_ + +#include +#include +#include +#include +#include +#include "../common.h" +#include "../subgraph_property.h" +#include "../../include/mxnet/lib_api.h" +namespace mxnet { +namespace op { + +/* + * This selects nodes for a subgraph based on node name as supplied + * by the supportedOps from an external library. It visits nodes via + * both input and output links. + */ +class CustomContainOpSelector: public SubgraphSelector { + public: + explicit CustomContainOpSelector(std::unordered_set supported_nodes) : + supported_nodes_(supported_nodes) {} + virtual bool Select(const nnvm::Node &n) { + return supported_nodes_.count(n.attrs.name) > 0; + } + virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) { + return supported_nodes_.count(new_node.attrs.name) > 0; + } + virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) { + return supported_nodes_.count(new_node.attrs.name) > 0; + } + std::unordered_set supported_nodes_; +}; + +/* + * This subgraph property finds a subgraph that only contains + * nodes as specified by the supportedOps from an external library. + * The operators in the subgraph will be executed by the operator + * specified by the external library too. + */ +class CustomSubgraphProperty: public SubgraphProperty { + public: + CustomSubgraphProperty() : + subgraph_prop("error"), + call_supported_ops_(nullptr), + supported_ops_(nullptr), + call_accept_subgraph_(nullptr), + accept_subgraph_(nullptr), + subgraph_op_name("error") {} + CustomSubgraphProperty(std::string subgraph_prop_name, + partCallSupportedOps_t call_supported_ops, + supportedOps_t supported_ops, + partCallAcceptSubgraph_t call_accept_subgraph, + acceptSubgraph_t accept_subgraph, + opCallFree_t call_free, + std::string op_name) : + subgraph_prop(subgraph_prop_name), + call_supported_ops_(call_supported_ops), + supported_ops_(supported_ops), + call_accept_subgraph_(call_accept_subgraph), + accept_subgraph_(accept_subgraph), + call_free_(call_free), + subgraph_op_name(op_name) {} + + // create custom subgraph property + static SubgraphPropertyPtr Create() { + return std::make_shared(); + } + + void PrePartition(const nnvm::Graph& g, + const std::vector>& options_map) { + // clear supported_nodes to remove state from previous calls + supported_nodes.clear(); + + // remove all graph attrs, some cannot be saved to json + nnvm::Graph graph = std::move(g); + graph.attrs.clear(); + const nnvm::IndexedGraph& indexed_graph = graph.indexed_graph(); + + // set shape attrs for each node in the graph + if (g.HasAttr("shape")) { + mxnet::ShapeVector shapes = g.GetAttr("shape"); + for (unsigned i = 0; i < indexed_graph.num_nodes(); i++) { + nnvm::Node* node = const_cast(indexed_graph[i].source); + mxnet::TShape shape = shapes[i]; + std::stringstream ss; + ss << shape; + node->attrs.dict["shape"] = ss.str(); + } + } + // set dtype attrs for each node in the graph + if (g.HasAttr("dtype")) { + std::vector dtypes = g.GetAttr >("dtype"); + for (unsigned i = 0; i < indexed_graph.num_nodes(); i++) { + nnvm::Node* node = const_cast(indexed_graph[i].source); + int dtype = dtypes[i]; + std::stringstream ss; + ss << dtype; + node->attrs.dict["dtype"] = ss.str(); + } + } + + CHECK(supported_ops_ != nullptr) + << "supported_ops_ is null for " << subgraph_prop << std::endl; + CHECK(call_supported_ops_ != nullptr) + << "call_supported_ops_ is null for " << subgraph_prop << std::endl; + + std::string subgraph_json = nnvm::pass::SaveJSON(graph); + std::vector supported_node_IDs(indexed_graph.num_nodes(), 0); + const char* json = subgraph_json.c_str(); + int *ids = supported_node_IDs.data(); + + // clear options from previous call + opt_keys_.clear(); + opt_vals_.clear(); + options_map_.clear(); + for (auto kv : options_map) { + options_map_.push_back(kv); + opt_keys_.push_back(options_map_.back().first.c_str()); + opt_vals_.push_back(options_map_.back().second.c_str()); + } + + CHECK(call_supported_ops_(supported_ops_, json, supported_node_IDs.size(), ids, + opt_keys_.data(), opt_vals_.data(), opt_keys_.size())) + << "Error calling supported_ops for '" << subgraph_prop << "'"; + + const auto& idx = g.indexed_graph(); + // loop and add node names for each supported node ID + for (unsigned i = 0; i < supported_node_IDs.size(); i++) { + if (supported_node_IDs[i]) { + supported_nodes.insert(idx[i].source->attrs.name); + } + } + } + // override CreateSubgraphNode + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const { + int accept = 1; + int num_attr = 0; + char** attr_keys = nullptr; + char** attr_vals = nullptr; + if (accept_subgraph_) { + nnvm::Graph g; + g.outputs = sym.outputs; + const auto& idx = g.indexed_graph(); + + // set isArg/isAux for each null op/param in the graph + const std::vector aux_names = sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + std::unordered_set aux_set(aux_names.begin(), aux_names.end()); + for (unsigned i = 0; i < idx.num_nodes(); i++) { + nnvm::Node* node = const_cast(idx[i].source); + // check if this node is input to subgraph + if (node->is_variable()) { + // check if this node is an aux param + if (aux_set.count(node->attrs.name)) + node->attrs.dict["isAux"] = "True"; + else + node->attrs.dict["isAux"] = "False"; + } + } + + std::string subgraph_json = nnvm::pass::SaveJSON(g); + CHECK(call_accept_subgraph_(accept_subgraph_, subgraph_json.c_str(), + subgraph_id, &accept, opt_keys_.data(), + opt_vals_.data(), opt_keys_.size(), + &attr_keys, &attr_vals, &num_attr)) + << "Error calling accept_subgraph for '" << subgraph_prop << "'"; + } + if (accept) { + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = Op::Get(subgraph_op_name); + n->attrs.name = "_op" + std::to_string(subgraph_id); + n->attrs.subgraphs.push_back(std::make_shared(sym)); + // set user specified attributes + for (int i=0; i < num_attr; i++) { + n->attrs.dict[attr_keys[i]] = attr_vals[i]; + call_free_(attr_vals[i]); + call_free_(attr_keys[i]); + } + // free memory used by custom op to allocate attributes + call_free_(attr_vals); + call_free_(attr_keys); + return n; + } else { + return NULL; + } + } + // override CreateSubgraphSelector + virtual SubgraphSelectorPtr CreateSubgraphSelector() const { + return std::make_shared(supported_nodes); + } + + std::string subgraph_prop; + partCallSupportedOps_t call_supported_ops_; + supportedOps_t supported_ops_; + partCallAcceptSubgraph_t call_accept_subgraph_; + acceptSubgraph_t accept_subgraph_; + opCallFree_t call_free_; + std::unordered_set supported_nodes; + std::string subgraph_op_name; + std::vector> options_map_; + std::vector opt_keys_, opt_vals_; +}; +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_SUBGRAPH_PARTITIONER_CUSTOM_SUBGRAPH_PROPERTY_H_ diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index bbc8f4076899..643c02a82b13 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -520,6 +520,15 @@ class SubgraphBackendRegistry { return SubgraphPropertyEntry(prop); } + SubgraphPropertyEntry __REGISTER_CUSTOM_PROPERTY__(const std::string& name, + SubgraphPropertyPtr cprop) { + auto it = backend_map_.find(name); + CHECK(it != backend_map_.end()) + << "Subgraph backend " << name << " is not found in SubgraphBackendRegistry"; + auto prop = it->second->RegisterSubgraphProperty(cprop); + return SubgraphPropertyEntry(prop); + } + SubgraphBackendRegistry() = default; SubgraphBackendRegistry(const SubgraphBackendRegistry&) = delete; SubgraphBackendRegistry(SubgraphBackendRegistry&&) = delete; diff --git a/tests/python/unittest/test_extensions.py b/tests/python/unittest/test_extensions.py index cc7858dce0fd..63b54d8516b4 100644 --- a/tests/python/unittest/test_extensions.py +++ b/tests/python/unittest/test_extensions.py @@ -84,3 +84,67 @@ def test_custom_op(): assert_almost_equal(in_grad_base[0].asnumpy(), in_grad1[0].asnumpy(), rtol=1e-3, atol=1e-3) assert_almost_equal(in_grad_base[0].asnumpy(), in_grad2[0].asnumpy(), rtol=1e-3, atol=1e-3) + +@unittest.skipIf(check_platform(), "not all machine types supported") +@unittest.skipIf(is_cd_run(), "continuous delivery run - ignoring test") +def test_subgraph(): + # possible places to find library file + if (os.name=='posix'): + lib = 'libsubgraph_lib.so' + if os.path.exists(lib): + # plain make build, when run in the CI + fname = lib + elif os.path.exists('build/'+lib): + # plain cmake build when run in the CI + fname = 'build/'+lib + else: + raise MXNetError("library %s not found " % lib) + elif (os.name=='nt'): + lib = 'libsubgraph_lib.dll' + if os.path.exists('windows_package\\lib\\'+lib): + # plain make build, when run in the CI + fname = 'windows_package\\lib\\'+lib + else: + # plain cmake build when run in the CI + raise MXNetError("library %s not found " % lib) + + fname = os.path.abspath(fname) + mx.library.load(fname) + + # test simple graph with add, exp and log operators, library supports exp/log + a = mx.sym.var('a') + b = mx.sym.var('b') + c = a + b + d = mx.sym.exp(c) + sym = mx.sym.log(d) + + args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 'b':mx.nd.ones((3,2),ctx=mx.cpu())} + arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()), + mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())] + + # baseline - regular execution in MXNet + exe = sym.bind(ctx=mx.cpu(), args=args) + out = exe.forward() + + # without propogating shapes/types, passing a custom option to subgraph prop "myOpt" + # should not create subgraph since subgraph prop requires type info + mysym1 = sym.optimize_for("myProp", myOpt='yello') + exe1 = mysym1.bind(ctx=mx.cpu(), args=args) + out1 = exe1.forward() + # check that result matches one executed by MXNet + assert_almost_equal(out[0].asnumpy(), out1[0].asnumpy(), rtol=1e-3, atol=1e-3) + + # with propogating shapes/types, rejecting subgraph + # this tests creating the subgraph and having the subgraph prop reject it + mysym2 = sym.optimize_for("myProp", arg_array, reject=True) + exe2 = mysym2.bind(ctx=mx.cpu(), args=args) + out2 = exe2.forward() + # check that result matches one executed by MXNet + assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3) + + # with propogating shapes/types + mysym3 = sym.optimize_for("myProp",arg_array) + exe3 = mysym3.bind(ctx=mx.cpu(), args=args) + out3 = exe3.forward() + # check that result matches one executed by MXNet + assert_almost_equal(out[0].asnumpy(), out3[0].asnumpy(), rtol=1e-3, atol=1e-3)