diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 49192cacd713..a4b36ea853d5 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -167,16 +167,16 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") - kh, kw = attrs.get_int_tuple('kernel_size') groups = attrs.get_int("groups") - channels = attrs.get_int("channels") layout = attrs.get_string("layout") out_layout = attrs.get_string("out_layout") + out_dtype = attrs.get_string("out_dtype") + out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype assert dilation == (1, 1), "not support dilate now" if groups == 1: # pylint: disable=assignment-from-no-return - out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], channels, (kh, kw), - strides, padding, layout, out_layout) + out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, + layout, out_layout, out_dtype) # pylint: enable=assignment-from-no-return else: raise ValueError("not support arbitrary group number > 1 for now") @@ -190,16 +190,9 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, _): def schedule_contrib_conv2d_NCHWc(attrs, outs, target): """Schedule definition of conv2d NCHWc""" groups = attrs.get_int("groups") - kh, kw = attrs.get_int_tuple('kernel_size') - oc = attrs.get_int("channels") - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - layout = attrs.get_string("layout") - out_layout = attrs.get_string("out_layout") with tvm.target.create(target): if groups == 1: - return topi.generic.schedule_conv2d_NCHWc(oc, (kh, kw), strides, padding, - layout, out_layout, outs) + return topi.generic.schedule_conv2d_NCHWc(outs) else: raise ValueError("not support group number > 1 for now") diff --git a/python/tvm/autotvm/task/dispatcher.py b/python/tvm/autotvm/task/dispatcher.py index 164877e3b451..fd91d60e7982 100644 --- a/python/tvm/autotvm/task/dispatcher.py +++ b/python/tvm/autotvm/task/dispatcher.py @@ -60,6 +60,53 @@ def query(self, target, workload): ret = self._old_ctx.query(target, workload) return ret + def update(self, target, workload, cfg): + """ + Update context with a specific config. + + Parameters + ---------- + target: Target + The current target + workload : Workload + The current workload. + cfg : ConfigSpace + The specific configuration. + + Note + ---- + This interface is for cases when TVM decides to replace an operator in the graph. + For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` + convolution with `NCHW[x]c` implementation on x86 CPUs. + Thus in TOPI, we first query schedule using original `NCHW` workload, + then update the dispatcher with the new `NCHW[x]c` workload. + So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using + its own workload directly. + + .. code-block:: python + + @conv2d_alter_layout.register("cpu") + def _alter_conv2d_layout(attrs, inputs, tinfo): + workload = get_conv2d_workload(...) + dispatch_ctx = autotvm.task.DispatchContext.current + target = tvm.target.current_target() + config = dispatch_ctx.query(target, workload) + + # Get conv2d_NCHWc workload from config + # new_workload = ... + # new_inputs = ... + # new_attrs = ... + + # Store altered operator's config + dispatch_ctx.update(target, new_workload, config) + return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) + + We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` + share the same schedule parameters. + One can construct a new `ConfigEntity` if this is not the case. + """ + raise NotImplementedError() + def _query_inside(self, target, workload): """ Query the context to get the specific config for a template. @@ -179,6 +226,11 @@ def _query_inside(self, target, workload): self.workload = workload return self._config + def update(self, target, workload, cfg): + """Override update""" + self.workload = workload + self._config = cfg + class ApplyHistoryBest(DispatchContext): """ @@ -197,6 +249,7 @@ def __init__(self, records): self.best_by_targetkey = {} self.best_by_model = {} + self._best_user_defined = {} if records: self.load(records) @@ -264,17 +317,32 @@ def _query_inside(self, target, workload): if opt.startswith("-model"): model = opt[7:] key = (model, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] if key in self.best_by_model: return self.best_by_model[key][0].config # then try matching by target key for k in target.keys: key = (k, workload) + if key in self._best_user_defined: + return self._best_user_defined[key] if key in self.best_by_targetkey: return self.best_by_targetkey[key][0].config return None + def update(self, target, workload, cfg): + for opt in target.options: + if opt.startswith("-model"): + model = opt[7:] + key = (model, workload) + self._best_user_defined[key] = cfg + + for k in target.keys: + key = (k, workload) + self._best_user_defined[key] = cfg + class FallbackContext(DispatchContext): """ @@ -324,6 +392,10 @@ def clear_cache(self, target, workload): if key in self.memory: del self.memory[key] + def update(self, target, workload, cfg): + key = (str(target), workload) + self.memory[key] = cfg + DispatchContext.current = FallbackContext() def clear_fallback_cache(target, workload): @@ -391,37 +463,14 @@ def _query_inside(self, target, workload): cfg : ConfigSpace The specific configuration. """ - cfg = self._records[self._counter][0].config - self._counter += 1 - return cfg - - def query_global_dict(self, key): - """ - Query the context to get config from global - config dictionary. - - Parameters - ---------- - key : str - Key to query the config. - - Returns - ------- - cfg : ConfigSpace - The specific configuration. - """ + if self._counter < len(self._records): + cfg = self._records[self._counter][0].config + self._counter += 1 + self.update(target, workload, cfg) + return cfg + key = (str(target), workload) return self._global_cfg_dict[key] - def update_global_dict(self, key, val): - """ - Update the global config dictionary. - - Parameters - ---------- - key : str - Key of config. - - val : ConfigSpace - Value of config. - """ - self._global_cfg_dict[key] = val + def update(self, target, workload, cfg): + key = (str(target), workload) + self._global_cfg_dict[key] = cfg diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index f9bf60237776..32bd66b6c12d 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -1,5 +1,5 @@ # pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ -# pylint: disable=consider-using-enumerate +# pylint: disable=consider-using-enumerate,too-many-lines """ Template configuration space. @@ -996,5 +996,17 @@ def fallback_with_reference_log(self, ref_log): if not isinstance(self.space_map[knob_name], SplitSpace): self._entity_map[knob_name] = best_match_cfg[knob_name] + def __setitem__(self, name, entity): + """set the entity(knob) of by name + + Parameters + ---------- + name: str + name of the entity + entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity + value of the entity + """ + self._entity_map[name] = entity + def __repr__(self): return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index ab52788c8d91..22a15143b96e 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -182,7 +182,7 @@ def create(func_name, args, target, target_host=None, template_key=None): return ret -def args_to_workload(x): +def args_to_workload(x, topi_compute_func=None): """Convert argument list to hashable workload tuple. This function will convert list to tuple, tvm node to python value and flatten tvm.tensor.Tensor to a tuple @@ -191,6 +191,8 @@ def args_to_workload(x): ---------- x: primitive hashable types or tensor.Tensor The original value + topi_compute_func: topi compute function + The function name will be added as first element of the workload tuple Returns ------- @@ -198,18 +200,19 @@ def args_to_workload(x): The hashable value """ if isinstance(x, tensor.Tensor): - return get_const_tuple(x.shape) + (x.dtype, ) + workload = get_const_tuple(x.shape) + (x.dtype, ) elif isinstance(x, (tuple, list, container.Array)): - return tuple([args_to_workload(a) for a in x]) + workload = tuple([args_to_workload(a) for a in x]) elif isinstance(x, (str, int, float, np.int, np.float)): - return x + workload = x elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): - return x.value + workload = x.value elif x is None: - return 0 + workload = 0 else: raise RuntimeError('Do not support type "%s" in argument. Consider to use' 'primitive types only' % type(x)) + return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload def template(func): """ diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 18f45f8d6708..f005ee0c9a54 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -1,4 +1,4 @@ -# pylint: disable=unused-variable,invalid-name +# pylint: disable=unused-variable,invalid-name,unused-argument """ Decorators for registering tunable templates to TOPI. @@ -13,7 +13,6 @@ from ... import _api_internal, tensor -from ..util import get_func_name from .task import args_to_workload, dispatcher @@ -55,8 +54,6 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): -------- See tvm/topi/python/topi/arm_cpu/depthwise_conv2d.py for example usage. """ - fname = get_func_name(topi_compute) - def _decorator(f): targets = [target_keys] if isinstance(target_keys, str) else target_keys for target_key in targets: @@ -68,7 +65,7 @@ def _decorator(f): def config_dispatcher(*args, **kwargs): """override topi call as a config dispatcher""" assert not kwargs, "Do not support kwargs in template function call" - return (fname, ) + args_to_workload(args) + return args_to_workload(args, topi_compute) _REGISTED_DISPATHCER[target_key][topi_compute] = config_dispatcher config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_compute] @@ -88,7 +85,7 @@ def template_call(cfg, *args, **kwargs): attrs = {} for k, v in node.op.attrs.items(): attrs[k] = v - attrs['workload'] = (fname, ) + args_to_workload(args) + attrs['workload'] = args_to_workload(args, topi_compute) if isinstance(op, tensor.ComputeOp): op = _api_internal._ComputeOp( op.name, op.tag, attrs, op.axis, op.body) @@ -153,7 +150,7 @@ def _decorator(f): if topi_schedule not in _REGISTED_DISPATHCER[target_key]: @topi_schedule.register(target_key) @dispatcher - def config_dispatcher(outs): + def config_dispatcher(outs, *args, **kwargs): """override topi call as a workload dispatcher""" def traverse(tensors): """traverse all ops to find attached workload""" @@ -179,11 +176,11 @@ def traverse(tensors): config_dispatcher = _REGISTED_DISPATHCER[target_key][topi_schedule] @config_dispatcher.register(template_keys) - def template_call(cfg, outs): + def template_call(cfg, outs, *args, **kwargs): """call the schedule func""" if f == topi_schedule.fdefault: - return f(outs) - return f(cfg, outs) + return f(outs, *args, **kwargs) + return f(cfg, outs, *args, **kwargs) return f diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index e99ce263296b..765b48d286bc 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -55,33 +55,15 @@ def schedule_conv2d_nhwc(outs): @tvm.target.generic_func -def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, - padding, layout, out_layout, outs): +def schedule_conv2d_NCHWc(outs): """Schedule for conv2d_NCHW[x]c Parameters ---------- - num_filter : int - The number of filter, i.e., the output channel. - - kernel_size : tuple of int - (kernel_height, kernel_width) - - strides : tuple of int - (stride_of_height, stride_of_width) - - padding : tuple of int - (pad_of_height, pad_of_width) - - layout : str - Input data layout - - out_layout : str - Output data layout - outs : Array of Tensor The computation graph description of conv2d_NCHWc in the format of an array of tensors. + The number of filter, i.e., the output channel. Returns ------- diff --git a/topi/python/topi/hls/nn.py b/topi/python/topi/hls/nn.py index 8c986d7a5663..536453fc629c 100644 --- a/topi/python/topi/hls/nn.py +++ b/topi/python/topi/hls/nn.py @@ -73,30 +73,11 @@ def schedule_conv2d_nhwc(outs): @generic.schedule_conv2d_NCHWc.register(["hls"]) -def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, - padding, layout, out_layout, outs): +def schedule_conv2d_NCHWc(outs): """Schedule for conv2d_NCHW[x]c Parameters ---------- - num_filter : int - The number of filter, i.e., the output channel. - - kernel_size : tuple of int - (kernel_height, kernel_width) - - strides : tuple of int - (stride_of_height, stride_of_width) - - padding : tuple of int - (pad_of_height, pad_of_width) - - layout : str - Input data layout - - out_layout : str - Output data layout - outs : Array of Tensor The computation graph description of conv2d_NCHWc in the format of an array of tensors. diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 4dae00e9c146..f6767b68afa1 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -61,8 +61,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) @conv2d_NCHWc.register(["intel_graphics"]) -def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout,\ - out_layout, out_dtype='float32'): +def _decl_conv2d(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'): """Conv2D operator for Intel Graphics backend. Parameters @@ -101,7 +100,7 @@ def _decl_conv2d(data, kernel, num_filter, kernel_size, stride, padding, layout, return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype) @generic.schedule_conv2d_NCHWc.register(["intel_graphics"]) -def schedule_conv2d_NCHWc(num_filter, kernel_size, stride, padding, layout, out_layout, outs): +def schedule_conv2d_NCHWc(outs): """Schedule for conv2d_nchw for Intel Graphics Parameters diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 4d70c4903a3f..7636350dfbf6 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -84,32 +84,6 @@ def _get_workload(data, kernel, stride, padding, out_dtype): '{} vs. {}".format(data.dtype, kernel.dtype) return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) -def _get_workload_int8(data, kernel, stride, padding, out_dtype): - """ Get the workload structure. """ - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ - "Do not support inputs with different data types now. ' \ - '{} vs. {}".format(data.dtype, kernel.dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) - - - -@tvm.target.generic_func -def _get_alter_layout_schedule(wkl): - # pylint: disable=unreachable - """ Get the platform specific schedule for conv2d_alter_layout. """ - target = tvm.target.current_target() - raise RuntimeError( - "No schedule for current target:{}".format(target)) - # This return has no use, merely to supress pylint warning - return wkl - @tvm.target.generic_func def _get_schedule(wkl): @@ -122,28 +96,6 @@ def _get_schedule(wkl): return wkl -@tvm.target.generic_func -def _get_schedule_NCHWc(wkl, layout, out_layout): - # pylint: disable=unreachable - """ Get the platform specific schedule. """ - target = tvm.target.current_target() - raise RuntimeError( - "No schedule for current target:{}".format(target)) - # This return has no use, merely to supress pylint warning - return wkl - - -@tvm.target.generic_func -def _get_schedule_NCHWc_int8(wkl, layout, out_layout): - # pylint: disable=unreachable - """ Get the platform specific schedule. """ - target = tvm.target.current_target() - raise RuntimeError( - "No schedule for current target:{}".format(target)) - # This return has no use, merely to supress pylint warning - return wkl - - def conv2d_nchw(Input, Filter, stride, padding, out_dtype=None): """Convolution operator in NCHW layout. @@ -302,8 +254,7 @@ def conv2d_nhwc(Input, Filter, stride, padding, out_dtype='float32'): @tvm.target.generic_func -def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, - padding, layout, out_layout, out_dtype='float32'): +def conv2d_NCHWc(data, kernel, stride, padding, layout, out_layout, out_dtype='float32'): """Conv2D operator for nChw[x]c layout. Parameters @@ -316,12 +267,6 @@ def conv2d_NCHWc(data, kernel, num_filter, kernel_size, stride, [num_filter_chunk, in_channel_chunk, filter_height, filter_width, in_channel_block, num_filter_block] - num_filter : int - number of filters, i.e., output channel size - - kernel_size : tuple of two ints - [kernel_height, kernel_width] - stride : int or a list/tuple of two ints stride size, or [stride_height, stride_width] diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index f766d827686d..c588e74432a4 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -2,203 +2,15 @@ """Conv2D schedule on x86""" import tvm from tvm import autotvm -from tvm.autotvm.task.dispatcher import ApplyGraphBest from tvm.autotvm.task.nnvm_integration import deserialize_args from tvm.autotvm.task import register, get_config from .. import generic, tag from .. import nn from ..util import get_const_tuple -from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, \ - _get_workload_int8, _get_schedule, _get_schedule_NCHWc, \ - _get_schedule_NCHWc_int8, _get_alter_layout_schedule, Workload +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload from ..nn.pad import pad from . import conv2d_avx_1x1, conv2d_avx_common -from .conv2d_avx_common import AVXConvCommonFwd -from .conv2d_avx_1x1 import AVXConv1x1Fwd -from .check_targets import check_skylake - -@_get_schedule.register("cpu") -def _get_schedule_conv(wkl): - _WORKLOADS_AVX = [ - # workloads of resnet18_v1 on imagenet - Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2), - Workload('float32', 'float32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - Workload('float32', 'float32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - Workload('float32', 'float32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - Workload('float32', 'float32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - # workloads of resnet34_v1 on imagenet, no extra workload required - # workloads of resnet50_v1 on imagenet - Workload('float32', 'float32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), - Workload('float32', 'float32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), - Workload('float32', 'float32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1), - # workloads of resnet101_v1 on imagenet, no extra workload required - # workloads of resnet152_v1 on imagenet, no extra workload required - # workloads of resnet18_v2 on imagenet, no extra workload required - # workloads of resnet34_v2 on imagenet, no extra workload required - ] - - fp32_vec_len = 8 - target = tvm.target.current_target(allow_none=False) - for opt in target.options: - if opt == '-mcpu=skylake-avx512': - fp32_vec_len = 16 - - _SCHEDULES_AVX = [ - # workloads of resnet18_v1 on imagenet - AVXConvCommonFwd(3, fp32_vec_len, 28, False), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), - # workloads of resnet34_v1 on imagenet, no extra workload required - # workloads of resnet50_v1 on imagenet - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - # workloads of resnet101_v1 on imagenet, no extra workload required - # workloads of resnet152_v1 on imagenet, no extra workload required - # workloads of resnet18_v2 on imagenet, no extra workload required - # workloads of resnet34_v2 on imagenet, no extra workload required - ] - - if wkl not in _WORKLOADS_AVX: - if wkl.hkernel == 1 and wkl.wkernel == 1: - return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len) - return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len) - idx = _WORKLOADS_AVX.index(wkl) - sch = _SCHEDULES_AVX[idx] - return sch - -def _get_schedule_conv_int8(wkl): - _WORKLOADS_AVX = [ - ## Following are for INT8 kernels - Workload('uint8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - Workload('uint8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - Workload('uint8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - Workload('uint8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - Workload('uint8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - Workload('uint8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - Workload('uint8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - # workloads of resnet34_v1 on imagenet, no extra workload required - # workloads of resnet50_v1 on imagenet - Workload('uint8', 'int32', 56, 56, 64, 256, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 56, 56, 256, 64, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 56, 56, 256, 128, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 28, 28, 128, 512, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 56, 56, 256, 512, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 28, 28, 512, 128, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 28, 28, 512, 256, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 14, 14, 256, 1024, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 28, 28, 512, 1024, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 14, 14, 1024, 256, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 14, 14, 1024, 512, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 7, 7, 512, 2048, 1, 1, 0, 0, 1, 1), - Workload('uint8', 'int32', 14, 14, 1024, 2048, 1, 1, 0, 0, 2, 2), - Workload('uint8', 'int32', 7, 7, 2048, 512, 1, 1, 0, 0, 1, 1), - ] - - fp32_vec_len = 8 - target = tvm.target.current_target(allow_none=False) - if check_skylake(target): - fp32_vec_len = 16 - - _SCHEDULES_AVX = [ - # Following are for INT8 operations - # workloads of resnet18_v1 on imagenet - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 28), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 28, False), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, False), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 14, True), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 1, 7), - AVXConvCommonFwd(fp32_vec_len, fp32_vec_len, 7, True), - # workloads of resnet34_v1 on imagenet, no extra workload required - # workloads of resnet50_v1 on imagenet - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 28), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 14), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - AVXConv1x1Fwd(fp32_vec_len, fp32_vec_len, 2, 7), - # workloads of resnet101_v1 on imagenet, no extra workload required - # workloads of resnet152_v1 on imagenet, no extra workload required - # workloads of resnet18_v2 on imagenet, no extra workload required - # workloads of resnet34_v2 on imagenet, no extra workload required - ] - - if wkl not in _WORKLOADS_AVX: - if wkl.hkernel == 1 and wkl.wkernel == 1: - return conv2d_avx_1x1._get_default_schedule(wkl, fp32_vec_len) - return conv2d_avx_common._get_default_schedule(wkl, fp32_vec_len) - idx = _WORKLOADS_AVX.index(wkl) - sch = _SCHEDULES_AVX[idx] - return sch - -@_get_schedule_NCHWc.register("cpu") -def _get_schedule_NCHWc_x86(wkl, layout, out_layout): - return _get_schedule_conv(wkl) - -@_get_schedule_NCHWc_int8.register("cpu") -def _get_schedule_NCHWc_x86_int8(wkl, layout, out_layout): - return _get_schedule_conv_int8(wkl) - -@_get_alter_layout_schedule.register("cpu") -def _get_alter_layout_schedule_x86(wkl): - return _get_schedule_conv(wkl) - def _get_fp32_len(): fp32_vec_len = 8 @@ -210,18 +22,23 @@ def _get_fp32_len(): return fp32_vec_len -def _get_default_sch(workload): +def _get_default_config(cfg, workload): + """ + Get default schedule config for the workload + Parameters + ---------- + workload : topi.nn.conv2d.Workload + Convolution workload + """ fp32_vec_len = _get_fp32_len() - _, _, kh, kw, _ = workload[2] - is_kernel_1x1 = kh == 1 and kw == 1 + is_kernel_1x1 = workload.hkernel == 1 and workload.wkernel == 1 if is_kernel_1x1: - cfg = conv2d_avx_1x1._fallback_schedule(workload, fp32_vec_len) + conv2d_avx_1x1._fallback_schedule(cfg, workload, fp32_vec_len) else: - cfg = conv2d_avx_common._fallback_schedule(workload, fp32_vec_len) - return cfg + conv2d_avx_common._fallback_schedule(cfg, workload, fp32_vec_len) -def _create_schedule_template(cfg, data, kernel, strides, padding, layout): +def _create_tuning_space(cfg, data, kernel, strides, padding, layout): """Create schedule configuration from input arguments""" dshape = get_const_tuple(data.shape) kshape = get_const_tuple(kernel.shape) @@ -247,38 +64,17 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, layout): cfg.define_knob("unroll_kw", [True, False]) -def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): - """convert argument to workload""" - if len(kernel.shape) == 4: - raw_kernel = kernel - else: # the input kernel is transformed by alter_op_layout - shape = get_const_tuple(kernel.shape) - raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]), - dtype=kernel.dtype) - return ('conv2d', ) + autotvm.task.args_to_workload( - [data, raw_kernel, strides, padding, layout, out_dtype]) - - -@conv2d.register("cpu") -@autotvm.task.dispatcher -def conv2d_x86(data, kernel, strides, padding, layout, out_dtype): - """x86 conv2d declaration.""" - return conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype) - - -@conv2d_x86.register(["direct"]) +@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') def _declaration_conv(cfg, data, kernel, strides, padding, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) if layout == 'NCHW': - _create_schedule_template(cfg, data, kernel, strides, padding, layout) + _create_tuning_space(cfg, data, kernel, strides, padding, layout) if cfg.is_fallback: - workload = conv_arg_to_workload(data, kernel, strides, padding, - layout, out_dtype) - cfg = _get_default_sch(workload) - args = [cfg, data, kernel, strides, padding, layout, out_dtype] - return _declaration_conv_impl(*args) + wkl = _get_workload(data, kernel, strides, padding, out_dtype) + _get_default_config(cfg, wkl) + return _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtype) elif layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, out_dtype) elif layout == 'NHWC': @@ -345,11 +141,7 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, layout, out_dtyp lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn] .astype(out_dtype), name='output_unpack', - tag='conv2d_nchw', - attrs={'workload': - conv_arg_to_workload(data, kernel, strides, - padding, layout, - out_dtype)}) + tag='conv2d_nchw') return unpack @@ -385,18 +177,7 @@ def traverse(op): _, _, kh, kw = get_const_tuple(kernel.shape) is_kernel_1x1 = kh == 1 and kw == 1 - current_cfg = cfg - if cfg.is_fallback: - workload_attr = op.attrs["workload"] - strides = (int(workload_attr[3][0].value), int(workload_attr[3][1].value)) - padding = (int(workload_attr[4][0].value), int(workload_attr[4][1].value)) - layout = workload_attr[5].value - out_dtype = workload_attr[6].value - workload = conv_arg_to_workload(data, kernel, strides, padding, - layout, out_dtype) - current_cfg = _get_default_sch(workload) - args = [s, current_cfg, data, data_pad, data_vec, kernel_vec, conv_out, - output, outs[0]] + args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]] if is_kernel_1x1: conv2d_avx_1x1._schedule_conv(*args) else: @@ -470,17 +251,13 @@ def traverse(op): @register("topi_x86_conv2d_NCHWc") def _topi_nn_conv2d_NCHWc(*args, **kwargs): assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - data, kernel = args[:2] - strides = args[4] - padding = args[5] - layout = args[6] + data, kernel, strides, padding, origin_layout, dtype = deserialize_args(args) raw_data_shape = get_const_tuple(data.shape) raw_kernel_shape = get_const_tuple(kernel.shape) # get config here cfg = get_config() - _create_schedule_template(cfg, data, kernel, strides, padding, layout) + _create_tuning_space(cfg, data, kernel, strides, padding, origin_layout) # change shape with the value in config ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1], @@ -491,50 +268,13 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): out_layout = "NCHW%dc" % oc_bn new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn, raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn) - args[0] = tvm.placeholder(new_data_shape, data.dtype) - args[1] = tvm.placeholder(new_kernel_shape, kernel.dtype) - args[6] = data_layout - args[7] = out_layout + new_data = tvm.placeholder(new_data_shape, data.dtype) + new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype) - C = _declaration_conv_NCHWc(cfg, *args, **kwargs) - s = _schedule_conv2d_NCHWc(cfg, args[2], args[3], args[4], args[5], - args[6], args[7], [C]) - return s, [args[0], args[1], C] - - -def conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, - padding, layout, out_layout, out_dtype): - """convert argument to workload""" - dshape = get_const_tuple(data.shape) - kshape = get_const_tuple(kernel.shape) - if len(dshape) > 4: - raw_data = tvm.placeholder((dshape[0], dshape[1] * dshape[4], dshape[2], - dshape[3]), dtype=kernel.dtype) - else: - raw_data = data - if len(kshape) > 4: - raw_kernel = tvm.placeholder((kshape[0] * kshape[5], kshape[1] * kshape[4], - kshape[2], kshape[3]), dtype=kernel.dtype) - else: - raw_kernel = kernel - return ('conv2d_NCHWc', ) + autotvm.task.args_to_workload( - [raw_data, raw_kernel, strides, padding, layout, out_layout, - out_dtype]) - - -def _query_dispatcher(workload, in_alter_op=False): - dispatch_ctx = autotvm.task.DispatchContext.current - if isinstance(dispatch_ctx, ApplyGraphBest): - if in_alter_op: - cfg = dispatch_ctx.query(None, None) - else: - cfg = dispatch_ctx.query_global_dict(workload) - else: - target = tvm.target.current_target() - cfg = dispatch_ctx.query(target, workload) - if cfg.is_fallback: - cfg = _get_default_sch(workload) - return cfg + C = _declaration_conv_NCHWc(cfg, new_data, new_kernel, strides, padding, + data_layout, out_layout, dtype) + s = _schedule_conv2d_NCHWc(cfg, [C]) + return s, [new_data, new_kernel, C] @conv2d_alter_layout.register("cpu") @@ -546,87 +286,72 @@ def _alter_conv2d_layout(attrs, inputs, tinfo): # only optimize for NCHW, groups=1 conv if attrs['layout'] != 'NCHW' or attrs.get_int("groups") != 1: return None + batch_size, in_channel, height, width = get_const_tuple(data.shape) + out_channel, _, kh, kw = get_const_tuple(kernel.shape) - kernel_size = attrs.get_int_tuple("kernel_size") padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") layout = attrs['layout'] - out_layout = layout if attrs["out_layout"] == "__undef__" else attrs["out_layout"] dtype = data.dtype out_dtype = dtype if attrs["out_dtype"] == "same" else attrs["out_dtype"] - workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, - padding, layout, out_layout, out_dtype) - cfg = _query_dispatcher(workload, True) - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - new_attrs['layout'] = 'NCHW%dc' % ic_bn - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - # Store global schedule dictionary for ApplyGraphBest dispatcher + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, layout, out_dtype], conv2d) dispatch_ctx = autotvm.task.DispatchContext.current - if isinstance(dispatch_ctx, ApplyGraphBest): - workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, - padding, new_attrs['layout'], - new_attrs['out_layout'], out_dtype) - global_dict_key = workload - dispatch_ctx.update_global_dict(global_dict_key, cfg) + target = tvm.target.current_target() + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: + wkl = _get_workload(data, kernel, strides, padding, out_dtype) + _get_default_config(cfg, wkl) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + new_attrs['layout'] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + # Store altered operator's config + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data.dtype) + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, new_attrs['layout'], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + dispatch_ctx.update(target, new_workload, cfg) - -@conv2d_NCHWc.register("cpu") -def conv2d_NCHWc_cpu(data, kernel, num_filter, kernel_size, strides, - padding, layout, out_layout, out_dtype): - """x86 conv2d_NCHWc declaration.""" - dispatch_ctx = autotvm.task.DispatchContext.current - if not isinstance(dispatch_ctx, ApplyGraphBest): - layout = out_layout = "NCHW" - workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, - padding, layout, out_layout, out_dtype) - cfg = _query_dispatcher(workload) - return _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides, - padding, layout, out_layout, out_dtype) + return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) -def _declaration_conv_NCHWc(cfg, data, kernel, num_filter, kernel_size, strides, +@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct') +def _declaration_conv_NCHWc(cfg, data, kernel, strides, padding, layout, out_layout, out_dtype): - n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] - ic = ic_chunk * ic_block - kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \ - (kernel_size, kernel_size) - is_kernel_1x1 = kh == 1 and kw == 1 - ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) - sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding) + HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - wkl = _get_workload_int8(tvm.placeholder((n, ic, h, w), dtype=data.dtype), - tvm.placeholder((num_filter, ic, kh, kw), - dtype=kernel.dtype), - strides, padding, out_dtype) - sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) - return conv2d_avx_1x1._declaration_conv_NCHWc_int8(wkl, sch, data, kernel) \ - if is_kernel_1x1 \ - else conv2d_avx_common._declaration_conv_NCHWc_int8(wkl, sch, data, kernel) - - args = [cfg, data, kernel, (kh, kw), (sh, sw), (ph, pw), layout, out_layout, out_dtype] - return _declaration_conv_NCHWc_impl(*args) - - -def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, padding, layout, - out_layout, out_dtype): - HPAD, WPAD = padding - HSTR, WSTR = strides - - n, ic_chunk, ih, iw, ic_block = get_const_tuple(data.shape) - ic = ic_chunk * ic_block - kh, kw = kernel_size - oc_chunk, _, _, _, _, oc_block = get_const_tuple(kernel.shape) - oc = oc_chunk * oc_block - oh = (ih + 2 * HPAD - kh) // HSTR + 1 - ow = (iw + 2 * WPAD - kw) // WSTR + 1 + oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + else: + oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + num_filter = oc_chunk * oc_bn + + # get workload and related schedule config + wkl = _get_workload(tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), + dtype=kernel.dtype), + strides, padding, out_dtype) + if cfg.is_fallback: + _get_default_config(cfg, wkl) + + # output shape + out_height = (ih + 2 * HPAD - kernel_height) // HSTR + 1 + out_width = (iw + 2 * WPAD - kernel_width) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) # DOPAD DOPAD = (HPAD != 0 or WPAD != 0) @@ -635,51 +360,43 @@ def _declaration_conv_NCHWc_impl(cfg, data, kernel, kernel_size, strides, paddin else: data_pad = data - # fetch schedule - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - if ic_bn != ic_block: - raise RuntimeError("ic_bn in config is not equal to actual data ic_block: %d vs %d." - % (ic_bn, ic_block)) - if oc_bn != oc_block: - raise RuntimeError("oc_bn in config is not equal to actual kernel oc_block: %d vs %d." - % (oc_bn, oc_block)) - - # convolution - oshape = (n, oc//oc_bn, oh, ow, oc_bn) - - ic = tvm.reduce_axis((0, ic), name='ic') - kh = tvm.reduce_axis((0, kernel_size[0]), name='kh') - kw = tvm.reduce_axis((0, kernel_size[1]), name='kw') + ic = tvm.reduce_axis((0, in_channel), name='ic') + kh = tvm.reduce_axis((0, kernel_height), name='kh') + kw = tvm.reduce_axis((0, kernel_width), name='kw') - workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, - strides, padding, layout, - out_layout, out_dtype), - attrs = {'workload': workload} - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + if data.dtype == 'uint8': + assert out_dtype == "int32", \ + "INT8 convolution requires input dtype = uint8 and output dtype=int32" + # Intel performs dot product of 2 "4" Int8 values + # Current implementation requires ic_bn to be a multiple of 4 + n_elems = 4 + assert ic_bn % n_elems == 0 + + ic_outer = tvm.reduce_axis((0, wkl.in_filter//ic_bn), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: + tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner] + .astype(out_dtype) * + kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation + return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, ic%ic_bn].astype(out_dtype) * kernel[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn, oc_block], axis=[ic, kh, kw]), - name='conv2d_NCHWc', tag="conv2d_NCHWc", attrs=attrs) - return conv - - -@generic.schedule_conv2d_NCHWc.register("cpu") -def schedule_conv2d_NCHWc(num_filter, kernel_size, strides, padding, - layout, out_layout, outs): - """x86 conv2d_NCHWc schedule""" - return _schedule_conv2d_NCHWc(None, num_filter, kernel_size, strides, padding, - layout, out_layout, outs) + name='conv2d_NCHWc', tag="conv2d_NCHWc") -def _schedule_conv2d_NCHWc(cfg, num_filter, kernel_size, strides, padding, - layout, out_layout, outs): +@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct']) +def _schedule_conv2d_NCHWc(cfg, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) scheduled_ops = [] - dispatch_ctx = autotvm.task.DispatchContext.current - if not isinstance(dispatch_ctx, ApplyGraphBest): - layout = out_layout = "NCHW" def traverse(op): """Traverse operators from computation graph""" @@ -702,34 +419,17 @@ def traverse(op): data_pad = data data = data_pad.op.input_tensors[0] - kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)) else \ - (kernel_size, kernel_size) - is_kernel_1x1 = kh == 1 and kw == 1 - n, ic_chunk, h, w, ic_block = [x.value for x in data.shape] - ic = ic_chunk * ic_block - original_data = tvm.placeholder((n, ic, h, w), dtype=data.dtype) - - kh, kw = kernel_size - original_kernel = tvm.placeholder((num_filter, ic, kh, kw), - dtype=kernel.dtype) + args = [s, cfg, data_vec, conv_out, outs[0]] if data.dtype == 'uint8': - wkl = _get_workload_int8(original_data, original_kernel, - strides, padding, conv_out.dtype) - sch = _get_schedule_NCHWc_int8(wkl, layout, out_layout) - args = [s, wkl, sch, data_vec, kernel, conv_out, outs[0]] - if is_kernel_1x1: + # int8 conv kernel is 7-dim + _, _, kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: conv2d_avx_1x1._schedule_conv_NCHWc_int8(*args) else: conv2d_avx_common._schedule_conv_NCHWc_int8(*args) else: - current_cfg = cfg - if current_cfg is None: - workload = conv_NCHWc_arg_to_workload(data, kernel, kernel_size, strides, - padding, layout, out_layout, - conv_out.dtype) - current_cfg = _query_dispatcher(workload) - args = [s, current_cfg, data_vec, conv_out, outs[0]] - if is_kernel_1x1: + _, _, kh, kw, _, _, = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: conv2d_avx_1x1._schedule_conv_NCHWc(*args) else: conv2d_avx_common._schedule_conv_NCHWc(*args) diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 96affc7b9d23..ce70ec83828b 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -1,21 +1,15 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name """1x1 Conv2D schedule on for Intel CPU""" from __future__ import absolute_import as _abs -from collections import namedtuple import tvm -from tvm.autotvm.task import ConfigEntity - -import topi +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.util import infer_pad -from ..nn.pad import pad +from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake -AVXConv1x1Fwd = namedtuple('AVXConv1x1Fwd', ['ic_bn', 'oc_bn', 'oh_factor', 'ow_factor']) - - -def _get_default_schedule(wkl, simd_width): +def _fallback_schedule(cfg, wkl, simd_width): HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 @@ -37,45 +31,11 @@ def _get_default_schedule(wkl, simd_width): if out_width % ow_factor == 0: for oh_factor in range(out_height, 0, -1): if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: - return AVXConv1x1Fwd(ic_bn, oc_bn, oh_factor, ow_factor) - - raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) - - -def _fallback_schedule(wkl, simd_width): - batch_size, in_channel, height, width, _ = wkl[1] - out_channel, _, hkernel, wkernel, _ = wkl[2] - HPAD, WPAD = wkl[4] - HSTR, WSTR = wkl[3] - out_height = (height + 2 * HPAD - hkernel) // HSTR + 1 - out_width = (width + 2 * WPAD - wkernel) // WSTR + 1 - - oc_bn = 1 - for bn in range(simd_width, 0, -1): - if out_channel % bn == 0: - oc_bn = bn - break - - ic_bn = 1 - for bn in range(oc_bn, 0, -1): - if in_channel % bn == 0: - ic_bn = bn - break - - for ow_factor in range(out_width, 0, -1): - if out_width % ow_factor == 0: - for oh_factor in range(out_height, 0, -1): - if out_height % oh_factor == 0 and ow_factor * oh_factor < 32: - cfg_dict = {"i": -1, - "c": None, - "e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]], - ["tile_oc", "sp", [out_channel // oc_bn, oc_bn]], - ["tile_oh", "ot", oh_factor], - ["tile_ow", "sp", [out_width // ow_factor, - ow_factor]],], - "t": ""} - return ConfigEntity.from_json_dict(cfg_dict) - + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_oh"] = OtherOptionEntity(oh_factor) + cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor]) + return raise ValueError("cannot decide default schedule for workload: {}".format(wkl)) @@ -148,8 +108,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): # fetch schedule - ic_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oh"].val, - cfg["tile_ow"].size[-1]) + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + _, _, _, _, ic_bn = get_const_tuple(data.shape) # schedule data A = data @@ -201,57 +161,13 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): return s -def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): - """ Declaration for int8 conv""" - out_dtype = wkl.out_dtype - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - batch_size = data.shape[0] - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - else: - data_pad = data - - oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) - - # Intel performs dot product of 2 "4" Int8 values - n_elems = 4 - assert sch.ic_bn%n_elems == 0 - ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer') - ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner') - ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') - - # Reshaping kernel as the last 2 dimensions are 1x1 (k_h x k_w) - k_shape = kernel.shape - kernel = topi.reshape(kernel, (k_shape[0], k_shape[1], k_shape[2], k_shape[3], - k_shape[4] * k_shape[5] * k_shape[6])) - - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic_outer, oh*HSTR, ow*WSTR, - ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * - kernel[oc_chunk, ic_outer, ic_f_inner, - oc_block, ic_s_inner].astype(out_dtype), - axis=[ic_outer, ic_f_inner, ic_s_inner]), - name='conv2d_NCHWc_int8', - tag="conv2d_NCHWc_int8") - - - return conv - - -def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): +def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): """ Defines the schedule for INT8 for intel machines Uses the Intel intrinsics to use INT8 operations More details - https://software.intel.com/en-us/articles/ lower-numerical-precision-deep-learning-inference-and-training """ - target = tvm.target.current_target(allow_none=False) int32_lanes = -1 if check_skylake(target): @@ -260,6 +176,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): return s assert int32_lanes != -1 + oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1] + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + # schedule data A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): @@ -271,8 +191,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): CC = s.cache_write(C, 'global') batch, oc_chunk, oh, ow, oc_block = s[C].op.axis - oh_outer, oh_inner = s[C].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[C].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[C].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[C].split(ow, factor=ow_factor) s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) s[C].vectorize(oc_block) @@ -282,17 +202,17 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): s[C].parallel(parallel_axis) _, oc_chunk, oh, ow, oc_block = s[CC].op.axis - ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis + kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis # Skylake and future processors have 16 vector lanes - assert sch.oc_bn % int32_lanes == 0 + assert oc_bn % int32_lanes == 0 oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - oh_outer, oh_inner = s[CC].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[CC].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor) - s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_outer, ic_f_inner, oh_inner, + s[CC].reorder(oc_chunk, oh_outer, ow_outer, kh, kw, ic_outer, ic_f_inner, oh_inner, ow_inner, oc_f_inner, oc_s_inner, ic_s_inner) s[CC].fuse(oc_chunk, oh_outer) @@ -303,8 +223,8 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - oh_outer, oh_inner = s[O].split(oh, factor=sch.oh_factor) - ow_outer, ow_inner = s[O].split(ow, factor=sch.ow_factor) + oh_outer, oh_inner = s[O].split(oh, factor=oh_factor) + ow_outer, ow_inner = s[O].split(ow, factor=ow_factor) s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh_outer) diff --git a/topi/python/topi/x86/conv2d_avx_common.py b/topi/python/topi/x86/conv2d_avx_common.py index eaa3d15e64b0..e52722ed54a7 100644 --- a/topi/python/topi/x86/conv2d_avx_common.py +++ b/topi/python/topi/x86/conv2d_avx_common.py @@ -1,19 +1,15 @@ # pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name """Conv2D schedule on for Intel CPU""" from __future__ import absolute_import as _abs -from collections import namedtuple import tvm -from tvm.autotvm.task import ConfigEntity +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity from ..nn.util import infer_pad -from ..nn.pad import pad +from ..util import get_const_tuple from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake -AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) - - -def _get_default_schedule(wkl, simd_width): +def _fallback_schedule(cfg, wkl, simd_width): HPAD, WPAD = wkl.hpad, wkl.wpad HSTR, WSTR = wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 @@ -36,42 +32,10 @@ def _get_default_schedule(wkl, simd_width): reg_n = n break - return AVXConvCommonFwd(ic_bn, oc_bn, reg_n, False) - - -def _fallback_schedule(wkl, simd_width): - batch_size, in_channel, height, width, _ = wkl[1] - out_channel, _, hkernel, wkernel, _ = wkl[2] - HPAD, WPAD = wkl[4] - HSTR, WSTR = wkl[3] - out_width = (width + 2 * WPAD - wkernel) // WSTR + 1 - - oc_bn = 1 - for bn in range(simd_width, 0, -1): - if out_channel % bn == 0: - oc_bn = bn - break - - ic_bn = 1 - for bn in range(oc_bn, 0, -1): - if in_channel % bn == 0: - ic_bn = bn - break - - reg_n = 1 - for n in range(31, 0, -1): - if out_width % n == 0: - reg_n = n - break - - cfg_dict = {"i": -1, - "c": None, - "e": [["tile_ic", "sp", [in_channel // ic_bn, ic_bn]], - ["tile_oc", "sp", [out_channel // oc_bn, oc_bn]], - ["tile_ow", "sp", [out_width // reg_n, reg_n]], - ["unroll_kw", "ot", False]], - "t": ""} - return ConfigEntity.from_json_dict(cfg_dict) + cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn]) + cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n]) + cfg["unroll_kw"] = OtherOptionEntity(False) def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last): @@ -147,8 +111,8 @@ def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, outpu def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): # fetch schedule - ic_bn, reg_n, unroll_kw = (cfg["tile_ic"].size[-1], cfg["tile_ow"].size[-1], - cfg["unroll_kw"].val) + reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data.shape) # schedule data A = data @@ -197,52 +161,7 @@ def _schedule_conv_NCHWc(s, cfg, data, conv_out, last): return s -def _declaration_conv_NCHWc_int8(wkl, sch, data, kernel): - """ - This function sets up the compute for INT8 conv 2d - Inputs are in INT8 datatype - Output is in INT32 datatype - """ - - out_dtype = wkl.out_dtype - HPAD, WPAD = wkl.hpad, wkl.wpad - HSTR, WSTR = wkl.hstride, wkl.wstride - - batch_size = data.shape[0] - out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1 - out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 - - # pack data - DOPAD = (HPAD != 0 or WPAD != 0) - if DOPAD: - data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") - else: - data_pad = data - - # convolution - oshape = (batch_size, wkl.out_filter//sch.oc_bn, out_height, out_width, sch.oc_bn) - kh = tvm.reduce_axis((0, wkl.hkernel), name='kh') - kw = tvm.reduce_axis((0, wkl.wkernel), name='kw') - - # Intel performs dot product of 2 "4" Int8 values - # Current implementation requires ic_bn to be a multiple of 4 - n_elems = 4 - assert sch.ic_bn%n_elems == 0 - - ic_outer = tvm.reduce_axis((0, wkl.in_filter//(sch.ic_bn)), name='ic_outer') - ic_f_inner = tvm.reduce_axis((0, sch.ic_bn//n_elems), name='ic_f_inner') - ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') - conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: - tvm.sum(data_pad[n, ic_outer, oh*HSTR+kh, ow*WSTR+kw, - ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * - kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, - oc_block, ic_s_inner].astype(out_dtype), - axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), - name='conv2d_NCHWc_int8', - tag="conv2d_NCHWc_int8") - return conv - -def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): +def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): """ Defines the schedule for INT8 for intel machines Uses the Intel intrinsics to use INT8 operations @@ -263,6 +182,10 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): return s assert int32_lanes != -1 + reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val + _, _, _, _, ic_bn = get_const_tuple(data.shape) + _, _, _, _, oc_bn = get_const_tuple(conv_out.shape) + A = data if isinstance(s[A].op, tvm.tensor.ComputeOp): batch, ic_chunk, ih, iw, _ = s[A].op.axis @@ -274,7 +197,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): CC = s.cache_write(C, 'global') _, oc_chunk, oh, ow, oc_block = s[C].op.axis - ow_chunk, ow_block = s[C].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[C].split(ow, factor=reg_n) s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[C].fuse(oc_chunk, oh) s[C].vectorize(oc_block) @@ -285,14 +208,14 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): _, oc_chunk, oh, ow, oc_block = s[CC].op.axis kh, kw, ic_outer, ic_f_inner, ic_s_inner = s[CC].op.reduce_axis - ow_chunk, ow_block = s[CC].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[CC].split(ow, factor=reg_n) # Skylake and future processors have 16 vector lanes - assert sch.oc_bn % int32_lanes == 0 + assert oc_bn % int32_lanes == 0 oc_f_inner, oc_s_inner = s[CC].split(oc_block, factor=int32_lanes) - if sch.unroll_kw: + if unroll_kw: s[CC].reorder(oc_chunk, oh, ow_chunk, ic_outer, kh, ic_f_inner, kw, ow_block, oc_f_inner, oc_s_inner, ic_s_inner) s[CC].unroll(kw) @@ -308,7 +231,7 @@ def _schedule_conv_NCHWc_int8(s, wkl, sch, data, kernel, conv_out, last): if C != O: batch, oc_chunk, oh, ow, oc_block = s[O].op.axis - ow_chunk, ow_block = s[O].split(ow, factor=sch.reg_n) + ow_chunk, ow_block = s[O].split(ow, factor=reg_n) s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block) parallel_axis = s[O].fuse(oc_chunk, oh) s[C].compute_at(s[O], parallel_axis) diff --git a/topi/recipe/conv/test_conv_int8_intel.py b/topi/recipe/conv/test_conv_int8_intel.py index 863b3a6a41ab..593f913db15d 100644 --- a/topi/recipe/conv/test_conv_int8_intel.py +++ b/topi/recipe/conv/test_conv_int8_intel.py @@ -54,19 +54,11 @@ def get_shape(im_height, im_width, in_filter, out_filter, k_h, k_w, hpad, wpad, data_shape = (1, in_filter//NUM_VEC_LANES, im_height, im_width, NUM_VEC_LANES) if out_dtype == 'int32': - if k_h != 1: - kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, - NUM_VEC_LANES//4, NUM_VEC_LANES, 4) - else: - kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES//4, - NUM_VEC_LANES, 4, k_h, k_w) + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES//4, NUM_VEC_LANES, 4) elif out_dtype == 'float32': - if k_h != 1: - kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, - NUM_VEC_LANES, NUM_VEC_LANES) - else: - kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, NUM_VEC_LANES, - NUM_VEC_LANES, k_h, k_w) + kernel_shape = (out_filter//NUM_VEC_LANES, in_filter//NUM_VEC_LANES, k_h, k_w, + NUM_VEC_LANES, NUM_VEC_LANES) out_height = (im_height + 2 * hpad - k_h) // hstride + 1 out_width = (im_width + 2 * wpad - k_w) // wstride + 1 o_shape = (1, out_filter//NUM_VEC_LANES, out_height, out_width, NUM_VEC_LANES) @@ -103,8 +95,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f with tvm.target.create(TARGET_NAME): - conv = topi.nn.conv2d_NCHWc(data, kernel, num_filter=out_filter, - kernel_size=(k_h, k_w), stride=hstride, + conv = topi.nn.conv2d_NCHWc(data, kernel, stride=hstride, padding=hpad, layout='NCHWc', out_layout='NCHWc', out_dtype=out_dtype) out = topi.nn.relu(conv) @@ -114,13 +105,7 @@ def run_inference(data_dtype, kernel_dtype, out_dtype, im_height, im_width, in_f LOGGER.debug(tvm.lower(sch, [data, kernel], simple_mode=True)) # Generate and run the optimized schedule - sconv = topi.generic.nn.schedule_conv2d_NCHWc(num_filter=out_filter, - kernel_size=(k_h, k_w), - strides=hstride, - padding=hpad, - layout='NCHWc', - out_layout='NCHWc', - outs=[out]) + sconv = topi.generic.nn.schedule_conv2d_NCHWc(outs=[out]) func = tvm.build(sconv, [data, kernel, out], target=TARGET_NAME, name='conv') func(data_array, kernel_array, c_sch) diff --git a/topi/tests/python/test_topi_conv2d_NCHWc.py b/topi/tests/python/test_topi_conv2d_NCHWc.py new file mode 100644 index 000000000000..38e6ad6d9e7c --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc.py @@ -0,0 +1,206 @@ +"""Test for NCHW[x]c convolution""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.transpose(data, (0, 2, 3, 1)) + data = np.reshape(data, (batch_size, height, width, channel//bn, bn)) + data = np.transpose(data, (0, 3, 1, 2, 4)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.transpose(kernel, (1, 2, 3, 0)) + kernel = np.reshape(kernel, (in_channel, kh, kw, out_channel//oc_bn, oc_bn)) + kernel = np.transpose(kernel, (1, 2, 3, 4, 0)) + kernel = np.reshape(kernel, (kh, kw, out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn)) + kernel = np.transpose(kernel, (2, 4, 0, 1, 5, 3)) + return kernel + +def _transform_bias(bias, bn): + # [num_filter, 1, 1] -> [num_filter//bn, 1, 1, bn] + num_filter, h, w = bias.shape + bias = np.transpose(bias, (1, 2, 0)) + bias = np.reshape(bias, (h, w, num_filter//bn, bn)) + bias = np.transpose(bias, (2, 0, 1, 3)) + return bias + +def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"): + assert dilation == 1, "conv2d_NCHWc does not support dilation for now." + print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 1 + for bn in range(oc_block, 0, -1): + if in_channel % bn == 0: + ic_block = bn + break + + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block, kernel, kernel, ic_block, oc_block), name='W') + bias = tvm.placeholder((num_filter//oc_block, 1, 1, oc_block), name='bias') + + @memoize("topi.tests.test_topi_conv2d_NCHWc.verify_conv2d_NCHWc") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) + w_np = np.random.uniform(size=(num_filter, in_channel, kernel, kernel)).astype(dtype) + b_np = np.random.uniform(size=(num_filter, 1, 1)).astype(dtype) + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) + if add_bias: + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_bias(b_np, oc_block), _transform_data(c_np, oc_block) + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + layout='NCHW%dc'%ic_block, + out_layout="NCHW%dc"%oc_block, + out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + # test llvm only for now since conv2d_NCHWc implement is missing in other backend. + for device in ["llvm"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +if __name__ == "__main__": + # ResNet18 workloads + verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, 3) + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1) + verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, 0) + verify_conv2d_NCHWc(1, 64, 56, 128, 3, 2, 1) + verify_conv2d_NCHWc(1, 64, 56, 128, 1, 2, 0) + verify_conv2d_NCHWc(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_NCHWc(1, 128, 28, 256, 3, 2, 1) + verify_conv2d_NCHWc(1, 128, 28, 256, 1, 2, 0) + verify_conv2d_NCHWc(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_NCHWc(1, 256, 14, 512, 3, 2, 1) + verify_conv2d_NCHWc(1, 256, 14, 512, 1, 2, 0) + verify_conv2d_NCHWc(1, 512, 7, 512, 3, 1, 1) + + # bias, relu + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_relu=True) + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True) + verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, add_bias=True, add_relu=True) + + # disable dilation test since it is not supported by NCHW[x]c conv for now. + # verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, 1, dilation=2) + + # batch size + verify_conv2d_NCHWc(4, 64, 56, 64, 3, 1, 1) + verify_conv2d_NCHWc(9, 64, 56, 64, 3, 1, 1) + + # weird workloads + verify_conv2d_NCHWc(2, 2, 2, 2, 2, 2, 2) + verify_conv2d_NCHWc(3, 3, 3, 3, 3, 3, 3) + verify_conv2d_NCHWc(4, 4, 4, 4, 4, 4, 4) + verify_conv2d_NCHWc(5, 5, 5, 5, 5, 5, 5) + verify_conv2d_NCHWc(6, 6, 6, 6, 6, 6, 6) + + # disable these tests due to some bugs of llvm with nvptx + # verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=1) + # verify_conv2d_NCHWc(1, 1, 1, 1, 1, 1, 1, dilation=2) + # verify_conv2d_NCHWc(2, 13, 71, 59, 3, 1, 1) + + # inception v3 workloads + verify_conv2d_NCHWc(1, 3, 299, 32, 3, 2, 0) + verify_conv2d_NCHWc(1, 32, 149, 32, 3, 1, 0) + verify_conv2d_NCHWc(1, 32, 147, 64, 3, 1, 1) + verify_conv2d_NCHWc(1, 64, 73, 80, 1, 1, 0) + verify_conv2d_NCHWc(1, 80, 73, 192, 3, 1, 0) + verify_conv2d_NCHWc(1, 192, 35, 64, 1, 1, 0) + verify_conv2d_NCHWc(1, 192, 35, 48, 1, 1, 0) + verify_conv2d_NCHWc(1, 48, 35, 64, 5, 1, 2) + verify_conv2d_NCHWc(1, 64, 35, 96, 3, 1, 1) + verify_conv2d_NCHWc(1, 96, 35, 96, 3, 1, 1) + verify_conv2d_NCHWc(1, 192, 35, 32, 1, 1, 0) + verify_conv2d_NCHWc(1, 256, 35, 64, 1, 1, 0) + verify_conv2d_NCHWc(1, 256, 35, 48, 1, 1, 0) + verify_conv2d_NCHWc(1, 288, 35, 64, 1, 1, 0) + verify_conv2d_NCHWc(1, 288, 35, 48, 1, 1, 0) + verify_conv2d_NCHWc(1, 288, 35, 384, 3, 2, 0) + verify_conv2d_NCHWc(1, 96, 35, 96, 3, 2, 0) + verify_conv2d_NCHWc(1, 768, 17, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 768, 17, 128, 1, 1, 0) + verify_conv2d_NCHWc(1, 128, 17, 128, 1, 1, 0) + verify_conv2d_NCHWc(1, 128, 17, 192, 7, 1, 3) + verify_conv2d_NCHWc(1, 128, 17, 128, 7, 1, 3) + verify_conv2d_NCHWc(1, 128, 17, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 768, 17, 160, 1, 1, 0) + verify_conv2d_NCHWc(1, 160, 17, 160, 1, 1, 0) + verify_conv2d_NCHWc(1, 160, 17, 192, 7, 1, 3) + verify_conv2d_NCHWc(1, 160, 17, 160, 7, 1, 3) + verify_conv2d_NCHWc(1, 160, 17, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 192, 17, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 192, 17, 192, 7, 1, 3) + verify_conv2d_NCHWc(1, 192, 17, 320, 3, 2, 0) + verify_conv2d_NCHWc(1, 192, 17, 192, 3, 2, 0) + verify_conv2d_NCHWc(1, 1280, 8, 320, 1, 1, 0) + verify_conv2d_NCHWc(1, 1280, 8, 384, 1, 1, 0) + verify_conv2d_NCHWc(1, 384, 8, 384, 1, 1, 0) + verify_conv2d_NCHWc(1, 384, 8, 384, 3, 1, 1) + verify_conv2d_NCHWc(1, 1280, 8, 448, 1, 1, 0) + verify_conv2d_NCHWc(1, 448, 8, 384, 3, 1, 1) + verify_conv2d_NCHWc(1, 1280, 8, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 2048, 8, 320, 1, 1, 0) + verify_conv2d_NCHWc(1, 2048, 8, 384, 1, 1, 0) + verify_conv2d_NCHWc(1, 2048, 8, 448, 1, 1, 0) + verify_conv2d_NCHWc(1, 2048, 8, 192, 1, 1, 0) + verify_conv2d_NCHWc(1, 1024, 19, 84, 3, 1, 1) + verify_conv2d_NCHWc(1, 2048, 10, 126, 3, 1, 1) + verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1) + verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1) diff --git a/tutorials/autotvm/tune_nnvm_x86.py b/tutorials/autotvm/tune_nnvm_x86.py index efd1ee4e1a12..18f1117dc68a 100644 --- a/tutorials/autotvm/tune_nnvm_x86.py +++ b/tutorials/autotvm/tune_nnvm_x86.py @@ -14,7 +14,6 @@ import tvm from tvm import autotvm from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner -from topi.x86.conv2d import conv_NCHWc_arg_to_workload import tvm.contrib.graph_runtime as runtime ################################################################# @@ -118,17 +117,9 @@ def tune_kernels(tasks, prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) # converting conv2d tasks to conv2d_NCHWc tasks - # data, kernel are tuples of ("TENSOR", shape, dtype) - data, kernel, strides, padding, layout, dtype = tsk.args - kernel_size = (kernel[1][2], kernel[1][3]) - data_plc = tvm.placeholder(data[1], name="data") - kernel_plc = tvm.placeholder(kernel[1], name="kernel") - args = [data_plc, kernel_plc, kernel[1][0], kernel_size, strides, - padding, layout, layout, dtype] - args = autotvm.task.nnvm_integration.serialize_args(args) - task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target) - task.workload = conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size, - strides, padding, layout, layout, dtype) + task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=tsk.args, + target=target, template_key='direct') + task.workload = tsk.workload # create tuner if tuner == 'xgb' or tuner == 'xgb-rank':