Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 98 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3964,11 +3964,7 @@ def aten_ops_linear(
def scaled_dot_product_attention_validator(
node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
if node.kwargs.get("enable_gqa", False):
_LOGGER.debug(
"enable_gqa is not yet supported by the converter. Please try setting decompose_attention=True in the compilation settings."
)
return False
enable_gqa = node.kwargs.get("enable_gqa", False)

query_shape, key_shape, value_shape = None, None, None
if "val" in node.args[0].meta:
Expand All @@ -3977,15 +3973,51 @@ def scaled_dot_product_attention_validator(
key_shape = node.args[1].meta["val"].size()
if "val" in node.args[2].meta:
value_shape = node.args[2].meta["val"].size()
if (
query_shape != key_shape
or query_shape != value_shape
or key_shape != value_shape
):

if key_shape != value_shape:
_LOGGER.debug(
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
)
return False

if query_shape is not None and key_shape is not None:
if len(query_shape) != len(key_shape):
_LOGGER.debug("query and key have different ranks, which is not supported.")
return False
ndim = len(query_shape)
if enable_gqa:
# IAttentionLayer natively supports GQA: Q and K/V may differ on the
# head dim (dim 1) as long as Hq is divisible by Hkv.
# Check batch (dim 0) and head_dim (last dim) match; skip seq (dim -2)
# and head (dim 1) dims.
head_dim = ndim - 1
seq_dim = ndim - 2
heads_dim = 1
for i in range(ndim):
if i in (seq_dim, heads_dim):
continue
if query_shape[i] != key_shape[i]:
_LOGGER.debug(
f"query and key mismatch on dim {i} with enable_gqa=True."
)
return False
num_q_heads = query_shape[1]
num_kv_heads = key_shape[1]
if num_q_heads % num_kv_heads != 0:
_LOGGER.debug(
f"enable_gqa=True but num_q_heads={num_q_heads} is not divisible "
f"by num_kv_heads={num_kv_heads}."
)
return False
else:
# IAttentionLayer supports decode-phase (seq_q != seq_k).
# Check all dims except the seq dim.
seq_dim = ndim - 2
if any(query_shape[i] != key_shape[i] for i in range(ndim) if i != seq_dim):
_LOGGER.debug(
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
)
return False
return True


Expand Down Expand Up @@ -4032,15 +4064,42 @@ def scaled_dot_product_flash_attention_validator(
key_shape = node.args[1].meta["val"].size()
if "val" in node.args[2].meta:
value_shape = node.args[2].meta["val"].size()
if (
query_shape != key_shape
or query_shape != value_shape
or key_shape != value_shape
):
if key_shape != value_shape:
_LOGGER.debug(
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
)
return False
if query_shape is not None and key_shape is not None:
if len(query_shape) != len(key_shape):
_LOGGER.debug("query and key have different ranks, which is not supported.")
return False
ndim = len(query_shape)
seq_dim = ndim - 2
heads_dim = 1
num_q_heads = query_shape[heads_dim]
num_kv_heads = key_shape[heads_dim]
is_gqa = num_q_heads != num_kv_heads
if is_gqa:
# GQA: IAttentionLayer natively handles Hq != Hkv.
# Require batch/head_dim to match and Hq divisible by Hkv.
for i in range(ndim):
if i in (seq_dim, heads_dim):
continue
if query_shape[i] != key_shape[i]:
_LOGGER.debug(f"GQA: query and key mismatch on dim {i}.")
return False
if num_q_heads % num_kv_heads != 0:
_LOGGER.debug(
f"GQA: num_q_heads={num_q_heads} not divisible by num_kv_heads={num_kv_heads}."
)
return False
else:
# MHA / decode-phase: seq may differ, all other dims must match.
if any(query_shape[i] != key_shape[i] for i in range(ndim) if i != seq_dim):
_LOGGER.debug(
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
)
return False
return True


Expand Down Expand Up @@ -4086,15 +4145,31 @@ def scaled_dot_product_efficient_attention_validator(
key_shape = node.args[1].meta["val"].size()
if "val" in node.args[2].meta:
value_shape = node.args[2].meta["val"].size()
if (
query_shape != key_shape
or query_shape != value_shape
or key_shape != value_shape
):
if key_shape != value_shape:
_LOGGER.debug(
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
)
return False
# GQA (Hq != Hkv) is intentionally not supported here.
# PyTorch's eager _scaled_dot_product_efficient_attention kernel rejects
# non-equal head counts at runtime, so no valid reference output exists for
# comparison. In practice, GQA models on CUDA dispatch to
# _scaled_dot_product_flash_attention (FP16/BF16) or decompose into
# matmul+_safe_softmax (FP32) — this op never appears with GQA shapes in
# a real FX graph. GQA is handled by the flash attention validator instead.
#
# IAttentionLayer does support decode-phase (seq_q != seq_k), so only the
# sequence dimension is skipped in the shape check below.
if query_shape is not None and key_shape is not None:
if len(query_shape) != len(key_shape) or any(
query_shape[i] != key_shape[i]
for i in range(len(query_shape))
if i != len(query_shape) - 2 # skip the seq dim
):
_LOGGER.debug(
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
)
return False
return True


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,57 @@
def force_causal_efficient_attention(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Force efficient-attention calls to causal mode when enabled in settings."""
"""Force efficient-attention calls to causal mode when enabled in settings.

For square attention (seq_q == seq_k): replaces attn_bias with is_causal=True
so IAttentionLayer can use its native causal path.

For decode-phase attention (seq_q != seq_k): skip the transformation.
Applying is_causal=True is semantically wrong here — it creates a lower-
triangular mask aligned to position 0, so the query attends only to k[0]
instead of all past keys. The node is left unchanged and passed to
IAttentionLayer, which supports non-square Q/K natively.
"""
if not settings.attn_bias_is_causal:
return gm

changed = False
for node in gm.graph.nodes:
if (
node.target
== torch.ops.aten._scaled_dot_product_efficient_attention.default
!= torch.ops.aten._scaled_dot_product_efficient_attention.default
):
continue

attn_bias = node.args[3] if len(node.args) > 3 else None
if attn_bias is None:
continue

query_node, key_node = node.args[0], node.args[1]
query_meta = query_node.meta.get("val") if hasattr(query_node, "meta") else None
key_meta = key_node.meta.get("val") if hasattr(key_node, "meta") else None
if (
query_meta is not None
and key_meta is not None
and query_meta.size(2) != key_meta.size(2)
):
attn_bias = node.args[3] if len(node.args) > 3 else None
if attn_bias is None:
continue
node.args = (
node.args[0],
node.args[1],
node.args[2],
None,
False,
0.0,
True,
)
changed = True
logger.debug(
f"The args of node {node} was changed to causal mode. Now the node's arguments are: {node.args}"
f"Skipping causal force for node {node}: seq_q={query_meta.size(2)} "
f"!= seq_k={key_meta.size(2)} (decode-phase, IAttentionLayer handles it)"
)
continue

node.args = (
node.args[0],
node.args[1],
node.args[2],
None,
False,
0.0,
True,
)
changed = True
logger.debug(f"Node {node} changed to causal mode: {node.args}")

if changed:
gm = clean_up_graph_after_modifications(gm)
Expand Down
56 changes: 2 additions & 54 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from typing import Any, Callable, List, Optional, Sequence, Tuple

import torch
import torch_tensorrt
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.passes.shape_prop import ShapeProp
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt import Input
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args
Expand Down Expand Up @@ -109,58 +107,6 @@ def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool:
return False


# this method is only used in our converter test to infer the module output dtypes via dummy inference
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a duplicate definition of line 42

# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
def infer_module_output_dtypes_for_test(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
device: Device,
kwarg_inputs: Optional[dict[str, Any]] = None,
truncate_double: bool = False,
) -> List[dtype]:
"""
This function performs model inference to determine the output dtypes
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
"""
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
# so we stick to the model inference approach currently.
with unset_fake_temporarily():
# Get the device on which the model exists
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
device = get_model_device(module)
torch_inputs = get_torch_inputs(inputs, device)
if kwarg_inputs is None:
kwarg_inputs = {}
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

# Int64 outputs can sometimes be generated from within other operators
# such as aten.sum - such outputs can be truncated
output_dtypes = []
for output in module_outputs:
output_ = output
# We don't need to check if output is nested here because the input module will be flattened
if not isinstance(output, torch.Tensor):
if isinstance(output, str):
raise ValueError(
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
)
else:
output_ = torch.tensor(output)

if truncate_double and output_.dtype == dtype.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(output_.dtype))

return output_dtypes


def fetch_attr(mod, target):
"""
Fetch an attribute from the ``Module`` hierarchy of ``mod.module``.
Expand Down Expand Up @@ -422,6 +368,7 @@ def run_test(
immutable_weights=True,
use_explicit_typing=False,
decompose_attention=False,
attn_bias_is_causal=True,
):
# TODO: lan to remove this and set use_dynamo_traccer to True by default
# once all the converter test files are moved to use_dynamo_tracer
Expand All @@ -434,6 +381,7 @@ def run_test(
immutable_weights=immutable_weights,
use_explicit_typing=use_explicit_typing,
decompose_attention=decompose_attention,
attn_bias_is_causal=attn_bias_is_causal,
)

mod = self.generate_graph(
Expand Down
Empty file added tests/py/dynamo/hlo/__init__.py
Empty file.
Loading
Loading