Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import Target
from torch_tensorrt._utils import is_tensorrt_version_supported
Expand All @@ -14,6 +12,9 @@
set_layer_name,
)

import tensorrt as trt
from tensorrt import ITensor as TRTTensor

if is_tensorrt_version_supported("10.8.0"):

def quantize(
Expand All @@ -33,11 +34,28 @@ def quantize(
Adds quantize and dequantize ops (QDQ) which quantize to FP4 based
on the output_type set and dequantizes them back.
"""
if len(input_tensor.shape) not in (2, 3):
raise ValueError(
f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
)
# Save original shape before any reshape so we can restore it.
original_shape = tuple(input_tensor.shape)
# FP4 block quantization requires 2D or 3D inputs. For higher-rank
# tensors (e.g. a patch-embed reshape to (B, C, pH, pW, kH, kW)) we
# flatten all leading dimensions into one, quantize in 2D, then
# restore the original shape on the output.
needs_reshape = len(original_shape) > 3

with unset_fake_temporarily():
if needs_reshape:
last_dim = original_shape[-1]
is_weight = ".weight_quantizer" in name
if is_weight:
# torch.Tensor path: plain reshape
input_tensor = input_tensor.reshape(-1, last_dim)
else:
# TRTTensor path: insert a shuffle (reshape) layer
reshape_layer = ctx.net.add_shuffle(input_tensor)
reshape_layer.reshape_dims = (-1, last_dim)
reshape_layer.name = f"{name}_reshape_to_2d"
input_tensor = reshape_layer.get_output(0)

axis = -1
global_scale = _calculate_global_scale(ctx, name, amax)
if ".weight_quantizer" in name:
Expand All @@ -64,6 +82,24 @@ def quantize(
raise ValueError(
f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer"
)

if needs_reshape:
restore_dims = list(original_shape)
# TRT reshape_dims allows at most one -1 (inferred dimension).
# More than one dynamic dim requires shape-tensor API which is
# not yet implemented here.
dynamic_count = sum(1 for d in restore_dims if d == -1)
if dynamic_count > 1:
raise ValueError(
f"dynamic_block_quantize: cannot restore tensor to shape "
f"{original_shape} — found {dynamic_count} dynamic dimensions "
f"(TRT reshape supports at most one inferred dimension)"
)
restore_layer = ctx.net.add_shuffle(output)
restore_layer.reshape_dims = tuple(restore_dims)
restore_layer.name = f"{name}_reshape_from_2d"
output = restore_layer.get_output(0)

return output

def _dynamic_double_quantize(
Expand Down
134 changes: 134 additions & 0 deletions tests/py/dynamo/models/test_models_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,140 @@ def calibrate_loop(model):
assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3)


@unittest.skipIf(
torch.cuda.get_device_capability() < (10, 0),
"FP4 quantization requires compute capability 10.0 or later",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
"ModelOpt is required to run this test",
)
@unittest.skipIf(
platform.system() != "Linux",
"modelopt is only supported on Linux",
)
@pytest.mark.unit
def test_nvfp4_nd_input_quantizer(ir):
"""Regression test for #4201: dynamic_block_quantize must handle N-D (>3D)
input tensors that arise when a reshape precedes the input_quantizer."""
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

dtype = torch.float16

class PatchEmbedLike(torch.nn.Module):
"""Mimics the reshape->linear pattern from Qwen3-VL patch_embed.proj
that triggered the original crash: input is reshaped to 5-D before
the FP4 input_quantizer fires."""

def __init__(self) -> None:
super().__init__()
# kernel decomposition: (C, kH, kW) -> embed_dim
self.proj = torch.nn.Linear(
in_features=3 * 2 * 16, # C * kH * kW = 96
out_features=32,
bias=False,
dtype=dtype,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (N, 3, 2, 16) — already 4-D from an upstream reshape
x = x.reshape(x.shape[0], -1) # (N, 96)
return self.proj(x)

def calibrate_loop(model: torch.nn.Module) -> None:
model(input_tensor)

# 4-D input: (batch, C, kH, kW) where C*kH*kW must be divisible by
# the FP4 block size (16).
input_tensor = torch.randn(64, 3, 2, 16, dtype=dtype).cuda()
model = PatchEmbedLike().eval().cuda()

quant_cfg = mtq.NVFP4_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,), strict=False)
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)
expected = model(input_tensor)
outputs_trt = trt_model(input_tensor)
abs_diff = torch.abs(expected - outputs_trt)
assert torch.allclose(
expected, outputs_trt, rtol=0.3, atol=0.3
), f"max abs diff: {abs_diff.max().item()}"


@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
)
@unittest.skipIf(
not importlib.util.find_spec("modelopt"),
"ModelOpt is required to run this test",
)
@unittest.skipIf(
platform.system() != "Linux",
"modelopt is only supported on Linux",
)
@pytest.mark.unit
def test_fp8_and_input_quantizer(ir):
"""FP8 analogue of ``test_fp4_and_input_quantizer``: the input_quantizer must
handle N-D (>3D) input tensors flowing through a reshape into a Linear."""
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import export_torch_mode

dtype = torch.float16

class PatchEmbedLike(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.proj = torch.nn.Linear(
in_features=3 * 2 * 16,
out_features=32,
bias=False,
dtype=dtype,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (N, 3, 2, 16) — 4-D from an upstream reshape
x = x.reshape(x.shape[0], -1)
return self.proj(x)

def calibrate_loop(model: torch.nn.Module) -> None:
model(input_tensor)

input_tensor = torch.randn(64, 3, 2, 16, dtype=dtype).cuda()
model = PatchEmbedLike().eval().cuda()

mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=calibrate_loop)

with torch.no_grad():
with export_torch_mode():
exp_program = torch.export.export(model, (input_tensor,), strict=False)
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
use_explicit_typing=True,
)
expected = model(input_tensor)
outputs_trt = trt_model(input_tensor)
abs_diff = torch.abs(expected - outputs_trt)
assert torch.allclose(
expected, outputs_trt, rtol=5e-2, atol=5e-2
), f"max abs diff: {abs_diff.max().item()}"


@unittest.skipIf(
torch.cuda.get_device_capability() < (8, 9),
"FP8 quantization requires compute capability 8.9 or later",
Expand Down
Loading
Loading