Skip to content

feat: support postselection masks in detector sampler#163

Closed
kish-00 wants to merge 1 commit into
QuEraComputing:mainfrom
kish-00:feat/postselection-masks
Closed

feat: support postselection masks in detector sampler#163
kish-00 wants to merge 1 commit into
QuEraComputing:mainfrom
kish-00:feat/postselection-masks

Conversation

@kish-00

@kish-00 kish-00 commented Jun 14, 2026

Copy link
Copy Markdown

Description

Implements the postselection masks feature for the detector sampler, as part of unitaryHACK 2026.

Adds an optional postselection_mask parameter to CompiledDetectorSampler.sample(). When a detector is masked, any shot where that detector fires is discarded — its non-direct outputs are filled with False. Direct (classical) postselected detectors skip the JAX sampling loop for discarded shots, providing a performance optimization.

Changes

  • sample_program(): Accepts optional survivor_mask parameter. When provided, only survivor shots run through JAX; discarded shots get all-False outputs.
  • _sample_batches(): Validates postselection_mask length, precomputes direct-detector mappings, computes survivor mask from direct detector bits, and threads it through to sample_program().
  • CompiledDetectorSampler.sample(): All overloads accept postselection_mask and pass it through.
  • Validation: Raises ValueError for wrong-length or empty masks.

Testing

10 new unit tests covering:

  • Noiseless circuit (no detectors fire — output unchanged)
  • Noisy repetition code circuit (shape preserved)
  • Wrong-length mask raises
  • Empty mask raises
  • Combined with separate_observables, append_observables, reference sample
  • All-False mask is a no-op

All 48 unit tests pass.

Closes #41

Add optional postselection_mask parameter to CompiledDetectorSampler.sample().
A shot is discarded if any masked detector fires. Direct (classical)
postselected detectors skip the JAX sampling loop for discarded shots;
their component-only outputs are filled with False.

- sample_program() accepts survivor_mask to conditionally skip JAX
- _sample_batches() validates mask length, computes survivor mask
  from direct detector bits, and threads it through
- All sample() overloads pass postselection_mask through
- Validation: wrong mask length or empty mask raises ValueError
- 10 new unit tests covering noiseless, noisy, reference, and edge cases

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c1a15803d4

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread src/tsim/sampler.py
Comment on lines +392 to +395
direct_bits = (
f_params_np[:, self._program.direct_f_indices].astype(np.bool_)
^ np.asarray(self._program.direct_flips, dtype=np.bool_)
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply detector reference before direct postselection

When use_detector_reference_sample=True, sample() later XORs the returned detector columns with the noiseless reference, but the direct postselection check here uses the raw direct detector bits. For any direct detector whose reference bit is 1, otherwise clean shots are marked as discarded before the XOR; their JAX-component outputs are filled with False, and after the XOR the returned detector column no longer identifies them as discarded, so reference-sample postselection can silently corrupt observables/component detector columns.

Useful? React with 👍 / 👎.

Comment thread src/tsim/sampler.py
)
predet_bits = direct_bits[:, direct_detector_indices]
# For each shot: fire if ANY masked direct detector fired
relevant = predet_bits[:, det_mask_input[direct_is_detector[direct_detector_indices]]]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Select direct masks by detector output index

In circuits that have a mix of direct detectors and component-sampled detectors, direct_is_detector[direct_detector_indices] has length equal to the number of direct detector outputs, not num_detectors, so using it as a boolean index into det_mask_input raises an IndexError whenever those counts differ. This means passing any postselection_mask to a mixed direct/non-direct detector circuit can fail before sampling; select det_mask_input by the corresponding detector output ids from output_order instead.

Useful? React with 👍 / 👎.

@kish-00

kish-00 commented Jun 16, 2026

Copy link
Copy Markdown
Author

Superseded by #165 (rebased, with 3 additional bugfixes: JAX list indexing crash, boolean index mismatch, reference shot discard). Closing to avoid confusion.

@kish-00 kish-00 closed this Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: Support postselection masks

1 participant