Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
Expand Down
13 changes: 7 additions & 6 deletions topi/python/topi/rocm/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic

@dense.register("rocm")
def dense_rocm(data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_compute(dense, "rocm", "direct")
def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend.

Parameters
Expand Down Expand Up @@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None):
return dense_default(data, weight, bias, out_dtype)


@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct")
def schedule_dense(cfg, outs):
"""Schedule for dense operator.

Parameters
Expand All @@ -85,4 +86,4 @@ def schedule_dense(outs):
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)
return topi.cuda.schedule_dense(cfg, outs)