Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions example/extensions/lib_pass/test_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,30 @@
sym = mx.sym.log(d)

def test_model(pass_name):
args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
# execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print(out)

# Symbol optimize_for
# with propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" with shapes/types' % pass_name)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
aux = []
mysym2 = sym.optimize_for(pass_name,arg_array,aux)
aux = {}
mysym2 = sym.optimize_for(pass_name,args,aux)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing pass "%s" without shapes/types' % pass_name)
mysym3 = sym.optimize_for(pass_name, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

Expand Down
25 changes: 12 additions & 13 deletions example/extensions/lib_subgraph/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,39 @@
sym2 = mx.sym.log(d2)

def test(backend):
args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}
###############################################
# Test with subgraph not consuming params
###############################################
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe = sym.bind(ctx=mx.cpu(), args=args)
out = exe.forward()
print(out)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend,arg_array)
mysym2 = sym.optimize_for(backend,args)
print(mysym2.tojson())
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# with propogating shapes/types, rejecting subgraph
print('-------------------------------')
print('Testing %s partitioning with shapes/types - rejecting subgraph' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32'), mx.nd.ones((3,2),dtype='float32')]
mysym2 = sym.optimize_for(backend, arg_array, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
mysym2 = sym.optimize_for(backend, args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
print(out2)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym3 = sym.optimize_for(backend, myOpt='yello')
exe3 = mysym3.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))})
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
print(out3)

Expand All @@ -108,28 +107,28 @@ def test(backend):
###############################################
# Test with subgraph directly consuming params
###############################################
args = {'a':mx.nd.ones((3,2))}
#execute in MXNet
print('-------------------------------')
print('Testing regular MXNet execution')
exe5 = sym2.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe5 = sym2.bind(ctx=mx.cpu(), args=args)
out5 = exe5.forward()
print(out5)

# with propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning with shapes/types' % backend)
arg_array = [mx.nd.ones((3,2),dtype='float32')]
mysym6 = sym2.optimize_for(backend, arg_array, reqArgs=True)
mysym6 = sym2.optimize_for(backend, args, reqArgs=True)
print(mysym6.tojson())
exe6 = mysym6.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe6 = mysym6.bind(ctx=mx.cpu(), args=args)
out6 = exe6.forward()
print(out6)

# without propogating shapes/types
print('-------------------------------')
print('Testing %s partitioning without shapes/types' % backend)
mysym7 = sym2.optimize_for(backend, reqArgs=True)
exe7 = mysym7.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2))})
exe7 = mysym7.bind(ctx=mx.cpu(), args=args)
out7 = exe7.forward()
print(out7)

Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,12 +1033,12 @@ def _build_cache(self, *args):
if self._backend:
ctx = args[0].context
# get list of params in the order of out.list_arguments
arg_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()]
aux_array = [args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()]
arg_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_arguments()}
aux_dict = {name:args[data_names[name]] if name in data_names.keys() else params[name].data()
for name in out.list_auxiliary_states()}
# Partition the graph.
out = out.optimize_for(self._backend, arg_array, aux_array, ctx, **self._backend_opts)
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, **self._backend_opts)
#update cached graph with partitioned graph
self._cached_graph = data, out
self._cached_op = ndarray.CachedOp(out, flags)
Expand Down
40 changes: 16 additions & 24 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,17 +1456,15 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
backend : str
The name of backend, as registered in `SubgraphBackendRegistry`

args : list of NDArray or dict of str to NDArray, optional
args : dict of str to NDArray, optional
Input arguments to the symbol, required to infer shapes/types before partitioning

- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.

aux : list of NDArray or dict of str to NDArray, optional
aux : dict of str to NDArray, optional
Input auxiliary arguments to the symbol

- If type is a list of `NDArray`, the order is the same as that of `list_arguments()`.
- If type is a dict of str to `NDArray`, then it maps the name of arguments
to the corresponding `NDArray`.

Expand All @@ -1483,6 +1481,8 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
"""
out = SymbolHandle()
assert isinstance(backend, str)
assert isinstance(args, dict) or args is None
assert isinstance(aux, dict) or aux is None

if args is None or len(args) == 0:
args_ = []
Expand Down Expand Up @@ -1530,30 +1530,22 @@ def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
ctypes.byref(new_aux_size),
ctypes.byref(new_aux_handle),
ctypes.byref(new_aux_names)))
arg_names = self.list_arguments()
if isinstance(args, dict):
# add new args/aux
if not args is None:
for i in range(new_args_size.value):
args[py_str(new_arg_names[i])] = NDArray(NDArrayHandle(new_args_handle[i]))
elif isinstance(args, list):
for i in range(new_args_size.value):
name = py_str(new_arg_names[i])
if name in arg_names:
idx = arg_names.index(name)
args[idx] = NDArray(NDArrayHandle(new_args_handle[i]))
else:
args.append(NDArray(NDArrayHandle(new_args_handle[i])))
aux_names = self.list_auxiliary_states()
if isinstance(aux, dict):
elif new_args_size.value > 0:
raise RuntimeError('Cannot add new args in optimize_for since args is None\n' +
'Provide a dictionary to the args argument to optimize_for')

if not aux is None:
for i in range(new_aux_size.value):
aux[py_str(new_aux_names[i])] = NDArray(NDArrayHandle(new_aux_handle[i]))
elif isinstance(aux, list):
for i in range(new_aux_size.value):
name = py_str(new_aux_names[i])
if name in aux_names:
idx = aux_names.index(name)
aux[idx] = NDArray(NDArrayHandle(new_aux_handle[i]))
else:
aux.append(NDArray(NDArrayHandle(new_aux_handle[i])))
elif new_aux_size.value > 0:
raise RuntimeError('Cannot add new aux in optimize_for since aux is None\n' +
'Provide a dictionary to the aux argument to optimize_for')

# return modified symbol
return Symbol(out)


Expand Down
6 changes: 2 additions & 4 deletions tests/python/unittest/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ def test_subgraph():
sym = mx.sym.log(d)

args = {'a':mx.nd.ones((3,2),ctx=mx.cpu()), 'b':mx.nd.ones((3,2),ctx=mx.cpu())}
arg_array = [mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu()),
mx.nd.ones((3,2),dtype='float32',ctx=mx.cpu())]

# baseline - regular execution in MXNet
exe = sym.bind(ctx=mx.cpu(), args=args)
Expand All @@ -147,14 +145,14 @@ def test_subgraph():

# with propogating shapes/types, rejecting subgraph
# this tests creating the subgraph and having the subgraph prop reject it
mysym2 = sym.optimize_for("myProp", arg_array, reject=True)
mysym2 = sym.optimize_for("myProp", args, reject=True)
exe2 = mysym2.bind(ctx=mx.cpu(), args=args)
out2 = exe2.forward()
# check that result matches one executed by MXNet
assert_almost_equal(out[0].asnumpy(), out2[0].asnumpy(), rtol=1e-3, atol=1e-3)

# with propogating shapes/types
mysym3 = sym.optimize_for("myProp",arg_array)
mysym3 = sym.optimize_for("myProp",args)
exe3 = mysym3.bind(ctx=mx.cpu(), args=args)
out3 = exe3.forward()
# check that result matches one executed by MXNet
Expand Down
14 changes: 8 additions & 6 deletions tests/python/unittest/test_subgraph_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,18 +353,20 @@ def test_subgraph_exe8(sym, subgraph_backend, op_names):
# bind
sym, _, _ = sym
arg_shapes, _, aux_shapes = sym.infer_shape()
arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
arg_names = sym.list_arguments()
aux_names = sym.list_auxiliary_states()
arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)}
aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)}
exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
exe1.forward()

# infer shape/type before partition before bind
check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, arg_array, aux_array)
c_str_array(op_names)))
part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict)
check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))

exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
exe2.forward()

# compare outputs
Expand Down