Skip to content

Simplify the if statement in backprop depthwise convolution #294

@weitliu

Description

@weitliu

Hi everyone. I am working on the backward method for the depth wise convolution. The implementation I currently can think of has a lot of tvm.select. Is there any way we could simplify the code?

def trans(b, i, j, c):
    global Out_grad_cond
    Out_grad_cond = tvm.compute(
        (batch, in_h, in_w, out_c),
        lambda bo, io, jo, co: tvm.select(tvm.all(io >= tvm.select(0<(i - filter_h + pad_h + stride_h) / stride_h,(i-filter_h+pad_h+stride_h)/stride_h,tvm.const(0)),
                                                  io <  tvm.select(0<((i + pad_h) / stride_h)+1-out_h, tvm.const(out_h - 1), (i + pad_h) / stride_h),
                                                  jo >= tvm.select(0<(j - filter_w + pad_w + stride_w) / stride_w,(j-filter_w+pad_w+stride_w)/stride_w,tvm.const(0)),
                                                  jo <  tvm.select(0<((j + pad_w) / stride_w)+1-out_w, tvm.const(out_w - 1), (j + pad_w) / stride_w)),
                                                  Out_grad[b, i, j, c], tvm.const(0.0)))

    di = tvm.reduce_axis((0, out_h-1), name='di')
    dj = tvm.reduce_axis((0, out_w-1), name='dj')
    dc = tvm.reduce_axis((0, channel_multiplier), name='dc')

    return tvm.sum(Out_grad_cond[b, di, dj, c*channel_multiplier + dc] * Filter[i+pad_h-di*stride_h, j+pad_w-dj*stride_w, c, dc],axis=[di,dj,dc])

In_grad = tvm.compute(
         (batch, in_h, in_w, in_c),
         lambda b, i, j, c: trans(b,i,j,c),
         name='In_grad')

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions