From 6537ef01aa4ea3b43fc70f042d506b96556f2d56 Mon Sep 17 00:00:00 2001 From: melody-ren Date: Tue, 3 Dec 2024 09:48:51 -0800 Subject: [PATCH] Make DecoderResult tuple-like in python (#16) Add python binding for DecoderResult to make it tuple-like Signed-off-by: Melody Ren --- .../qec/python/circuit_level_noise.py | 4 ++-- .../qec/python/code_capacity_noise.py | 4 ++-- .../examples/qec/python/pseudo_threshold.py | 4 ++-- libs/qec/python/bindings/py_decoder.cpp | 19 +++++++++++++++++- libs/qec/python/tests/test_decoder.py | 20 +++++++++---------- 5 files changed, 34 insertions(+), 17 deletions(-) diff --git a/docs/sphinx/examples/qec/python/circuit_level_noise.py b/docs/sphinx/examples/qec/python/circuit_level_noise.py index d65f456..43ddaee 100644 --- a/docs/sphinx/examples/qec/python/circuit_level_noise.py +++ b/docs/sphinx/examples/qec/python/circuit_level_noise.py @@ -55,8 +55,8 @@ for syndrome in syndromes: print("syndrome:", syndrome) # decode the syndrome - result = decoder.decode(syndrome) - data_prediction = np.array(result.result, dtype=np.uint8) + convergence, result = decoder.decode(syndrome) + data_prediction = np.array(result, dtype=np.uint8) # see if the decoded result anti-commutes with the observables print("decode result:", data_prediction) diff --git a/docs/sphinx/examples/qec/python/code_capacity_noise.py b/docs/sphinx/examples/qec/python/code_capacity_noise.py index 95deaa2..274515f 100644 --- a/docs/sphinx/examples/qec/python/code_capacity_noise.py +++ b/docs/sphinx/examples/qec/python/code_capacity_noise.py @@ -32,8 +32,8 @@ print(f"syndrome: {syndrome}") # Decode the syndrome to predict what happen to the data - result = decoder.decode(syndrome) - data_prediction = np.array(result.result, dtype=np.uint8) + convergence, result = decoder.decode(syndrome) + data_prediction = np.array(result, dtype=np.uint8) print(f"data_prediction: {data_prediction}") # See if this prediction flipped the observable diff --git a/docs/sphinx/examples/qec/python/pseudo_threshold.py b/docs/sphinx/examples/qec/python/pseudo_threshold.py index 1bc20cd..9f0c42c 100644 --- a/docs/sphinx/examples/qec/python/pseudo_threshold.py +++ b/docs/sphinx/examples/qec/python/pseudo_threshold.py @@ -28,8 +28,8 @@ # Calculate which syndromes are flagged. syndrome = Hz@data % 2 - result = decoder.decode(syndrome) - data_prediction = np.array(result.result) + convergence, result = decoder.decode(syndrome) + data_prediction = np.array(result) predicted_observable = observable@data_prediction % 2 diff --git a/libs/qec/python/bindings/py_decoder.cpp b/libs/qec/python/bindings/py_decoder.cpp index 710f841..8b6d591 100644 --- a/libs/qec/python/bindings/py_decoder.cpp +++ b/libs/qec/python/bindings/py_decoder.cpp @@ -96,7 +96,24 @@ void bindDecoder(py::module &mod) { Contains the sequence of corrections that should be applied to recover the original quantum state. The format depends on the specific decoder implementation. - )pbdoc"); + )pbdoc") + // Add tuple interface + .def("__len__", [](const decoder_result &) { return 2; }) + .def("__getitem__", + [](const decoder_result &r, size_t i) { + switch (i) { + case 0: + return py::cast(r.converged); + case 1: + return py::cast(r.result); + default: + throw py::index_error(); + } + }) + // Enable iteration protocol + .def("__iter__", [](const decoder_result &r) -> py::object { + return py::iter(py::make_tuple(r.converged, r.result)); + }); py::class_( qecmod, "Decoder", "Represents a decoder for quantum error correction") diff --git a/libs/qec/python/tests/test_decoder.py b/libs/qec/python/tests/test_decoder.py index a79bc0b..5480cad 100644 --- a/libs/qec/python/tests/test_decoder.py +++ b/libs/qec/python/tests/test_decoder.py @@ -37,7 +37,7 @@ def test_decoder_result_structure(): def test_decoder_result_values(): decoder = qec.get_decoder('example_byod', H) result = decoder.decode(create_test_syndrome()) - + assert result.converged is True assert all(isinstance(x, float) for x in result.result) assert all(0 <= x <= 1 for x in result.result) @@ -54,12 +54,12 @@ def test_decoder_different_matrix_sizes(matrix_shape, syndrome_size): syndrome = np.random.random(syndrome_size).tolist() decoder = qec.get_decoder('example_byod', H) - result = decoder.decode(syndrome) + convergence, result = decoder.decode(syndrome) - assert len(result.result) == syndrome_size - assert result.converged is True - assert all(isinstance(x, float) for x in result.result) - assert all(0 <= x <= 1 for x in result.result) + assert len(result) == syndrome_size + assert convergence is True + assert all(isinstance(x, float) for x in result) + assert all(0 <= x <= 1 for x in result) # FIXME add this back # def test_decoder_error_handling(): @@ -80,13 +80,13 @@ def test_decoder_reproducibility(): decoder = qec.get_decoder('example_byod', H) np.random.seed(42) - result1 = decoder.decode(create_test_syndrome()) + convergence1, result1 = decoder.decode(create_test_syndrome()) np.random.seed(42) - result2 = decoder.decode(create_test_syndrome()) + convergence2, result2 = decoder.decode(create_test_syndrome()) - assert result1.result == result2.result - assert result1.converged == result2.converged + assert result1 == result2 + assert convergence1 == convergence2 def test_pass_weights(): error_probability = 0.1