diff --git a/docsrc/tutorials/resource_memory/engine_cache.rst b/docsrc/tutorials/resource_memory/engine_cache.rst index a49b8ed7e6..499a674262 100644 --- a/docsrc/tutorials/resource_memory/engine_cache.rst +++ b/docsrc/tutorials/resource_memory/engine_cache.rst @@ -129,6 +129,38 @@ The timing cache is always active and persisted at ``timing_cache_path``: The default path is ``/tmp/torch_tensorrt_engine_cache/timing_cache.bin``. +.. note:: + + The timing cache is **not used with TensorRT-RTX**, which does not perform + autotuning. For TensorRT-RTX, see the *Runtime Cache* section below. + +Runtime Cache (TensorRT-RTX) +----------------------------- + +TensorRT-RTX uses JIT compilation at inference time. The **runtime cache** stores +these compilation results so that kernels and execution graphs are not recompiled +on subsequent runs. This is analogous to the timing cache but operates at inference +time rather than build time. + +The runtime cache is automatically created when using TensorRT-RTX and can be +persisted to disk via ``runtime_cache_path``: + +.. code-block:: python + + trt_gm = torch_tensorrt.dynamo.compile( + exported_program, + arg_inputs=inputs, + runtime_cache_path="/data/trt_cache/runtime_cache.bin", + use_python_runtime=True, + ) + +The default path is +``/tmp/torch_tensorrt_engine_cache/runtime_cache.bin``. + +The cache is saved to disk when the module is destroyed (garbage collected) and +loaded on subsequent compilations with the same path. File locking is used to +prevent corruption when multiple processes share the same cache file. + ---- Custom Cache Backends diff --git a/docsrc/user_guide/compilation/compilation_settings.rst b/docsrc/user_guide/compilation/compilation_settings.rst index 2c32bb81c0..74439d756b 100644 --- a/docsrc/user_guide/compilation/compilation_settings.rst +++ b/docsrc/user_guide/compilation/compilation_settings.rst @@ -372,7 +372,7 @@ Compilation Workflow - ``False`` - Defer TRT engine deserialization until all engines have been built. Works around resource contraints and builder overhad but engines - may be less well tuned to their deployment resource availablity + may be less well tuned to their deployment resource availability * - ``debug`` - ``False`` - Enable verbose TRT builder logs at ``DEBUG`` level. @@ -402,7 +402,13 @@ Engine Caching - ``/tmp/torch_tensorrt_engine_cache/timing_cache.bin`` - Path for TRT's timing cache file. The timing cache records kernel timing data across sessions, speeding up subsequent engine builds for similar subgraphs even - when the engine cache itself is cold. + when the engine cache itself is cold. Not used for TensorRT-RTX (no autotuning). + * - ``runtime_cache_path`` + - ``/tmp/torch_tensorrt_engine_cache/runtime_cache.bin`` + - Path for the TensorRT-RTX runtime cache file. The runtime cache stores JIT + compilation results at inference time, preventing repeated compilation of + kernels and graphs across sessions. Uses file locking for concurrent access + safety. Only used with TensorRT-RTX; ignored for standard TensorRT. ---- diff --git a/py/requirements.txt b/py/requirements.txt index 3d9dc09297..cd6b2341c4 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -6,4 +6,5 @@ torch>=2.12.0.dev,<2.13.0 --extra-index-url https://pypi.ngc.nvidia.com pyyaml dllist +filelock setuptools \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bc3cdc5721..b6e691f607 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -92,6 +92,7 @@ def cross_compile_for_windows( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -170,7 +171,8 @@ def cross_compile_for_windows( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -334,6 +336,7 @@ def cross_compile_for_windows( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, + "runtime_cache_path": runtime_cache_path, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -366,6 +369,12 @@ def cross_compile_for_windows( f"arg: {key} is not supported for cross compilation for windows feature, hence it is disabled." ) + if "runtime_cache_path" in compilation_options: + compilation_options.pop("runtime_cache_path") + logger.warning( + "runtime_cache_path is a JIT-time API and is not applicable to cross compilation for windows. Ignoring." + ) + settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) exported_program = pre_export_lowering(exported_program, settings) @@ -438,6 +447,7 @@ def compile( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -531,7 +541,8 @@ def compile( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -738,6 +749,7 @@ def compile( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, + "runtime_cache_path": runtime_cache_path, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1150,6 +1162,7 @@ def convert_exported_program_to_serialized_trt_engine( dryrun: bool = _defaults.DRYRUN, hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, + runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1224,7 +1237,8 @@ def convert_exported_program_to_serialized_trt_engine( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1397,6 +1411,7 @@ def convert_exported_program_to_serialized_trt_engine( "dryrun": dryrun, "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, + "runtime_cache_path": runtime_cache_path, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1413,6 +1428,11 @@ def convert_exported_program_to_serialized_trt_engine( "use_distributed_mode_trace": use_distributed_mode_trace, "decompose_attention": decompose_attention, } + if "runtime_cache_path" in compilation_options: + compilation_options.pop("runtime_cache_path") + logger.warning( + "runtime_cache_path is a JIT-time API and is not applicable to serialized engine export. Ignoring." + ) settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 2e838cd28c..b6d40a01bb 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -40,6 +40,9 @@ TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) +RUNTIME_CACHE_PATH = os.path.join( + tempfile.gettempdir(), "torch_tensorrt_engine_cache", "runtime_cache.bin" +) LAZY_ENGINE_INIT = False CACHE_BUILT_ENGINES = False REUSE_CACHED_ENGINES = False diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index e3f2f1bc37..b53718d526 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -43,6 +43,7 @@ REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, + RUNTIME_CACHE_PATH, SPARSE_WEIGHTS, STRIP_ENGINE_WEIGHTS, TILING_OPTIMIZATION_LEVEL, @@ -96,7 +97,8 @@ class CompilationSettings: TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. @@ -149,6 +151,7 @@ class CompilationSettings: dryrun: Union[bool, str] = DRYRUN hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH + runtime_cache_path: str = RUNTIME_CACHE_PATH lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 87ac3cbcd0..6c27b9bdfe 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -379,7 +379,14 @@ def _create_timing_cache( """ Create a timing cache to enable faster build time for TRT engines. By default the timing_cache_path="/tmp/timing_cache.bin" + Skipped for TensorRT-RTX since it does not use autotuning. """ + if ENABLED_FEATURES.tensorrt_rtx: + _LOGGER.info( + "Skipping timing cache creation for TensorRT-RTX (no autotuning)" + ) + return + buffer = b"" if os.path.isfile(timing_cache_path): # Load from existing cache @@ -394,8 +401,12 @@ def _save_timing_cache( timing_cache_path: str, ) -> None: """ - This is called after a TensorRT engine is built. Save the timing cache + This is called after a TensorRT engine is built. Save the timing cache. + Skipped for TensorRT-RTX since it does not use autotuning. """ + if ENABLED_FEATURES.tensorrt_rtx: + return + timing_cache = builder_config.get_timing_cache() os.makedirs(os.path.dirname(timing_cache_path), exist_ok=True) with open(timing_cache_path, "wb") as timing_cache_file: diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index d3ef7e0a41..d9d505a5cf 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -110,7 +110,8 @@ def __init__( enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT. dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) - timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. + runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels **kwargs: Any, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 31182bbe21..9d122446fe 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -1,14 +1,17 @@ from __future__ import annotations import logging +import os from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sequence, Tuple +import tensorrt as trt import torch import torch_tensorrt from torch.nn import Module from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt.dynamo._defaults import DEBUG_LOGGING_DIR from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig @@ -21,8 +24,6 @@ multi_gpu_device_check, ) -import tensorrt as trt - logger = logging.getLogger(__name__) @@ -226,6 +227,12 @@ def __init__( # If the output tensor is not owned by the engine (output_tensors_are_unowned=True), we need to create a new output tensor in each forward pass self.output_tensors_are_unowned = False self.symbolic_shape_expressions = symbolic_shape_expressions + + # Runtime cache state (TensorRT-RTX only) + self.runtime_config: Any = None + self.runtime_cache: Any = None + self.runtime_cache_path = settings.runtime_cache_path + if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -257,7 +264,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: if self.context is not None: del self.context budget_bytes = self._set_device_memory_budget(budget_bytes) - self.context = self.engine.create_execution_context() + self.context = self._create_context() self.runtime_states.context_changed = True return budget_bytes @@ -290,7 +297,11 @@ def setup_engine(self) -> None: self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) if self.settings.enable_weight_streaming: self.set_default_device_memory_budget() - self.context = self.engine.create_execution_context() + + if ENABLED_FEATURES.tensorrt_rtx: + self._setup_runtime_config() + + self.context = self._create_context() assert self.context is not None, "Failed to create execution context" assert self.engine.num_io_tensors == ( len(self.input_names) + len(self.output_names) @@ -324,6 +335,67 @@ def setup_engine(self) -> None: for input_name in self.input_names } + def _setup_runtime_config(self) -> None: + """Create a RuntimeConfig with runtime cache for TensorRT-RTX. + + The runtime cache stores JIT compilation results to avoid repeated + compilation of kernels/graphs across inference runs. + """ + self.runtime_config = self.engine.create_runtime_config() + self.runtime_config.set_execution_context_allocation_strategy( + trt.ExecutionContextAllocationStrategy.STATIC + ) + self.runtime_cache = self.runtime_config.create_runtime_cache() + self._load_runtime_cache() + self.runtime_config.set_runtime_cache(self.runtime_cache) + logger.info("TensorRT-RTX runtime cache configured") + + def _create_context(self) -> trt.IExecutionContext: + """Create an execution context, using RuntimeConfig for RTX.""" + if ENABLED_FEATURES.tensorrt_rtx and self.runtime_config is not None: + return self.engine.create_execution_context(self.runtime_config) + return self.engine.create_execution_context() + + def _load_runtime_cache(self) -> None: + """Load runtime cache from disk if it exists (with shared file lock).""" + if self.runtime_cache is None: + return + if not os.path.isfile(self.runtime_cache_path): + logger.debug(f"No existing runtime cache at {self.runtime_cache_path}") + return + try: + from filelock import FileLock + + lock = FileLock(self.runtime_cache_path + ".lock") + with lock.acquire(timeout=10): + with open(self.runtime_cache_path, "rb") as f: + data = f.read() + if data: + self.runtime_cache.deserialize(data) + logger.info(f"Loaded runtime cache from {self.runtime_cache_path}") + except Exception as e: + logger.warning(f"Failed to load runtime cache: {e}") + + def _save_runtime_cache(self) -> None: + """Save runtime cache to disk (with exclusive file lock).""" + if self.runtime_cache is None: + return + try: + host_mem = self.runtime_cache.serialize() + if host_mem is None: + return + os.makedirs(os.path.dirname(self.runtime_cache_path), exist_ok=True) + + from filelock import FileLock + + lock = FileLock(self.runtime_cache_path + ".lock") + with lock.acquire(timeout=10): + with open(self.runtime_cache_path, "wb") as f: + f.write(memoryview(host_mem)) + logger.info(f"Saved runtime cache to {self.runtime_cache_path}") + except Exception as e: + logger.warning(f"Failed to save runtime cache: {e}") + def _check_initialized(self) -> None: if not self.initialized: raise RuntimeError("PythonTorchTensorRTModule is not initialized.") @@ -357,6 +429,8 @@ def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() state.pop("engine", None) state.pop("context", None) + state.pop("runtime_config", None) + state.pop("runtime_cache", None) return state def __setstate__(self, state: Dict[str, Any]) -> None: @@ -376,6 +450,7 @@ def _reset_captured_graph(self) -> None: self.cudagraph = None def __del__(self) -> None: + self._save_runtime_cache() self._reset_captured_graph() def setup_input_tensors( @@ -771,7 +846,7 @@ def disable_profiling(self) -> None: self._check_initialized() torch.cuda.synchronize() del self.context - self.context = self.engine.create_execution_context() + self.context = self._create_context() self.profiling_enabled = False def get_layer_info(self) -> str: diff --git a/setup.py b/setup.py index 8b377fc651..dba7d5ec6e 100644 --- a/setup.py +++ b/setup.py @@ -823,6 +823,7 @@ def get_requirements(): base_requirements = [ "packaging>=23", "typing-extensions>=4.7.0", + "filelock", "dllist", "psutil", ] diff --git a/tests/py/dynamo/models/test_runtime_cache_models.py b/tests/py/dynamo/models/test_runtime_cache_models.py new file mode 100644 index 0000000000..aecb2fbaa3 --- /dev/null +++ b/tests/py/dynamo/models/test_runtime_cache_models.py @@ -0,0 +1,329 @@ +import gc +import importlib +import os +import shutil +import tempfile +import time +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), + "torchvision is not installed", +) +class TestRuntimeCacheModels(TestCase): + """End-to-end model tests with runtime cache enabled.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + torch._dynamo.reset() + + def test_resnet18_with_runtime_cache(self): + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + runtime_cache_path=self.cache_path, + ) + + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"ResNet18 cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) + + # Verify runtime cache is saved on cleanup + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(self.cache_path), + "Runtime cache should be saved after ResNet18 inference", + ) + + def test_resnet18_cache_reuse(self): + """Compile + infer twice with same cache path. Second run should load cached data.""" + import torchvision.models as models + + model = models.resnet18(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + compile_kwargs = { + "ir": "dynamo", + "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + "runtime_cache_path": self.cache_path, + } + + # First compilation — cold cache + compiled1 = torchtrt.compile(model, **compile_kwargs) + _ = compiled1(input_tensor) + del compiled1 + gc.collect() + torch._dynamo.reset() + self.assertTrue(os.path.isfile(self.cache_path)) + cache_size_1 = os.path.getsize(self.cache_path) + + # Second compilation — warm cache + compiled2 = torchtrt.compile(model, **compile_kwargs) + output2 = compiled2(input_tensor) + + cos_sim = cosine_similarity(ref_output, output2) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"ResNet18 (cached) cosine similarity {cos_sim} below threshold", + ) + + del compiled2 + gc.collect() + cache_size_2 = os.path.getsize(self.cache_path) + # Cache should exist and be non-empty after both runs + self.assertGreater(cache_size_1, 0) + self.assertGreater(cache_size_2, 0) + + def test_mobilenet_v2_with_runtime_cache(self): + import torchvision.models as models + + model = models.mobilenet_v2(pretrained=True).eval().cuda() + input_tensor = torch.randn(1, 3, 224, 224).cuda() + + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + runtime_cache_path=self.cache_path, + ) + + ref_output = model(input_tensor) + trt_output = compiled(input_tensor) + + cos_sim = cosine_similarity(ref_output, trt_output) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"MobileNetV2 cosine similarity {cos_sim} below threshold {COSINE_THRESHOLD}", + ) + + del compiled + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCacheDynamicShapes(TestCase): + """Tests runtime cache with dynamic input shapes.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + torch._dynamo.reset() + + def test_dynamic_batch_with_cache(self): + class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv(x)) + + model = ConvModel().eval().cuda() + + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=[ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ], + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + runtime_cache_path=self.cache_path, + ) + + # Test with batch size 1 + input_bs1 = torch.randn(1, 3, 32, 32).cuda() + ref_bs1 = model(input_bs1) + out_bs1 = compiled(input_bs1) + cos_sim_1 = cosine_similarity(ref_bs1, out_bs1) + self.assertTrue( + cos_sim_1 > COSINE_THRESHOLD, + f"BS=1 cosine similarity {cos_sim_1} below threshold", + ) + + # Test with batch size 4 + input_bs4 = torch.randn(4, 3, 32, 32).cuda() + ref_bs4 = model(input_bs4) + out_bs4 = compiled(input_bs4) + cos_sim_4 = cosine_similarity(ref_bs4, out_bs4) + self.assertTrue( + cos_sim_4 > COSINE_THRESHOLD, + f"BS=4 cosine similarity {cos_sim_4} below threshold", + ) + + # Verify cache is saved + del compiled + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + def test_cache_valid_across_shapes(self): + """Save cache from one shape, load and verify it works with another shape in range.""" + + class SimpleConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 8, 3, padding=1) + + def forward(self, x): + return self.conv(x) + + model = SimpleConv().eval().cuda() + + compile_kwargs = { + "ir": "dynamo", + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 16, 16), + opt_shape=(2, 3, 16, 16), + max_shape=(4, 3, 16, 16), + dtype=torch.float32, + ) + ], + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + "runtime_cache_path": self.cache_path, + } + + # First run with batch=2 — saves cache + compiled1 = torchtrt.compile(model, **compile_kwargs) + input_bs2 = torch.randn(2, 3, 16, 16).cuda() + _ = compiled1(input_bs2) + del compiled1 + gc.collect() + torch._dynamo.reset() + self.assertTrue(os.path.isfile(self.cache_path)) + + # Second run with batch=3 — loads same cache + compiled2 = torchtrt.compile(model, **compile_kwargs) + input_bs3 = torch.randn(3, 3, 16, 16).cuda() + ref_bs3 = model(input_bs3) + out_bs3 = compiled2(input_bs3) + + cos_sim = cosine_similarity(ref_bs3, out_bs3) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + f"Cross-shape cache reuse cosine similarity {cos_sim} below threshold", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCachePerformance(TestCase): + """Informational timing tests for runtime cache warm-up behavior.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + torch._dynamo.reset() + + def test_warmup_timing(self): + """Measure cold vs warm cache inference time. Informational only — no strict pass/fail.""" + + class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(256, 512) + self.fc2 = torch.nn.Linear(512, 256) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + model = MLP().eval().cuda() + input_tensor = torch.randn(16, 256).cuda() + + compile_kwargs = { + "ir": "dynamo", + "inputs": [torchtrt.Input(input_tensor.shape, dtype=torch.float32)], + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + "runtime_cache_path": self.cache_path, + } + + # Cold cache compilation + inference + compiled1 = torchtrt.compile(model, **compile_kwargs) + torch.cuda.synchronize() + start = time.perf_counter() + _ = compiled1(input_tensor) + torch.cuda.synchronize() + cold_time = time.perf_counter() - start + del compiled1 + gc.collect() + torch._dynamo.reset() + + # Warm cache compilation + inference + compiled2 = torchtrt.compile(model, **compile_kwargs) + torch.cuda.synchronize() + start = time.perf_counter() + _ = compiled2(input_tensor) + torch.cuda.synchronize() + warm_time = time.perf_counter() - start + + print(f"\n Cold cache first inference: {cold_time*1000:.1f}ms") + print(f" Warm cache first inference: {warm_time*1000:.1f}ms") + print(f" Speedup: {cold_time/warm_time:.2f}x") + + # No strict assertion — just log for visibility + self.assertTrue(True, "Timing test completed (informational)") + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py new file mode 100644 index 0000000000..bad67db24c --- /dev/null +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -0,0 +1,287 @@ +import gc +import logging +import os +import shutil +import tempfile +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._defaults import RUNTIME_CACHE_PATH, TIMING_CACHE_PATH +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +class TwoLayerModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 8) + + def forward(self, x): + return torch.relu(self.linear(x)) + + +def _compile_simple(runtime_cache_path=None): + """Helper: compile SimpleModel with Python runtime, return (compiled_module, inputs).""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + } + if runtime_cache_path is not None: + kwargs["runtime_cache_path"] = runtime_cache_path + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled, inputs + + +def _find_python_trt_module(compiled): + """Walk the compiled graph module to find PythonTorchTensorRTModule instances.""" + from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( + PythonTorchTensorRTModule, + ) + + for name, mod in compiled.named_modules(): + if isinstance(mod, PythonTorchTensorRTModule): + return mod + return None + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCacheSetup(TestCase): + """Tests that runtime config and cache are correctly created for RTX.""" + + def test_runtime_config_created(self): + compiled, _ = _compile_simple() + mod = _find_python_trt_module(compiled) + self.assertIsNotNone( + mod, "No PythonTorchTensorRTModule found in compiled model" + ) + self.assertIsNotNone(mod.runtime_config, "runtime_config should be set for RTX") + self.assertIsNotNone(mod.runtime_cache, "runtime_cache should be set for RTX") + + def test_context_created_successfully(self): + compiled, inputs = _compile_simple() + mod = _find_python_trt_module(compiled) + self.assertIsNotNone(mod.context, "execution context should be created") + # Verify inference works + output = compiled(*[inp.clone() for inp in inputs]) + self.assertEqual(output.shape, inputs[0].shape) + + def test_runtime_cache_path_default(self): + compiled, _ = _compile_simple() + mod = _find_python_trt_module(compiled) + self.assertEqual(mod.runtime_cache_path, RUNTIME_CACHE_PATH) + + def test_runtime_cache_path_custom(self): + cache_dir = tempfile.mkdtemp() + try: + custom_path = os.path.join(cache_dir, "my_cache.bin") + compiled, _ = _compile_simple(runtime_cache_path=custom_path) + mod = _find_python_trt_module(compiled) + self.assertEqual(mod.runtime_cache_path, custom_path) + finally: + shutil.rmtree(cache_dir, ignore_errors=True) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCachePersistence(TestCase): + """Tests that runtime cache is correctly saved to and loaded from disk.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def test_cache_saved_on_del(self): + compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + # Run inference to populate the cache + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertFalse( + os.path.isfile(self.cache_path), + "Cache should not exist before module cleanup", + ) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(self.cache_path), + "Cache file should be created after module cleanup", + ) + + def test_cache_file_nonempty(self): + compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertGreater( + os.path.getsize(self.cache_path), + 0, + "Cache file should have nonzero size", + ) + + def test_cache_roundtrip(self): + """Compile, infer, save. Then compile again with same cache path and verify correctness.""" + model = SimpleModel().eval().cuda() + inputs = [torch.randn(2, 3).cuda()] + ref_output = model(*inputs) + + # First compilation — populates and saves cache + compiled1, _ = _compile_simple(runtime_cache_path=self.cache_path) + _ = compiled1(*[inp.clone() for inp in inputs]) + del compiled1 + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + + # Second compilation — should load cached data + compiled2, _ = _compile_simple(runtime_cache_path=self.cache_path) + output = compiled2(*[inp.clone() for inp in inputs]) + max_diff = float(torch.max(torch.abs(ref_output - output))) + self.assertAlmostEqual( + max_diff, 0, places=3, msg="Output mismatch after cache roundtrip" + ) + + def test_save_creates_directory(self): + nested_path = os.path.join(self.cache_dir, "a", "b", "c", "runtime_cache.bin") + compiled, inputs = _compile_simple(runtime_cache_path=nested_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertTrue( + os.path.isfile(nested_path), + "Save should create intermediate directories", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Runtime cache is only available with TensorRT-RTX", +) +class TestRuntimeCacheConcurrency(TestCase): + """Tests that file locking works for concurrent access.""" + + def setUp(self): + self.cache_dir = tempfile.mkdtemp() + self.cache_path = os.path.join(self.cache_dir, "runtime_cache.bin") + + def tearDown(self): + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def test_filelock_works(self): + """Verify that filelock can be acquired on the cache path after save.""" + compiled, inputs = _compile_simple(runtime_cache_path=self.cache_path) + _ = compiled(*[inp.clone() for inp in inputs]) + del compiled + gc.collect() + self.assertTrue(os.path.isfile(self.cache_path)) + # Verify we can acquire a lock on the same path (no deadlock) + from filelock import FileLock + + lock = FileLock(self.cache_path + ".lock") + with lock.acquire(timeout=5): + data = open(self.cache_path, "rb").read() + self.assertGreater(len(data), 0) + + def test_sequential_save_load(self): + """Two modules saving and loading from the same path should not corrupt data.""" + # First module saves + compiled1, inputs = _compile_simple(runtime_cache_path=self.cache_path) + _ = compiled1(*[inp.clone() for inp in inputs]) + del compiled1 + gc.collect() + size1 = os.path.getsize(self.cache_path) + + # Second module saves (overwrites) + compiled2, inputs = _compile_simple(runtime_cache_path=self.cache_path) + _ = compiled2(*[inp.clone() for inp in inputs]) + del compiled2 + gc.collect() + size2 = os.path.getsize(self.cache_path) + + self.assertGreater(size1, 0) + self.assertGreater(size2, 0) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Timing cache skip is only relevant for TensorRT-RTX", +) +class TestTimingCacheSkipped(TestCase): + """Tests that timing cache is correctly skipped for RTX builds.""" + + def setUp(self): + # Clean up any pre-existing timing cache + if os.path.isfile(TIMING_CACHE_PATH): + os.remove(TIMING_CACHE_PATH) + + def test_no_timing_cache_file(self): + compiled, inputs = _compile_simple() + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertFalse( + os.path.isfile(TIMING_CACHE_PATH), + "Timing cache should NOT be created for RTX builds", + ) + + def test_timing_cache_skip_logged(self): + with self.assertLogs( + "torch_tensorrt.dynamo.conversion._TRTInterpreter", level="INFO" + ) as cm: + compiled, inputs = _compile_simple() + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertTrue( + any("Skipping timing cache" in msg for msg in cm.output), + f"Expected 'Skipping timing cache' log message, got: {cm.output}", + ) + + +@unittest.skipIf( + ENABLED_FEATURES.tensorrt_rtx, + "This test verifies standard TRT behavior (non-RTX)", +) +class TestNonRTXUnchanged(TestCase): + """Tests that standard TRT behavior is unaffected by the runtime cache changes.""" + + def test_no_runtime_config_for_standard_trt(self): + compiled, _ = _compile_simple() + mod = _find_python_trt_module(compiled) + if mod is not None: + self.assertIsNone( + mod.runtime_config, + "runtime_config should be None for standard TRT", + ) + self.assertIsNone( + mod.runtime_cache, + "runtime_cache should be None for standard TRT", + ) + + def test_timing_cache_still_created(self): + # Clean up any pre-existing timing cache + if os.path.isfile(TIMING_CACHE_PATH): + os.remove(TIMING_CACHE_PATH) + compiled, inputs = _compile_simple() + _ = compiled(*[inp.clone() for inp in inputs]) + self.assertTrue( + os.path.isfile(TIMING_CACHE_PATH), + "Timing cache should still be created for standard TRT", + ) + + +if __name__ == "__main__": + run_tests()