Skip to content

Problem met when trying to implement argmax in topi #503

@sxjscience

Description

@sxjscience

@tqchen
When we do an argmax, we basically need to do:

T_idx, T_val = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k))

We can avoid using idx[i, k] by doing the following:

T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val[i, k]), axis=k))

However, we cannot do this, which will appear when we want to calculate the idx w.r.t multiple reduce axes.

T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var + k.var, val[i, k]), axis=k))

Is there a way to solve this problem?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions