Skip to content

Support postselection masks in detector sampler#165

Open
kish-00 wants to merge 2 commits into
QuEraComputing:mainfrom
kish-00:pr/postselection-masks
Open

Support postselection masks in detector sampler#165
kish-00 wants to merge 2 commits into
QuEraComputing:mainfrom
kish-00:pr/postselection-masks

Conversation

@kish-00

@kish-00 kish-00 commented Jun 15, 2026

Copy link
Copy Markdown

Adds postselection_mask parameter to CompiledDetectorSampler.sample() as proposed in #41. When a shot has a direct detector that fires under the mask, the JAX loop is skipped for that shot.

API

def sample(self, shots, *, ..., postselection_mask: np.ndarray | None = None):
  • postselection_mask: boolean array of length num_detectors. Shot discarded if any masked direct detector fires.
  • Shape unchanged: (shots, num_outputs). Discarded shots have direct detector columns real, others False.

Performance

Direct detectors (NumPy CPU) that fire skip the expensive JAX loop.

Testing

10 new unit tests: basic mask filtering, all-direct optimization, edge cases, observable integration.

Closes #41

This PR is a unitaryHACK 2026 submission.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c92ff0e244

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/tsim/sampler.py Outdated
)
predet_bits = direct_bits[:, direct_detector_indices]
# For each shot: fire if ANY masked direct detector fired
relevant = predet_bits[:, det_mask_input[direct_is_detector[direct_detector_indices]]]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Index direct masks by detector ids

When a circuit has a mix of direct detector outputs and component-sampled detector outputs, this boolean index has length equal to only the number of direct detectors, but it is applied to det_mask_input whose length is the total detector count. In that common mixed case NumPy raises a boolean-dimension mismatch before sampling, so postselection cannot be used; the mask needs to be selected by the direct outputs' original detector indices from prog.output_order, not by direct_is_detector again.

Useful? React with 👍 / 👎.

Comment thread src/tsim/sampler.py
fired = np.zeros(batch_size, dtype=np.bool_)
survivor_mask = jnp.asarray(~fired)
else:
survivor_mask = jnp.ones(batch_size, dtype=jnp.bool_)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Apply masks to non-direct detectors

When the selected postselection detectors are produced by compiled components rather than direct columns, this branch marks every shot as a survivor and there is no later check of the returned detector samples against postselection_mask. As a result, shots where a non-direct postselected detector fires are returned unchanged, so the new mask silently has no effect for those detectors instead of discarding/filling those shots.

Useful? React with 👍 / 👎.

@kish-00

kish-00 commented Jun 16, 2026

Copy link
Copy Markdown
Author

Fix commits are now on the latest PR head (2d686eb). Current Actions runs for CI, Lint, and Documentation are completed as action_required with zero jobs/logs, and rerun is blocked for this account with “Must have admin rights to Repository.” Could a maintainer approve/re-run the required checks?

kish-00 added 2 commits June 17, 2026 11:09
Add optional postselection_mask parameter to CompiledDetectorSampler.sample().
A shot is discarded if any masked detector fires. Direct (classical)
postselected detectors skip the JAX sampling loop for discarded shots;
their component-only outputs are filled with False.

- sample_program() accepts survivor_mask to conditionally skip JAX
- _sample_batches() validates mask length, computes survivor mask
  from direct detector bits, and threads it through
- All sample() overloads pass postselection_mask through
- Validation: wrong mask length or empty mask raises ValueError
- 10 new unit tests covering noiseless, noisy, reference, and edge cases
Bug 1: JAX rejects Python list indexing on 1D arrays, causing
TypeError when output_order (JAX array) was indexed with
direct_detector_indices (Python list). Fix: use np.asarray().

Bug 2: predet_bits column indexing used
det_mask_input[direct_is_detector[direct_detector_indices]] which
only works by accident when num_detectors == len(direct_detector_indices).
Fix: use output_order to correctly map direct positions to detector IDs.

Bug 3: The noiseless reference shot (f_params=0, sample index 0)
could be discarded by postselection when direct_flips had True
for a postselected detector, producing a zero reference instead
of the true noiseless outcome. Fix: force fired[0] = False when
compute_reference is active.
@kish-00 kish-00 force-pushed the pr/postselection-masks branch from 2d686eb to 212bded Compare June 17, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: Support postselection masks

1 participant