Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ struct TRTEngine : torch::CustomClassHolder {
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
void set_pre_allocated_outputs(bool enable);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was never used. The setter is here:

self.engine.set_output_tensors_as_unowned(enabled)

void set_output_tensors_as_unowned(bool enable);
bool are_output_tensors_unowned();
TorchTRTRuntimeStates runtime_states;
Expand Down
1 change: 1 addition & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
: TRTEngine::ResourceAllocationStrategy::kStatic);
})
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("pre_allocated_outputs", &TRTEngine::pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
"device_memory_budget",
Expand Down
5 changes: 2 additions & 3 deletions docsrc/contributors/complex_number_support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ runtime modules handle the conversion:
* ``prepare_inputs`` (``dynamo/utils.py``) — builds the ``Input`` spec with the
``view_as_real`` shape/dtype but retains the original complex tensor in
``inp.torch_tensor`` for tracing.
* ``_PythonTorchTensorRTModule.forward`` — applies ``torch.view_as_real(i).contiguous()``
for each complex input before feeding it to the engine.
* ``_TorchTensorRTModule.forward`` — same ``view_as_real`` conversion.
* ``TorchTensorRTModule.forward`` — applies ``torch.view_as_real(i).contiguous()``
for each complex input before feeding tensors to ``execute_engine`` / ``execute_engine_python``.

Key Implementation Invariants
-------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docsrc/contributors/cuda_graphs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ Subsequent inference launches the instantiated graph instead of calling
Graph Storage
^^^^^^^^^^^^^

Each runtime module (both C++ ``TorchTensorRTModule`` and Python
``PythonTorchTensorRTModule``) stores a ``cudaGraphExec_t`` instance. When
``TorchTensorRTModule`` (C++ or Python execution path) may record a CUDA graph for
engine execution when CUDA graphs are enabled at runtime. When
``use_cuda_graph=True`` is set at compile time the runtime records one graph
per engine for the first input shape encountered.

Expand Down
20 changes: 3 additions & 17 deletions docsrc/contributors/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@ infrastructure for inference.
Dynamo Runtime (Primary Path)
-------------------------------

Two runtime backends are available. The backend is selected via the
``use_python_runtime`` compilation setting.

C++ Runtime (default)
^^^^^^^^^^^^^^^^^^^^^^^

The C++ runtime is more performant, fully serializable, and supports advanced features
like CUDAGraphs and multi-device safety.
The Dynamo runtime is fully serializable and supports advanced features like
CUDAGraphs and multi-device safety.

TensorRT engines are stored as ``torch.classes.tensorrt.Engine`` — a C++ TorchBind
class that holds the serialized engine bytes plus metadata:
Expand All @@ -41,14 +35,6 @@ This op pops inputs and the engine off the PyTorch dispatcher stack, runs the te
through TensorRT, and pushes output tensors back. The compiled ``torch.fx.Graph``
stores engine objects as attributes, making the whole module portable.

Python Runtime
^^^^^^^^^^^^^^^

The Python runtime uses TensorRT's Python API directly for inference. It is useful when
a C++ build is not available (e.g. in some CI environments) and is simpler to instrument
for debugging. It does not support serialization to ``ExportedProgram``; the compiled
graph is Python-only.

Serialization Options
----------------------

Expand All @@ -59,7 +45,7 @@ The default serialization path for the Dynamo AOT workflow. The compiled
``torch.fx.GraphModule`` is wrapped in a
`torch.export.ExportedProgram <https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/export.html>`_
container. TensorRT engines are stored as tensor attributes in the package; PyTrees
capture input/output structure. Requires the C++ runtime and supports Python execution.
capture input/output structure.

.. code-block:: python
Expand Down
8 changes: 3 additions & 5 deletions docsrc/debugging/troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ Runtime Errors
the engine. Upgrade TRT or rebuild with ``version_compatible=True``.
* The GPU compute capability is lower than on the build machine. Rebuild with
``hardware_compatible=True`` (requires Ampere or newer).
* The ``.ep`` file was generated with ``use_python_runtime=True`` which is not
serializable. Rebuild with the default C++ runtime.

**Shape mismatch at runtime / "Invalid input shape"**

Expand All @@ -153,9 +151,9 @@ Runtime Errors
The model contains data-dependent-shape ops (``nonzero``, ``unique``,
``masked_select``, etc.) which require TRT's output allocator.

* Use ``PythonTorchTensorRTModule`` (``use_python_runtime=True``) — it
activates the dynamic output allocator automatically via
``requires_output_allocator=True``.
* Use :class:`~torch_tensorrt.runtime.TorchTensorRTModule` (or a compiled graph that wraps it)
with ``requires_output_allocator=True`` so the runtime can use TRT's output allocator
when the engine needs dynamic output allocation.
* See :ref:`cuda_graphs` for ``DynamicOutputAllocator`` details.

----
Expand Down
13 changes: 10 additions & 3 deletions docsrc/py_api/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,20 @@ Functions

.. autofunction:: enable_output_allocator

Runtime backend
---------------

Execution uses the C++ runtime engine when it is installed in the build; otherwise the
Python runtime engine is used. There is no separate process-wide backend switch
in ``torch_tensorrt.runtime``.

Classes
---------

.. autoclass:: TorchTensorRTModule
:members:
:special-members: __init__
:show-inheritance:

.. autoclass:: PythonTorchTensorRTModule
:members:
:special-members: __init__
Single runtime module for TensorRT engines. Dispatches to the C++ or Python execution
implementation depending on whether the C++ extension is available. See :ref:`python_runtime`.
1 change: 0 additions & 1 deletion docsrc/tutorials/deployment/cross_compile_windows.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Requirements
The following features are **disabled** during cross-compilation (they are not
available in the Windows TRT runtime or require OS-specific binaries):

* Python runtime (``use_python_runtime`` is forced to ``False``)
* Lazy engine initialization (``lazy_engine_init`` is forced to ``False``)
* Engine caching (``cache_built_engines`` / ``reuse_cached_engines`` disabled)

Expand Down
5 changes: 0 additions & 5 deletions docsrc/tutorials/deployment/distributed_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ handles DTensor inputs correctly:
options={
"use_distributed_mode_trace": True,
"use_explicit_typing": True, # enabled_precisions deprecated
"use_python_runtime": True,
"min_block_size": 1,
},
)
Expand Down Expand Up @@ -152,10 +151,6 @@ Compilation Settings for Distributed Workloads
- ``False``
- Use ``aot_autograd`` for tracing instead of the default path. Required when the
model contains DTensor or other distributed tensors.
* - ``use_python_runtime``
- ``None`` (auto)
- Use the Python runtime. Often set to ``True`` for tensor-parallel models that run
inside an existing distributed process group.
* - ``use_explicit_typing``
- ``True``
- Respect dtypes set in model/inputs (recommended). Use ``model.half()`` or
Expand Down
4 changes: 2 additions & 2 deletions docsrc/tutorials/runtime_opt/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ Runtime Optimization
=====================

Optimize inference throughput and latency: CUDA Graphs for kernel-replay,
pre-allocated output buffers, and the Python runtime module.
pre-allocated output buffers, and choosing the Python vs C++ TRT execution path.

.. toctree::
:maxdepth: 1

cuda_graphs
Example: Torch Export with Cudagraphs <../_rendered_examples/dynamo/torch_export_cudagraphs>
Example: Pre-allocated output buffer <../_rendered_examples/dynamo/pre_allocated_output_example>
python_runtime
Python vs C++ runtime <python_runtime>
133 changes: 71 additions & 62 deletions docsrc/tutorials/runtime_opt/python_runtime.rst
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
.. _python_runtime:

Python Runtime
==============
Python vs C++ runtime
=====================

Torch-TensorRT provides two runtime backends for executing compiled TRT engines
inside a PyTorch graph:
Torch-TensorRT uses a single module type, :class:`~torch_tensorrt.runtime.TorchTensorRTModule`,
to run TensorRT engines inside PyTorch. The **execution path** (which code actually drives
TensorRT execution) is selected automatically:

* **C++ runtime** (default) — ``TorchTensorRTModule`` backed by a C++ TorchBind class.
Fully serializable, supports CUDAGraphs, multi-device safe.
* **Python runtime** — ``PythonTorchTensorRTModule`` backed entirely by the TRT Python
API. Simpler to instrument for debugging but **not serializable** to
``ExportedProgram``.
* **C++ path** — ``torch.classes.tensorrt.Engine`` and ``torch.ops.tensorrt.execute_engine``.
Used when the Torch-TensorRT C++ extension (``libtorchtrt`` / runtime ``.so``) is loaded:
TorchScript-friendly, and integrates with the full C++ runtime stack.
* **Python path** — Internal ``TRTEngine`` (``torch_tensorrt.dynamo.runtime._TRTEngine``)
plus ``tensorrt::execute_engine`` registered from Python when the C++ runtime is not
available (use ``PYTHON_ONLY=1`` when building Torch-TensorRT). Useful for minimal installs and for Python-level debugging.

Both the C++ and Python paths are invoked through the same ``TorchTensorRTModule`` class,
which dispatches to the appropriate runtime engine based on the build of Torch-TensorRT (Full build or PYTHON_ONLY build).

----

When to Use the Python Runtime
--------------------------------
When the Python runtime is used
-----------------------------

Use ``use_python_runtime=True`` when:
The Python engine implementation is chosen automatically when the C++ Torch-TensorRT library
is not installed (enabled by setting ``PYTHON_ONLY=1`` when building Torch-TensorRT). You may still prefer that setup when:

* You need to run on a machine where the C++ Torch-TensorRT library is not installed
(e.g., a minimal CI container with only the Python wheel).
Expand All @@ -27,74 +33,77 @@ Use ``use_python_runtime=True`` when:

Use the default C++ runtime in all other cases, especially:

* When saving a compiled module to disk (``torch_tensorrt.save()``).
* When using CUDAGraphs for low-latency inference.
* In production deployments.

----

Enabling the Python Runtime
-----------------------------
Compile and run
-----------------

.. code-block:: python
Use ``torch_tensorrt.dynamo.compile``, ``torch.compile(..., backend="tensorrt", ...)``, or
construct :class:`~torch_tensorrt.runtime.TorchTensorRTModule` directly. The module picks C++
vs Python execution based on the build of Torch-TensorRT (Full build or Python-only build).

import torch_tensorrt
----

trt_gm = torch_tensorrt.dynamo.compile(
exported_program,
arg_inputs=inputs,
use_python_runtime=True,
)
Serialization
---------------

Or via ``torch.compile``:
``TorchTensorRTModule`` are serializable in both the C++ and Python paths.
.. code-block::python
torch_tensorrt.save(trt_module, trt_ep_path, retrace=True)
trt_module = torch_tensorrt.load(trt_ep_path).module()

.. code-block:: python
Cross-serialization (Python and C++)
-------------------------------------

trt_model = torch.compile(
model,
backend="tensorrt",
options={"use_python_runtime": True},
)
One of the key features of ``TorchTensorRTModule`` is seamless cross serialization:
**you can serialize an engine using the Python runtime and load it using the C++ runtime, or vice versa**.
The engine file format and all core metadata are fully compatible across runtimes and platforms, ensuring flexibility for production and development workflows.

----
For example, you can:

Limitations
-----------
- **Build and serialize in Python**, then deploy by loading the module in a C++-enabled environment (e.g. in TorchScript or when the C++ extension is present):

.. code-block:: python

* **Not serializable**: ``PythonTorchTensorRTModule`` cannot be saved via
``torch_tensorrt.save()`` as an ``ExportedProgram`` or loaded back. The module is
Python-only in-process.
# In an environment with only Python runtime (PYTHON_ONLY=1)
torch_tensorrt.save(trt_module, "trt_module.ep")

.. code-block:: python
# --- Later, or on a different machine with C++ runtime enabled ---
trt_module = torch_tensorrt.load("trt_module.ep").module()
output = trt_module(input)

# This will raise an error with use_python_runtime=True:
torch_tensorrt.save(trt_gm, "model.ep", arg_inputs=inputs)
- **Build in C++ runtime environment**, save the engine, and then load it in a Python-only deployment or debugging context, with no changes needed.

* **No C++ deployment**: The compiled module cannot be exported to AOTInductor or used
in a C++ application without re-compiling with the C++ runtime.
This interoperability allows you to train, compile, and debug using the Python path,
but deploy for maximum performance using the C++ runtime—or test and profile using Python tools with modules built from C++.
**No extra conversion is required and the serialization format is shared across both backends.**

----

Limitations
-----------
* **CUDAGraphs**: Whole-graph CUDAGraphs work with the Python runtime, but the
per-submodule CUDAGraph recording in ``CudaGraphsTorchTensorRTModule`` is
only available with the C++ runtime.

----

``PythonTorchTensorRTModule`` Direct Instantiation
----------------------------------------------------
``TorchTensorRTModule`` from raw engine bytes
---------------------------------------------

You can instantiate ``PythonTorchTensorRTModule`` directly from raw engine bytes,
for example when integrating a TRT engine built outside of Torch-TensorRT:
You can build a module directly from a serialized TensorRT engine (for example, an engine
produced outside Torch-TensorRT):

.. code-block:: python

from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule
from torch_tensorrt.dynamo._settings import CompilationSettings

# Load raw engine bytes (e.g., from trtexec output or torch_tensorrt.dynamo.convert_*)
with open("model.engine", "rb") as f:
engine_bytes = f.read()

module = PythonTorchTensorRTModule(
module = TorchTensorRTModule(
serialized_engine=engine_bytes,
input_binding_names=["x"],
output_binding_names=["output"],
Expand All @@ -104,23 +113,22 @@ for example when integrating a TRT engine built outside of Torch-TensorRT:

output = module(torch.randn(1, 3, 224, 224).cuda())

**Constructor arguments:**
**Constructor arguments** (see class docstring for full detail):

``serialized_engine`` (``bytes``)
The raw serialized TRT engine bytes.
Raw serialized TRT engine.

``input_binding_names`` (``List[str]``)
``input_binding_names`` / ``output_binding_names`` (``List[str]``)
TRT input binding names in the order they are passed to ``forward()``.

``output_binding_names`` (``List[str]``)
TRT output binding names in the order they should be returned.
TRT output binding names in the order they are returned from ``forward()``.

``name`` (``str``, optional)
Human-readable name for the module (used in logging).
Name for logging and serialization.

``settings`` (``CompilationSettings``, optional)
The compilation settings used to build the engine. Used to determine device
placement and other runtime behaviors.
``settings`` (:class:`~torch_tensorrt.dynamo._settings.CompilationSettings`, optional)
Device and runtime options (must match how the engine was built).

``weight_name_map`` (``dict``, optional)
Mapping of TRT weight names to PyTorch state dict names. Required for refit
Expand All @@ -132,9 +140,10 @@ for example when integrating a TRT engine built outside of Torch-TensorRT:

----

Runtime Selection Logic
------------------------
Runtime selection summary
-------------------------

When ``use_python_runtime`` is ``None`` (auto-select), Torch-TensorRT tries to import
the C++ TorchBind class. If the C++ extension is not available it silently falls back to
the Python runtime. Pass ``True`` or ``False`` to force a specific runtime.
* ``TorchTensorRTModule`` uses the C++ engine path when the Torch-TensorRT extension is loaded;
otherwise it uses the Python ``TRTEngine`` path.
* If the C++ extension is **not** built, only the Python path is available.
* To use the Python runtime, set ``PYTHON_ONLY=1`` when building Torch-TensorRT.
8 changes: 1 addition & 7 deletions docsrc/user_guide/compilation/compilation_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ Core Parameters
* - ``device``
- current CUDA device
- :class:`torch_tensorrt.Device` specifying the GPU to compile for.
* - ``use_python_runtime``
- ``False`` (auto)
- ``False`` uses the C++ runtime (recommended — serializable, CUDAGraphs,
multi-device safe). ``True`` forces the Python runtime (simpler to instrument
for debugging but not serializable to ``ExportedProgram``). ``None`` selects C++
if available.
* - ``pass_through_build_failures``
- ``False``
- When ``True``, TRT engine build errors raise exceptions rather than fall back to PyTorch.
Expand Down Expand Up @@ -372,7 +366,7 @@ Compilation Workflow
- ``False``
- Defer TRT engine deserialization until all engines have been built.
Works around resource contraints and builder overhad but engines
may be less well tuned to their deployment resource availablity
may be less well tuned to their deployment resource availability
* - ``debug``
- ``False``
- Enable verbose TRT builder logs at ``DEBUG`` level.
Expand Down
Loading
Loading