Skip to content

Commit

Permalink
Make DecoderResult tuple-like in python (NVIDIA#16)
Browse files Browse the repository at this point in the history
Add python binding for DecoderResult to make it tuple-like

Signed-off-by: Melody Ren <[email protected]>
  • Loading branch information
melody-ren authored Dec 3, 2024
1 parent a096e55 commit 6537ef0
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 17 deletions.
4 changes: 2 additions & 2 deletions docs/sphinx/examples/qec/python/circuit_level_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
for syndrome in syndromes:
print("syndrome:", syndrome)
# decode the syndrome
result = decoder.decode(syndrome)
data_prediction = np.array(result.result, dtype=np.uint8)
convergence, result = decoder.decode(syndrome)
data_prediction = np.array(result, dtype=np.uint8)

# see if the decoded result anti-commutes with the observables
print("decode result:", data_prediction)
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx/examples/qec/python/code_capacity_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
print(f"syndrome: {syndrome}")

# Decode the syndrome to predict what happen to the data
result = decoder.decode(syndrome)
data_prediction = np.array(result.result, dtype=np.uint8)
convergence, result = decoder.decode(syndrome)
data_prediction = np.array(result, dtype=np.uint8)
print(f"data_prediction: {data_prediction}")

# See if this prediction flipped the observable
Expand Down
4 changes: 2 additions & 2 deletions docs/sphinx/examples/qec/python/pseudo_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
# Calculate which syndromes are flagged.
syndrome = Hz@data % 2

result = decoder.decode(syndrome)
data_prediction = np.array(result.result)
convergence, result = decoder.decode(syndrome)
data_prediction = np.array(result)

predicted_observable = observable@data_prediction % 2

Expand Down
19 changes: 18 additions & 1 deletion libs/qec/python/bindings/py_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,24 @@ void bindDecoder(py::module &mod) {
Contains the sequence of corrections that should be applied to recover
the original quantum state. The format depends on the specific decoder
implementation.
)pbdoc");
)pbdoc")
// Add tuple interface
.def("__len__", [](const decoder_result &) { return 2; })
.def("__getitem__",
[](const decoder_result &r, size_t i) {
switch (i) {
case 0:
return py::cast(r.converged);
case 1:
return py::cast(r.result);
default:
throw py::index_error();
}
})
// Enable iteration protocol
.def("__iter__", [](const decoder_result &r) -> py::object {
return py::iter(py::make_tuple(r.converged, r.result));
});

py::class_<decoder, PyDecoder>(
qecmod, "Decoder", "Represents a decoder for quantum error correction")
Expand Down
20 changes: 10 additions & 10 deletions libs/qec/python/tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_decoder_result_structure():
def test_decoder_result_values():
decoder = qec.get_decoder('example_byod', H)
result = decoder.decode(create_test_syndrome())

assert result.converged is True
assert all(isinstance(x, float) for x in result.result)
assert all(0 <= x <= 1 for x in result.result)
Expand All @@ -54,12 +54,12 @@ def test_decoder_different_matrix_sizes(matrix_shape, syndrome_size):
syndrome = np.random.random(syndrome_size).tolist()

decoder = qec.get_decoder('example_byod', H)
result = decoder.decode(syndrome)
convergence, result = decoder.decode(syndrome)

assert len(result.result) == syndrome_size
assert result.converged is True
assert all(isinstance(x, float) for x in result.result)
assert all(0 <= x <= 1 for x in result.result)
assert len(result) == syndrome_size
assert convergence is True
assert all(isinstance(x, float) for x in result)
assert all(0 <= x <= 1 for x in result)

# FIXME add this back
# def test_decoder_error_handling():
Expand All @@ -80,13 +80,13 @@ def test_decoder_reproducibility():
decoder = qec.get_decoder('example_byod', H)

np.random.seed(42)
result1 = decoder.decode(create_test_syndrome())
convergence1, result1 = decoder.decode(create_test_syndrome())

np.random.seed(42)
result2 = decoder.decode(create_test_syndrome())
convergence2, result2 = decoder.decode(create_test_syndrome())

assert result1.result == result2.result
assert result1.converged == result2.converged
assert result1 == result2
assert convergence1 == convergence2

def test_pass_weights():
error_probability = 0.1
Expand Down

0 comments on commit 6537ef0

Please sign in to comment.