[TOPI] depthwise-conv2d in NCHW[x]c layout for x86#2045
[TOPI] depthwise-conv2d in NCHW[x]c layout for x86#2045tqchen merged 11 commits intoapache:masterfrom
Conversation
|
cc @FrozenGene , what is the relation of this PR with #2028? |
|
@tqchen this is for x86, #2028 is for arm. I'm a little curious about #2028 , I tested ARM schedule on CPU (basically c5.9xlarge) , @FrozenGene 's branch got ~21.4 ms (I added registers for intel cpu manually since #2028 removes them), while previous ARM schedule (in current tvm) got ~2.2ms. I know it is not fair to benchmark ARM schedule on x86 CPU and it is not tuned, but I believe current ARM depthwise schedule is also not tuned on x86 CPU - such a large performance drop looks somewhat weird. @merrymercy @FrozenGene Could you double check? I can also help to test that on ARM once I got a device. Another comment is, on x86, the most efficient layout for normal conv2d turns to be NCHW[x]c, so by having depthwise-conv remain the same layout, we can get rid of layout transformation between layers. Thus I believe on x86, solution in this PR would be better than NCHW ARM schedule. |
|
We should add dilation arguments as in #1970, then we don't have to convert log when we start to optimize for dilation. |
As @merrymercy said, we haven't uploaded any tuned config logs for #2028 . I want to know how do you compare? Do you use these two schedules to tune on x86 CPU and run or just run without tuning? If you tune, as #2028 said, you should notice that You should make the XGBTunner constructor’s feature type argument be feature_type= 'knob'. i.e. XGBTuner(tsk, loss_type='rank', feature_type='knob'). |
|
@FrozenGene I run without tuning. My question is rather about the default/fallback schedule, looks like the previous one is far better. When running the previous one, I also got warnings like, Let's continue this discussion in #2028 |
|
@merrymercy Thanks, I'll checkout and verify. |
| padding = attrs.get_int_tuple("padding") | ||
| strides = attrs.get_int_tuple("strides") | ||
| dilation = attrs.get_int_tuple("dilation") | ||
| channels = attrs.get_int("channels") |
There was a problem hiding this comment.
Just for curiosity. Are channels here "out channels? If yes, it will better to name it appropriately for clarity
| from ..util import simplify | ||
|
|
||
| # workload description of depthwise-conv2d | ||
| Workload = namedtuple('Workload', |
There was a problem hiding this comment.
Do we want dilation in here? Or that's for separate PR?
There was a problem hiding this comment.
Workload here is for getting default schedule, since dilation so far does not impact how we calculate configs, I'd rather keep it simple for now.
There was a problem hiding this comment.
Makes sense. This is resolved from my side.
| data_pad = data | ||
|
|
||
| # depthconv stage | ||
| di = tvm.reduce_axis((0, filter_height), name='di') |
There was a problem hiding this comment.
Are kh, kw better names for di, dj? Just trying to be consistent with other Intel cpu schedules.
| dj = tvm.reduce_axis((0, filter_width), name='dj') | ||
| Output = tvm.compute( | ||
| (batch, out_channel_chunk, out_height, out_width, out_channel_block), | ||
| lambda b, oco, i, j, oci: tvm.sum( |
There was a problem hiding this comment.
Same as above with oh, ow for i, j?
| s[C].vectorize(ic_block) | ||
| parallel_axis = s[C].fuse(ic_chunk, oh) | ||
| s[C].parallel(parallel_axis) | ||
| s[C].unroll(ow_block) |
There was a problem hiding this comment.
For curiosity, do we need this if we are unrolling s[CC] block later?
There was a problem hiding this comment.
no we don't, I'll also remove the vectorize above.
| _, ic_chunk, oh, ow, ic_block = s[C].op.axis | ||
| ow_chunk, ow_block = s[C].split(ow, factor=tile_ow) | ||
| s[C].reorder(ic_chunk, oh, ow_chunk, ow_block, ic_block) | ||
| s[C].vectorize(ic_block) |
There was a problem hiding this comment.
For curiosity, do we need this if we are vectorizing ic_block for s[CC] later?
| dtype=DepthwiseConv2d.dtype), ctx) | ||
| relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx) | ||
| # launch kernel 1 (depthwise_conv2d) | ||
| timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=1) |
There was a problem hiding this comment.
If 'number' here means how many times we run the experiment, then we should have a higher number.
There was a problem hiding this comment.
well, since here we measure only functionality. I'd rather remove the time_evaluator, only do f1(...) instead.
| def _transform_data(data, bn): | ||
| # NCHW -> NCHW[x]c | ||
| batch_size, channel, height, width = data.shape | ||
| data = np.transpose(data, (0, 2, 3, 1)) |
There was a problem hiding this comment.
First reshape and then transpose? Only need one transpose here.
| # channel, channel_multiplier, kh, kw -> out_channel_chunk, kh, kw, out_channel_block | ||
| channel, channel_multiplier, kh, kw = kernel.shape | ||
| out_channel = channel * channel_multiplier | ||
| kernel = np.transpose(kernel, (2, 3, 0, 1)) |
|
@merrymercy @anijain2305 @kevinthesun Please review again. |
|
Can you add depthwise convolution support in tune_nnvm_x86 tutorial? |
|
@merrymercy tutorial updated. |
|
@tqchen Could you help to merge if it is good? |
|
Thanks, @yizhi @anijain2305 @merrymercy @kevinthesun @FrozenGene ! this is now merged. |
Improves mobilenet1.0 from ~2.2ms (autotvm tuned arm cpu schedule) to 1.5ms, on ec2 c5.9xlarge (18 physical-core Intel Skylake cpu).
Reviewers @merrymercy @kevinthesun please review.