Skip to content

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015

Draft
ricardoV94 wants to merge 7 commits into
pymc-devs:mainfrom
ricardoV94:gather_scatter_fusion
Draft

Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015
ricardoV94 wants to merge 7 commits into
pymc-devs:mainfrom
ricardoV94:gather_scatter_fusion

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 29, 2026

Summary

Introduce IndexedElemwise, an OpFromGraph that wraps AdvancedSubtensor + Elemwise + AdvancedIncSubtensor subgraphs so the Numba backend can generate a single loop with indirect indexing, avoiding materializing AvancedSubtensor input arrays, and writing directly on the output buffer, doing the job of AdvancedIncSubtensor in the same loop, without having to loop again through the intermediate elemwise output

Commit 1 fuses indexed reads (AdvancedSubtensor1 on inputs).
Commit 2 fuses indexed updates (AdvancedIncSubtensor1 on outputs).
Commit 3 extends to AdvancedSubtensor inputs, on arbitrary (1d) indexed (consecutive) axes

Motivation

In hierarchical models with mu = beta[group_idx] * x + ..., the logp+gradient graph combines indexed reads and indexed updates in the same Elemwise (the forward expands group-level parameters via advanced subtensor, and the gradient accumulates back into the source via advanced inc subtensor).

A simple example

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

numba_mode = get_mode("NUMBA")
numba_mode_before = numba_mode.excluding("fuse_indexed_elemwise")

x = pt.vector("x")
idx = pt.vector("idx", dtype=int)
value = pt.vector("value")

y = pt.zeros(100)
out = ((x[idx] - value) ** 2).sum()
grad_wrt_x = pt.grad(out, x)
fn_before = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode_before, trust_input=True)
fn_before.dprint(print_op_info=True, print_destroy_map=True)
# Sum{axes=None} [id A] 5
#  └─ Composite{...}.0 [id B] d={0: [0]} 1
#     ├─ AdvancedSubtensor1 [id C] 0
#     │  ├─ x [id D]
#     │  └─ idx [id E]
#     └─ value [id F]
# AdvancedIncSubtensor1{inplace,inc} [id G] d={0: [0]} 4
#  ├─ Alloc [id H] 3
#  │  ├─ [0.] [id I]
#  │  └─ Shape_i{0} [id J] 2
#  │     └─ x [id D]
#  ├─ Composite{...}.1 [id B] d={0: [0]} 1
#  │  └─ ···
#  └─ idx [id E]

# Inner graphs:

# Composite{...} [id B] d={0: [0]}
#  ← sqr [id K] 'o0'
#     └─ sub [id L] 't5'
#        ├─ i0 [id M]
#        └─ i1 [id N]
#  ← mul [id O] 'o1'
#     ├─ 2.0 [id P]
#     └─ sub [id L] 't5'
#        └─ ···

fn = pytensor.function([x, value, idx], [out, grad_wrt_x], mode=numba_mode, trust_input=True)
fn.dprint(print_op_info=True, print_destroy_map=True)

# Sum{axes=None} [id A] 3
#  └─ IndexedElemwise{Composite{...}}.0 [id B] d={1: [3]} 2
#     ├─ x [id C] (indexed read (idx_0))
#     ├─ value [id D]
#     ├─ idx [id E] (idx_0)
#     └─ Alloc [id F] 1 (buf_0)
#        ├─ [0.] [id G]
#        └─ Shape_i{0} [id H] 0
#           └─ x [id C]
# IndexedElemwise{Composite{...}}.1 [id B] d={1: [3]} 2 (indexed inc (buf_0, idx_0))
#  └─ ···

# Inner graphs:

# IndexedElemwise{Composite{...}} [id B] d={1: [3]}
#  ← Composite{...}.0 [id I]
#     ├─ AdvancedSubtensor1 [id J]
#     │  ├─ *0-<Vector(float64, shape=(?,))> [id K]
#     │  └─ *2-<Vector(int64, shape=(?,))> [id L]
#     └─ *1-<Vector(float64, shape=(?,))> [id M]
#  ← AdvancedIncSubtensor1{inplace,inc} [id N] d={0: [0]}
#     ├─ *3-<Vector(float64, shape=(?,))> [id O]
#     ├─ Composite{...}.1 [id I]
#     │  └─ ···
#     └─ *2-<Vector(int64, shape=(?,))> [id L]

# Composite{...} [id I]
#  ← sqr [id P] 'o0'
#     └─ sub [id Q] 't0'
#        ├─ i0 [id R]
#        └─ i1 [id S]
#  ← mul [id T] 'o1'
#     ├─ 2.0 [id U]
#     └─ sub [id Q] 't0'
#        └─ ···

x_test = np.arange(15, dtype="float64")
idx_test = np.random.randint(15, size=(10_000,))
value_test = np.random.normal(size=idx_test.shape)

logp_before, dlogp_before = fn_before(x_test, value_test, idx_test)
logp, dlogp = fn(x_test, value_test, idx_test)
np.testing.assert_allclose(logp_before, logp)
np.testing.assert_allclose(dlogp_before, dlogp)

%timeit fn_before(x_test, value_test, idx_test)  # 29.4 μs ± 2.57 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit fn(x_test, value_test, idx_test)  # 13.8 μs ± 136 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Next step would be to also fuse the sum directly on the elemwise, so we end up with a single loop over the data. This is important as the sum can easily break our fusion, as we don't fuse if the elemwise output is needed elsewhere (like in a sum).

@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 2 times, most recently from 6d875d8 to 0ad6e2e Compare March 29, 2026 18:14
@ricardoV94 ricardoV94 changed the title Numba: fuse AdvancedSubtensor+Elemwise Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor Mar 29, 2026
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 9 times, most recently from 41869a4 to a07997b Compare April 2, 2026 11:10
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 6 times, most recently from 2b65554 to f939f9c Compare April 5, 2026 20:32
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch 5 times, most recently from c06fc82 to 0057a48 Compare May 24, 2026 00:17
…ariables

When a shared variable's update is deleted but the variable is still
destroyed (mutated inplace) by a node in the copied graph, the shared
variable storage will still be mutated. Emit a UserWarning in this case.
Extend FusionOptimizer to merge independent subgraphs that share
inputs but have no producer-consumer edge (siblings like f(x) and g(x)).
The eager expansion only walks producer-consumer edges, missing these.

Also extract InplaceGraphOptimizer.try_inplace_on_node helper and
_insert_sorted_subgraph to deduplicate insertion-point logic.
The inplace_pattern loop used `input_type.layout` leaked from the
preceding core-input-types loop instead of `output_type.layout`.
Extend the IndexedElemwise fusion to also absorb
AdvancedIncSubtensor1 (indexed set/inc) on the output side.

Before (3 nodes):
  temp = Elemwise(x[idx], y)               # shape (919,)
  result = IncSubtensor(target, temp, idx)  # target shape (85,)

After (1 fused loop, target is an input):
  for k in range(919):
      target[idx[k]] += scalar_fn(x[idx[k]], y[k])

- FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers
- Reject fusion when val broadcasts against target's non-indexed axes
- store_core_outputs supports inc mode via o[...] += val
- Inner fgraph always uses inplace IncSubtensor
- op_debug_information shows buf_N / idx_N linkage
Support AdvancedSubtensor on any axis (not just axis 0) and multi-index
patterns like x[idx_row, idx_col] where multiple 1D index arrays address
consecutive source axes.  Generalize writes (AdvancedIncSubtensor) to
match.

Reads:
- Add undo_take_dimshuffle_for_fusion pre-fusion rewrite
- _get_indexed_read_info handles AdvancedSubtensor with consecutive
  tensor indices, full-slice prefix/suffix
- Reject boolean indices and non-consecutive advanced indices

Writes:
- _get_indexed_update_info mirrors _get_indexed_read_info for
  AdvancedIncSubtensor
- find_indexed_update_consumers detects both AdvancedIncSubtensor1
  and AdvancedIncSubtensor
- Broadcast guard generalized for non-axis-0 indexed axes
- Indexed update construction supports AdvancedIncSubtensor (inplace)

Dispatch + codegen:
- indexed_inputs encoding: ((positions, axis, idx_bc), ...)
- input_read_spec uses tuple of (idx_k, axis) pairs per input
- n_index_loop_dims = max(idx.ndim for group)
Support multidimensional (e.g. 2D matrix) and 0-d integer indices in
IndexedElemwise fusion, for both reads and writes.

ND indices:
- Add undo_take_reshape_for_fusion: undoes the Reshape+flatten pattern
  that transform_take applies for ND indices, recovering the original
  AdvancedSubtensor(source, mat_idx) form for fusion.
  Handles both axis=0 and axis>0 (with DimShuffle wrapping).
- idx_load_axes: tuple of tuples, each index array loads from idx_ndim
  loop counters

0-d indices:
- Accept 0-d tensor indices (e.g. x[scalar_idx, vec_idx]) which are
  valid AdvancedSubtensor inputs that broadcast with higher-dim indices.
@ricardoV94 ricardoV94 force-pushed the gather_scatter_fusion branch from 0057a48 to b917ef0 Compare May 24, 2026 00:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant