Skip to content

Commit 6425284

Browse files
rafaelhaCopilot
andauthored
Fix order of observable annotations not being reflected in sampler output (#127)
## Summary - Fix `parse_stim_circuit` crash on empty `DETECTOR` / `OBSERVABLE_INCLUDE` annotations; they now produce deterministic-zero detector/observable bits, matching Stim. - Fix sparse and out-of-order `OBSERVABLE_INCLUDE` indices: sampler output now has one column per id in `range(num_observables)`, in sorted order, with missing ids as deterministic-zero columns. --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 559359d commit 6425284

7 files changed

Lines changed: 154 additions & 11 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Fixed
1111
- `MR`, `MRX`, and `MRY` no longer double-count their measurement flip probability as both a pre-measurement Pauli error and a measurement-result flip.
12+
- Out-of-order `OBSERVABLE_INCLUDE` indices now produce the correct sampler column order and output shape. Missing indices below the maximum mentioned id appear as deterministic-zero columns, and columns are emitted in sorted logical-index order.
13+
- Empty `DETECTOR` and `OBSERVABLE_INCLUDE` annotations (without targets) no longer crash the parser; they now produce zero detector/observable bits, matching Stim semantics.
14+
1215

1316
### Added
1417
- `TPP` and `TPP_DAG` instructions — applies exp(-i pi/8 P) or exp(+i pi/8 P) (up to global phase) for a Pauli product P, i.e., phases the -1 eigenspace of P by exp(i pi/4) or exp(-i pi/4).

src/tsim/core/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def build_sampling_graph(
231231
g.remove_vertex(vertices.pop())
232232

233233
labels = [f"det[{i}]" for i in range(len(built.detectors))] + [
234-
f"obs[{i}]" for i in built.observables_dict
234+
f"obs[{i}]" for i in sorted(built.observables_dict)
235235
]
236236
for label in labels:
237237
vs = annotation_to_vertex[label]

src/tsim/core/instructions.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,12 +1114,26 @@ def ry(b: GraphRepresentation, qubit: int) -> None:
11141114
# =============================================================================
11151115

11161116

1117-
def detector(b: GraphRepresentation, rec: list[int], *args) -> None:
1118-
"""Add detector annotation that XORs the given measurement record bits."""
1119-
row = min({b.graph.row(b.rec[r]) for r in rec}) - 0.5
1117+
def _annotation_row(b: GraphRepresentation, rec: list[int]) -> float:
1118+
"""Pick a fresh row for a detector/observable vertex.
1119+
1120+
Empty `rec` is valid Stim and represents a deterministic-zero annotation;
1121+
fall back to a row above existing annotations rather than `min()` on an
1122+
empty set.
1123+
"""
11201124
d_rows = {b.graph.row(d) for d in b.detectors + b.observables}
1125+
if rec:
1126+
row: float = min(b.graph.row(b.rec[r]) for r in rec) - 0.5
1127+
else:
1128+
row = (max(d_rows) + 1) if d_rows else 0
11211129
while row in d_rows:
11221130
row += 1
1131+
return row
1132+
1133+
1134+
def detector(b: GraphRepresentation, rec: list[int], *args) -> None:
1135+
"""Add detector annotation that XORs the given measurement record bits."""
1136+
row = _annotation_row(b, rec)
11231137
v0 = b.graph.add_vertex(
11241138
VertexType.X, qubit=-1, row=row, phase=f"det[{len(b.detectors)}]"
11251139
)
@@ -1133,10 +1147,7 @@ def observable_include(b: GraphRepresentation, rec: list[int], idx: int) -> None
11331147
idx = int(idx)
11341148

11351149
if idx not in b.observables_dict:
1136-
row = min({b.graph.row(b.rec[r]) for r in rec}) - 0.5
1137-
d_rows = {b.graph.row(d) for d in b.detectors + b.observables}
1138-
while row in d_rows:
1139-
row += 1
1150+
row = _annotation_row(b, rec)
11401151
v0 = b.graph.add_vertex(VertexType.X, qubit=-1, row=row, phase=f"obs[{idx}]")
11411152
b.observables_dict[idx] = v0
11421153

src/tsim/core/parse.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,4 +261,13 @@ def parse_stim_circuit(
261261
gate_func(b, *chunk, *args)
262262

263263
finalize_correlated_error(b)
264+
265+
# Materialize every observable id from 0..num_observables-1 so missing
266+
# indices appear as deterministic-zero outputs and downstream iteration
267+
# is in sorted index order, matching Stim semantics.
268+
for i in range(stim_circuit.num_observables):
269+
if i not in b.observables_dict:
270+
observable_include(b, [], i)
271+
b.observables_dict = {i: b.observables_dict[i] for i in sorted(b.observables_dict)}
272+
264273
return b

src/tsim/utils/diagram.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,10 @@ def render_pyzx_d3(stim_circ: stim.Circuit, kwargs: dict[str, Any]) -> GraphS:
499499
return g
500500

501501
g = g.clone()
502-
max_row = max(g.row(v) for v in built.last_vertex.values())
503-
for q in built.last_vertex:
504-
g.set_row(built.last_vertex[q], max_row)
502+
if built.last_vertex:
503+
max_row = max(g.row(v) for v in built.last_vertex.values())
504+
for q in built.last_vertex:
505+
g.set_row(built.last_vertex[q], max_row)
505506

506507
for v in list(g.vertices()):
507508
phase_vars = g._phaseVars[v]

test/unit/core/test_parse.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,38 @@ def test_spp_multiple_products(self):
697697
assert len(b.rec) == 0
698698

699699

700+
class TestParseEmptyAnnotations:
701+
"""Empty DETECTOR / OBSERVABLE_INCLUDE annotations are valid Stim and represent
702+
deterministic-zero detector/observable bits. They must parse without crashing
703+
and produce a single annotation vertex with no record edges."""
704+
705+
def test_empty_detector_alone(self):
706+
b = parse_stim_circuit(stim.Circuit("DETECTOR"))
707+
assert len(b.detectors) == 1
708+
v = b.detectors[0]
709+
assert b.graph.type(v) == VertexType.X
710+
assert list(b.graph.neighbors(v)) == []
711+
712+
def test_empty_observable_alone(self):
713+
b = parse_stim_circuit(stim.Circuit("OBSERVABLE_INCLUDE(0)"))
714+
assert set(b.observables_dict) == {0}
715+
v = b.observables_dict[0]
716+
assert b.graph.type(v) == VertexType.X
717+
assert list(b.graph.neighbors(v)) == []
718+
719+
def test_empty_detector_after_measurement(self):
720+
b = parse_stim_circuit(stim.Circuit("M 0\nDETECTOR rec[-1]\nDETECTOR"))
721+
assert len(b.detectors) == 2
722+
# First detector has the measurement edge, second has no edges.
723+
assert len(list(b.graph.neighbors(b.detectors[0]))) == 1
724+
assert list(b.graph.neighbors(b.detectors[1])) == []
725+
726+
def test_empty_detector_with_args(self):
727+
"""DETECTOR(coords...) with empty rec must also parse."""
728+
b = parse_stim_circuit(stim.Circuit("DETECTOR(1, 2)"))
729+
assert len(b.detectors) == 1
730+
731+
700732
class TestParseMPPCancellation:
701733
"""Tests for MPP with duplicate/anticommuting Pauli targets.
702734
@@ -853,3 +885,30 @@ def test_tpp_anti_hermitian_raises(self):
853885
"""TPP Z0*X0 = iY is anti-Hermitian and should raise."""
854886
with pytest.raises(ValueError, match="anti-Hermitian"):
855887
parse_stim_circuit(stim.Circuit("SPP[T] Z0*X0"))
888+
889+
890+
class TestParseSparseObservables:
891+
"""OBSERVABLE_INCLUDE indices are sparse and out-of-order in real circuits.
892+
Stim defines num_observables = max_index + 1 and emits one column per id
893+
in sorted order, with missing ids as deterministic-zero columns."""
894+
895+
def test_sparse_observable_index_pads_missing(self):
896+
circuit = stim.Circuit("M 0\nOBSERVABLE_INCLUDE(2) rec[-1]")
897+
b = parse_stim_circuit(circuit)
898+
assert set(b.observables_dict) == {0, 1, 2}
899+
# Missing ids 0 and 1 have no record edges (deterministic zero).
900+
assert list(b.graph.neighbors(b.observables_dict[0])) == []
901+
assert list(b.graph.neighbors(b.observables_dict[1])) == []
902+
# Index 2 has the measurement edge.
903+
assert len(list(b.graph.neighbors(b.observables_dict[2]))) == 1
904+
905+
def test_observables_dict_is_sorted_after_out_of_order(self):
906+
circuit = stim.Circuit(
907+
"M 0\nM 1\nOBSERVABLE_INCLUDE(2) rec[-2]\nOBSERVABLE_INCLUDE(0) rec[-1]"
908+
)
909+
b = parse_stim_circuit(circuit)
910+
assert list(b.observables_dict) == [0, 1, 2]
911+
912+
def test_no_observables_remains_empty(self):
913+
b = parse_stim_circuit(stim.Circuit("M 0"))
914+
assert b.observables_dict == {}

test/unit/test_sampler.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,63 @@ def test_reference_sample_defaults_unchanged():
363363
use_observable_reference_sample=False,
364364
)
365365
assert np.array_equal(d1, d2)
366+
367+
368+
def test_detector_sampler_empty_annotations():
369+
"""Empty DETECTOR / OBSERVABLE_INCLUDE produce deterministic-zero columns
370+
that sit alongside non-empty ones, matching Stim semantics."""
371+
c = Circuit("""
372+
X 0
373+
M 0 1
374+
DETECTOR rec[-2]
375+
DETECTOR
376+
OBSERVABLE_INCLUDE(0) rec[-1]
377+
OBSERVABLE_INCLUDE(1)
378+
""")
379+
assert c.num_detectors == 2
380+
assert c.num_observables == 2
381+
382+
sampler = c.compile_detector_sampler(seed=0)
383+
det, obs = sampler.sample(4, separate_observables=True)
384+
assert det.shape == (4, 2)
385+
assert obs.shape == (4, 2)
386+
# First detector reads the X-induced 1; second detector and both observables
387+
# are deterministically zero.
388+
assert np.all(det[:, 0])
389+
assert not np.any(det[:, 1])
390+
assert not np.any(obs)
391+
392+
393+
def test_detector_sampler_sparse_observable_index():
394+
"""OBSERVABLE_INCLUDE(2) only must produce 3 observable columns, with the
395+
measured value landing in column 2 and columns 0/1 deterministically zero."""
396+
c = Circuit("X 0\nM 0\nOBSERVABLE_INCLUDE(2) rec[-1]")
397+
assert c.num_observables == 3
398+
399+
sampler = c.compile_detector_sampler(seed=0)
400+
samples = sampler.sample(4, append_observables=True)
401+
assert samples.shape == (4, 3)
402+
assert not np.any(samples[:, 0])
403+
assert not np.any(samples[:, 1])
404+
assert np.all(samples[:, 2])
405+
406+
407+
def test_detector_sampler_out_of_order_observable_indices():
408+
"""Out-of-order OBSERVABLE_INCLUDE must produce columns in sorted index order."""
409+
# rec[-2] is qubit-0 (=1 after X), rec[-1] is qubit-1 (=0).
410+
# Map: obs[2] = rec[-2] = 1, obs[0] = rec[-1] = 0, obs[1] is unmentioned = 0.
411+
c = Circuit("""
412+
X 0
413+
M 0 1
414+
OBSERVABLE_INCLUDE(2) rec[-2]
415+
OBSERVABLE_INCLUDE(0) rec[-1]
416+
""")
417+
assert c.num_observables == 3
418+
419+
sampler = c.compile_detector_sampler(seed=0)
420+
_, obs = sampler.sample(2, separate_observables=True)
421+
assert obs.shape == (2, 3)
422+
# Sorted-by-index order: [obs0, obs1, obs2] = [0, 0, 1]
423+
assert not np.any(obs[:, 0])
424+
assert not np.any(obs[:, 1])
425+
assert np.all(obs[:, 2])

0 commit comments

Comments
 (0)