feat: support postselection masks in detector sampler#163
Conversation
Add optional postselection_mask parameter to CompiledDetectorSampler.sample(). A shot is discarded if any masked detector fires. Direct (classical) postselected detectors skip the JAX sampling loop for discarded shots; their component-only outputs are filled with False. - sample_program() accepts survivor_mask to conditionally skip JAX - _sample_batches() validates mask length, computes survivor mask from direct detector bits, and threads it through - All sample() overloads pass postselection_mask through - Validation: wrong mask length or empty mask raises ValueError - 10 new unit tests covering noiseless, noisy, reference, and edge cases
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c1a15803d4
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| direct_bits = ( | ||
| f_params_np[:, self._program.direct_f_indices].astype(np.bool_) | ||
| ^ np.asarray(self._program.direct_flips, dtype=np.bool_) | ||
| ) |
There was a problem hiding this comment.
Apply detector reference before direct postselection
When use_detector_reference_sample=True, sample() later XORs the returned detector columns with the noiseless reference, but the direct postselection check here uses the raw direct detector bits. For any direct detector whose reference bit is 1, otherwise clean shots are marked as discarded before the XOR; their JAX-component outputs are filled with False, and after the XOR the returned detector column no longer identifies them as discarded, so reference-sample postselection can silently corrupt observables/component detector columns.
Useful? React with 👍 / 👎.
| ) | ||
| predet_bits = direct_bits[:, direct_detector_indices] | ||
| # For each shot: fire if ANY masked direct detector fired | ||
| relevant = predet_bits[:, det_mask_input[direct_is_detector[direct_detector_indices]]] |
There was a problem hiding this comment.
Select direct masks by detector output index
In circuits that have a mix of direct detectors and component-sampled detectors, direct_is_detector[direct_detector_indices] has length equal to the number of direct detector outputs, not num_detectors, so using it as a boolean index into det_mask_input raises an IndexError whenever those counts differ. This means passing any postselection_mask to a mixed direct/non-direct detector circuit can fail before sampling; select det_mask_input by the corresponding detector output ids from output_order instead.
Useful? React with 👍 / 👎.
|
Superseded by #165 (rebased, with 3 additional bugfixes: JAX list indexing crash, boolean index mismatch, reference shot discard). Closing to avoid confusion. |
Description
Implements the postselection masks feature for the detector sampler, as part of unitaryHACK 2026.
Adds an optional
postselection_maskparameter toCompiledDetectorSampler.sample(). When a detector is masked, any shot where that detector fires is discarded — its non-direct outputs are filled withFalse. Direct (classical) postselected detectors skip the JAX sampling loop for discarded shots, providing a performance optimization.Changes
sample_program(): Accepts optionalsurvivor_maskparameter. When provided, only survivor shots run through JAX; discarded shots get all-Falseoutputs._sample_batches(): Validates postselection_mask length, precomputes direct-detector mappings, computes survivor mask from direct detector bits, and threads it through tosample_program().CompiledDetectorSampler.sample(): All overloads acceptpostselection_maskand pass it through.ValueErrorfor wrong-length or empty masks.Testing
10 new unit tests covering:
separate_observables,append_observables, reference sampleFalsemask is a no-opAll 48 unit tests pass.
Closes #41