Motivation
The graph runtime is now able to support heterogeneous execution through annotation with various device ids. It is important to have a compiler pass to enable annotation from the frontend so that users have the flexibility to annotate the operators with "the best" device. This RFC proposes to add annotation in Relay as a standalone pass. There prototype implementation is here.
Action items
Some design items are listed as following:
- Each operator is attached with a
fallback attribute to indicate if it will fallback.
- Each
CallNode is attached with a device_id attribute to indicate which device it should be annotated to (by default it is 0, meaning no annotation is required)
- Annotation could be performed through various ways. Currently, users can optionally provide a
Dict[op_name, device] map to build, or enable fallback by adding set_fallback to an operator. More sophisticated annotation schemes (i.e. the ones with cost functions by taking device communication and data transferring overhead into account) could be explored in the future.
- Copy ops are needed to copy data across different devices, and these ops can be treated specially during compilation. For example, we don't need to provide them with
fcompute and fschedule. These ops could be omitted during lowering as well since the real data copy will be performed during runtime.
Proposed APIs
- The
build API is like the following:
def build(func, target=None, target_host=None, params=None, op_name_device=None, fallback_device=None): .
where heterogeneous compilation is enabled when target is a dict of device to target.
- The annotation API is as the following:
def annotate_ops(expr, op_name_dev_map, fallback_device):
During annotation, the device_id of a CallNode is set to fallback_device if its operator is registered with fallback or it is not explicitly specified where it should be allocated to in the map.
- The return of
Plan API in graph_plan_memory needs to be changed slightly. Now in addition to returning a list of storage_id, the corresponding device_id also has to be returned.
Motivation
The graph runtime is now able to support heterogeneous execution through annotation with various device ids. It is important to have a compiler pass to enable annotation from the frontend so that users have the flexibility to annotate the operators with "the best" device. This RFC proposes to add annotation in Relay as a standalone pass. There prototype implementation is here.
Action items
Some design items are listed as following:
fallbackattribute to indicate if it will fallback.CallNodeis attached with adevice_idattribute to indicate which device it should be annotated to (by default it is 0, meaning no annotation is required)Dict[op_name, device]map tobuild, or enable fallback by addingset_fallbackto an operator. More sophisticated annotation schemes (i.e. the ones with cost functions by taking device communication and data transferring overhead into account) could be explored in the future.fcomputeandfschedule. These ops could be omitted during lowering as well since the real data copy will be performed during runtime.Proposed APIs
buildAPI is like the following:def build(func, target=None, target_host=None, params=None, op_name_device=None, fallback_device=None):.where heterogeneous compilation is enabled when target is a dict of device to target.
def annotate_ops(expr, op_name_dev_map, fallback_device):During annotation, the
device_idof aCallNodeis set tofallback_deviceif its operator is registered withfallbackor it is not explicitly specified where it should be allocated to in the map.PlanAPI ingraph_plan_memoryneeds to be changed slightly. Now in addition to returning a list ofstorage_id, the correspondingdevice_idalso has to be returned.