Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
kh, kw = attrs.get_int_tuple('kernel_size')
groups = attrs.get_int("groups")
channels = attrs.get_int("channels")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
out_dtype = attrs.get_string("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "not support dilate now"
if groups == 1:
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw),
strides, padding, layout, out_layout)
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding,
layout, out_layout, out_dtype)
# pylint: enable=assignment-from-no-return
else:
raise ValueError("not support arbitrary group number > 1 for now")
Expand All @@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _):
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of conv2d NCHWc"""
groups = attrs.get_int("groups")
kh, kw = attrs.get_int_tuple('kernel_size')
oc = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
layout = attrs.get_string("layout")
out_layout = attrs.get_string("out_layout")
with tvm.target.create(target):
if groups == 1:
return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding,
layout, out_layout, outs)
return topi.generic.schedule_conv2d_NCHWc(outs)
else:
raise ValueError("not support group number > 1 for now")

Expand Down
113 changes: 81 additions & 32 deletions python/tvm/autotvm/task/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,53 @@ def query(self, target, workload):
ret = self._old_ctx.query(target, workload)
return ret

def update(self, target, workload, cfg):
"""
Update context with a specific config.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a quite critical design here, and I think it is best if we could provide some motivation(with an example) on why do we need to expose this API)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a Note block to this function


Parameters
----------
target: Target
The current target
workload : Workload
The current workload.
cfg : ConfigSpace
The specific configuration.

Note
----
This interface is for cases when TVM decides to replace an operator in the graph.
For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW`
convolution with `NCHW[x]c` implementation on x86 CPUs.
Thus in TOPI, we first query schedule using original `NCHW` workload,
then update the dispatcher with the new `NCHW[x]c` workload.
So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using
its own workload directly.

.. code-block:: python

@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo):
workload = get_conv2d_workload(...)
dispatch_ctx = autotvm.task.DispatchContext.current
target = tvm.target.current_target()
config = dispatch_ctx.query(target, workload)

# Get conv2d_NCHWc workload from config
# new_workload = ...
# new_inputs = ...
# new_attrs = ...

# Store altered operator's config
dispatch_ctx.update(target, new_workload, config)
return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs)

We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc`
share the same schedule parameters.
One can construct a new `ConfigEntity` if this is not the case.
"""
raise NotImplementedError()

def _query_inside(self, target, workload):
"""
Query the context to get the specific config for a template.
Expand Down Expand Up @@ -179,6 +226,11 @@ def _query_inside(self, target, workload):
self.workload = workload
return self._config

def update(self, target, workload, cfg):
"""Override update"""
self.workload = workload
self._config = cfg


class ApplyHistoryBest(DispatchContext):
"""
Expand All @@ -197,6 +249,7 @@ def __init__(self, records):

self.best_by_targetkey = {}
self.best_by_model = {}
self._best_user_defined = {}

if records:
self.load(records)
Expand Down Expand Up @@ -264,17 +317,32 @@ def _query_inside(self, target, workload):
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_model:
return self.best_by_model[key][0].config

# then try matching by target key
for k in target.keys:
key = (k, workload)
if key in self._best_user_defined:
return self._best_user_defined[key]
if key in self.best_by_targetkey:
return self.best_by_targetkey[key][0].config

return None

def update(self, target, workload, cfg):
for opt in target.options:
if opt.startswith("-model"):
model = opt[7:]
key = (model, workload)
self._best_user_defined[key] = cfg

for k in target.keys:
key = (k, workload)
self._best_user_defined[key] = cfg


class FallbackContext(DispatchContext):
"""
Expand Down Expand Up @@ -324,6 +392,10 @@ def clear_cache(self, target, workload):
if key in self.memory:
del self.memory[key]

def update(self, target, workload, cfg):
key = (str(target), workload)
self.memory[key] = cfg

DispatchContext.current = FallbackContext()

def clear_fallback_cache(target, workload):
Expand Down Expand Up @@ -391,37 +463,14 @@ def _query_inside(self, target, workload):
cfg : ConfigSpace
The specific configuration.
"""
cfg = self._records[self._counter][0].config
self._counter += 1
return cfg

def query_global_dict(self, key):
"""
Query the context to get config from global
config dictionary.

Parameters
----------
key : str
Key to query the config.

Returns
-------
cfg : ConfigSpace
The specific configuration.
"""
if self._counter < len(self._records):
cfg = self._records[self._counter][0].config
self._counter += 1
self.update(target, workload, cfg)
return cfg
key = (str(target), workload)
return self._global_cfg_dict[key]

def update_global_dict(self, key, val):
"""
Update the global config dictionary.

Parameters
----------
key : str
Key of config.

val : ConfigSpace
Value of config.
"""
self._global_cfg_dict[key] = val
def update(self, target, workload, cfg):
key = (str(target), workload)
self._global_cfg_dict[key] = cfg
14 changes: 13 additions & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate,too-many-lines
"""
Template configuration space.

Expand Down Expand Up @@ -996,5 +996,17 @@ def fallback_with_reference_log(self, ref_log):
if not isinstance(self.space_map[knob_name], SplitSpace):
self._entity_map[knob_name] = best_match_cfg[knob_name]

def __setitem__(self, name, entity):
"""set the entity(knob) of by name

Parameters
----------
name: str
name of the entity
entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
value of the entity
"""
self._entity_map[name] = entity

def __repr__(self):
return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
15 changes: 9 additions & 6 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None):

return ret

def args_to_workload(x):
def args_to_workload(x, topi_compute_func=None):
"""Convert argument list to hashable workload tuple.
This function will convert list to tuple, tvm node to python value and
flatten tvm.tensor.Tensor to a tuple
Expand All @@ -191,25 +191,28 @@ def args_to_workload(x):
----------
x: primitive hashable types or tensor.Tensor
The original value
topi_compute_func: topi compute function
The function name will be added as first element of the workload tuple

Returns
-------
ret: hashable
The hashable value
"""
if isinstance(x, tensor.Tensor):
return get_const_tuple(x.shape) + (x.dtype, )
workload = get_const_tuple(x.shape) + (x.dtype, )
elif isinstance(x, (tuple, list, container.Array)):
return tuple([args_to_workload(a) for a in x])
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
return x
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
workload = x.value
elif x is None:
return 0
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload

def template(func):
"""
Expand Down
17 changes: 7 additions & 10 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=unused-variable,invalid-name
# pylint: disable=unused-variable,invalid-name,unused-argument
"""
Decorators for registering tunable templates to TOPI.

Expand All @@ -13,7 +13,6 @@

from ... import _api_internal, tensor

from ..util import get_func_name
from .task import args_to_workload, dispatcher


Expand Down Expand Up @@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
--------
See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage.
"""
fname = get_func_name(topi_compute)

def _decorator(f):
targets = [target_keys] if isinstance(target_keys, str) else target_keys
for target_key in targets:
Expand All @@ -68,7 +65,7 @@ def _decorator(f):
def config_dispatcher(*args, **kwargs):
"""override topi call as a config dispatcher"""
assert not kwargs, "Do not support kwargs in template function call"
return (fname, ) + args_to_workload(args)
return args_to_workload(args, topi_compute)
_REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher

config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute]
Expand All @@ -88,7 +85,7 @@ def template_call(cfg, *args, **kwargs):
attrs = {}
for k, v in node.op.attrs.items():
attrs[k] = v
attrs['workload'] = (fname, ) + args_to_workload(args)
attrs['workload'] = args_to_workload(args, topi_compute)
if isinstance(op, tensor.ComputeOp):
op = _api_internal._ComputeOp(
op.name, op.tag, attrs, op.axis, op.body)
Expand Down Expand Up @@ -153,7 +150,7 @@ def _decorator(f):
if topi_schedule not in _REGISTED_DISPATHCER[target_key]:
@topi_schedule.register(target_key)
@dispatcher
def config_dispatcher(outs):
def config_dispatcher(outs, *args, **kwargs):
"""override topi call as a workload dispatcher"""
def traverse(tensors):
"""traverse all ops to find attached workload"""
Expand All @@ -179,11 +176,11 @@ def traverse(tensors):
config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule]

@config_dispatcher.register(template_keys)
def template_call(cfg, outs):
def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func"""
if f == topi_schedule.fdefault:
return f(outs)
return f(cfg, outs)
return f(outs, *args, **kwargs)
return f(cfg, outs, *args, **kwargs)

return f

Expand Down
22 changes: 2 additions & 20 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs):


@tvm.target.generic_func
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
def schedule_conv2d_NCHWc(outs):
"""Schedule for conv2d_NCHW[x]c

Parameters
----------
num_filter : int
The number of filter, i.e., the output channel.

kernel_size : tuple of int
(kernel_height, kernel_width)

strides : tuple of int
(stride_of_height, stride_of_width)

padding : tuple of int
(pad_of_height, pad_of_width)

layout : str
Input data layout

out_layout : str
Output data layout

outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
The number of filter, i.e., the output channel.

Returns
-------
Expand Down
Loading