Support postselection masks in detector sampler#165
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c92ff0e244
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| ) | ||
| predet_bits = direct_bits[:, direct_detector_indices] | ||
| # For each shot: fire if ANY masked direct detector fired | ||
| relevant = predet_bits[:, det_mask_input[direct_is_detector[direct_detector_indices]]] |
There was a problem hiding this comment.
Index direct masks by detector ids
When a circuit has a mix of direct detector outputs and component-sampled detector outputs, this boolean index has length equal to only the number of direct detectors, but it is applied to det_mask_input whose length is the total detector count. In that common mixed case NumPy raises a boolean-dimension mismatch before sampling, so postselection cannot be used; the mask needs to be selected by the direct outputs' original detector indices from prog.output_order, not by direct_is_detector again.
Useful? React with 👍 / 👎.
| fired = np.zeros(batch_size, dtype=np.bool_) | ||
| survivor_mask = jnp.asarray(~fired) | ||
| else: | ||
| survivor_mask = jnp.ones(batch_size, dtype=jnp.bool_) |
There was a problem hiding this comment.
Apply masks to non-direct detectors
When the selected postselection detectors are produced by compiled components rather than direct columns, this branch marks every shot as a survivor and there is no later check of the returned detector samples against postselection_mask. As a result, shots where a non-direct postselected detector fires are returned unchanged, so the new mask silently has no effect for those detectors instead of discarding/filling those shots.
Useful? React with 👍 / 👎.
|
Fix commits are now on the latest PR head ( |
Add optional postselection_mask parameter to CompiledDetectorSampler.sample(). A shot is discarded if any masked detector fires. Direct (classical) postselected detectors skip the JAX sampling loop for discarded shots; their component-only outputs are filled with False. - sample_program() accepts survivor_mask to conditionally skip JAX - _sample_batches() validates mask length, computes survivor mask from direct detector bits, and threads it through - All sample() overloads pass postselection_mask through - Validation: wrong mask length or empty mask raises ValueError - 10 new unit tests covering noiseless, noisy, reference, and edge cases
Bug 1: JAX rejects Python list indexing on 1D arrays, causing TypeError when output_order (JAX array) was indexed with direct_detector_indices (Python list). Fix: use np.asarray(). Bug 2: predet_bits column indexing used det_mask_input[direct_is_detector[direct_detector_indices]] which only works by accident when num_detectors == len(direct_detector_indices). Fix: use output_order to correctly map direct positions to detector IDs. Bug 3: The noiseless reference shot (f_params=0, sample index 0) could be discarded by postselection when direct_flips had True for a postselected detector, producing a zero reference instead of the true noiseless outcome. Fix: force fired[0] = False when compute_reference is active.
2d686eb to
212bded
Compare
Adds
postselection_maskparameter toCompiledDetectorSampler.sample()as proposed in #41. When a shot has a direct detector that fires under the mask, the JAX loop is skipped for that shot.API
Performance
Direct detectors (NumPy CPU) that fire skip the expensive JAX loop.
Testing
10 new unit tests: basic mask filtering, all-direct optimization, edge cases, observable integration.
Closes #41
This PR is a unitaryHACK 2026 submission.