Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like
tvm.relay.slice_like
tvm.relay.device_copy
tvm.relay.annotation.on_device


Level 1 Definitions
Expand Down
31 changes: 31 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/annotation.h
* \brief Attribute for annotation operators.
*/
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(0);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
36 changes: 36 additions & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/device_copy.h
* \brief Attribute for the device copy operator.
*/
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_

#include <tvm/attrs.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
int dst_dev_type;
int src_dev_type;

TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") {
TVM_ATTR_FIELD(src_dev_type)
.describe(
"The virutal device/context type where the op copies data from.")
.set_default(0);
TVM_ATTR_FIELD(dst_dev_type)
.describe(
"The virutal device/context type where the op copies data to.")
.set_default(0);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEVICE_COPY_H_
1 change: 0 additions & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
}
};


struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
std::string src_layout;
std::string dst_layout;
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,21 @@ Expr ForwardRewrite(const Expr& expr,
std::function<NodeRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

/*!
* \brief Rewrite the annotated program.
* \param expr The expression.
* \param fallback_device The fallback device which is the default device for
* operators without annotation.
* \return The updated program.
*/
Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device);

/*!
* \brief Collect the device mapping information of each expression.
* \param expr The expression.
* \return The device mapping.
*/
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .ndarray import vpi, rocm, opengl, ext_dev

from ._ffi.runtime_ctypes import TypeCode
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .op.tensor import *
from .op.transform import *
from . import nn
from . import annotation
from . import vision
from . import image
from . import frontend
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# pylint: disable=wildcard-import, unused-import, unused-wildcard-import
"""Annotation related operators."""
# Re-export in a specific file name so that autodoc can pick it up
from .op.annotation import *
11 changes: 6 additions & 5 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,21 @@ def build(funcs, target, target_host=None):

Parameters
----------
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.


target : tvm.Target
The target to run the code on.
The target to run the code on.

target_host : tvm.Target
The host target.
The host target.

Returns
-------
module : tvm.Module
The runtime module.
The runtime module.
"""
if target_host == "":
target_host = None
Expand Down
83 changes: 68 additions & 15 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

from __future__ import absolute_import
import json
from collections import defaultdict
import attr
from . import _backend
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar
from ..expr_functor import ExprFunctor
from ..ty import TupleType, TensorType
from ... import target as _target


@attr.s
Expand Down Expand Up @@ -105,9 +107,9 @@ def __init__(self, mod, target):
self.nodes = []
self.var_map = {}
self.params = {}
self.storage_map = None
self.storage_device_map = None
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self.lowered_funcs = defaultdict(set)
self._name_map = {}

def add_node(self, node, expr):
Expand All @@ -129,10 +131,20 @@ def add_node(self, node, expr):
"""
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
assert expr in self.storage_device_map
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
node.attrs["storage_id"] = [x.value for x in storage_device_info[0]]
device_types = [x.value for x in storage_device_info[1]]
num_unknown_devices = device_types.count(0)
if num_unknown_devices != 0 and num_unknown_devices != len(device_types):
raise RuntimeError("The graph contains not annotated nodes for "
"heterogeneous execution. All nodes must be "
"annotated.")

# Add the `device_index` attribute when the graph is annotated.
Comment thread
zhiics marked this conversation as resolved.
if num_unknown_devices == 0:
node.attrs["device_index"] = device_types

node_id = len(self.nodes)
self.nodes.append(node)
Expand Down Expand Up @@ -232,9 +244,25 @@ def visit_call(self, call):
"TVM only support calls to primitive functions " +
"(i.e functions composed of fusable operator invocations)")

cached_func = self.compile_engine.lower(func, self.target)
assert call in self.storage_device_map
device_types = self.storage_device_map[call][1]
call_dev_type = device_types[0].value
if isinstance(self.target, (str, _target.Target)):
# homogeneous execution.
cached_func = self.compile_engine.lower(func, self.target)
self.target = {0: str(self.target)}
elif isinstance(self.target, dict):
# heterogeneous execution.
if call_dev_type not in self.target:
raise Exception("No target is provided for device " +
"{0}".format(call_dev_type))
cached_func = self.compile_engine.lower(func,
self.target[call_dev_type])
else:
raise ValueError("self.target must be the type of str," +
"tvm.target.Target, or dict of int to str")
for loweredf in cached_func.funcs:
self.lowered_funcs.add(loweredf)
self.lowered_funcs[self.target[call_dev_type]].add(loweredf)

inputs = []
# flatten tuple in the call.
Expand Down Expand Up @@ -284,20 +312,25 @@ def _get_json(self):
num_entry = 0
shapes = []
storage_ids = []
device_types = []
dltypes = []
node_row_ptr = [0]
for node in self.nodes:
assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
storage_ids += node.attrs["storage_id"]
if "device_index" in node.attrs:
device_types += node.attrs["device_index"]
num_entry += node.num_outputs
node_row_ptr.append(num_entry)

# Compute "attrs" field.
attrs = {}
attrs["shape"] = ["list_shape", shapes]
attrs["storage_id"] = ["list_int", storage_ids]
if device_types:
attrs["device_index"] = ["list_int", device_types]
attrs["dltype"] = ["list_str", dltypes]

json_dict = {
Expand All @@ -313,11 +346,24 @@ def _get_json(self):
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[0])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)

def debug_dump_device_annotation(self, func):
"""Debug function to dump device annotation result."""
def _annotate(expr):
if expr in self.storage_device_map:
storage_device_info = self.storage_device_map[expr]
assert len(storage_device_info) == 2
return str(storage_device_info[1])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)


def codegen(self, func):
"""Compile a single function into a graph.

Expand All @@ -331,24 +377,31 @@ def codegen(self, func):
graph_json : str
The graph json that can be consumed by runtime.

lowered_funcs : List[tvm.LoweredFunc]
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
The lowered functions.

params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self.storage_map = _backend.GraphPlanMemory(func)
self.storage_device_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param)
self.var_map[param] = self.add_node(node, param)

# Then we compile the body into a graph which can depend
# on input variables.
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)

# Return the lowered functions as a list for homogeneous compilation.
# Otherwise, for heterogeneous compilation, a dictionary containing
# the device id to a list of lowered functions is returned. Both forms
# are acceptable to tvm.build.
if not isinstance(self.target, dict):
lowered_funcs = list(list(self.lowered_funcs.values())[0])
else:
lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()}
return graph_json, lowered_funcs, self.params

def _get_unique_name(self, name):
Expand Down
Loading