Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploring how to handle dynamic decompositions with PLxPR enabled #6859

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
696acf3
E.C.
PietropaoloFrisoni Jan 20, 2025
9cdfae9
Creating an empty `DynamicDecomposeInterpreter` c;ass
PietropaoloFrisoni Jan 21, 2025
1e138fa
Sbattendo la testa contro il muro tante volte
PietropaoloFrisoni Jan 21, 2025
76c9250
Current prototype version
PietropaoloFrisoni Jan 22, 2025
c821ac9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
cee6ec4
Fixing one more problem
PietropaoloFrisoni Jan 22, 2025
a208dc5
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 22, 2025
d16a9e0
Moving tests to separate file
PietropaoloFrisoni Jan 23, 2025
a8b9283
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
18bc43c
Pylint fixes (although premature)
PietropaoloFrisoni Jan 23, 2025
e2e8fd0
Removing reundandt tuple calls
PietropaoloFrisoni Jan 23, 2025
0abd620
Tests with dynamic wires
PietropaoloFrisoni Jan 23, 2025
1e3ffb6
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 23, 2025
1ae399b
Adding Autograph test
PietropaoloFrisoni Jan 23, 2025
9a54c3e
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
c5f2ae5
Removing unused parameters and adding a few tests
PietropaoloFrisoni Jan 24, 2025
497440c
Adding a few more tests
PietropaoloFrisoni Jan 24, 2025
c7da133
Removing import
PietropaoloFrisoni Jan 24, 2025
2f0417c
Pylint
PietropaoloFrisoni Jan 24, 2025
3c8bc37
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import warnings
from collections.abc import Callable, Generator, Iterable
from functools import lru_cache, partial
from typing import Optional
from typing import Optional, Sequence

import pennylane as qml
from pennylane.transforms.core import transform
Expand Down Expand Up @@ -187,6 +187,92 @@
DecomposeInterpreter, decompose_plxpr_to_plxpr = _get_plxpr_decompose()


@lru_cache
def _get_plxpr_dynamic_decompose(): # pylint: disable=missing-docstring
try:
# pylint: disable=import-outside-toplevel
# pylint: disable=unused-import
import jax
from pennylane.capture.primitives import AbstractMeasurement, AbstractOperator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name

class DynamicDecomposeInterpreter(qml.capture.PlxprInterpreter):
"""
Experimental Plxpr Interpreter for applying a dynamic decomposition to operations program capture is enabled.

"""

def eval_dynamic_decomposition(
self, jaxpr_decomp: "jax.core.Jaxpr", consts: Sequence, *args
):
"""
Evaluate a dynamic decomposition of a Jaxpr.

Args:
jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate
consts (Sequence): the constants to use in the evaluation
*args: the arguments to use in the evaluation

"""

for arg, invar in zip(args, jaxpr_decomp.invars, strict=True):
self._env[invar] = arg
for const, constvar in zip(consts, jaxpr_decomp.constvars, strict=True):
self._env[constvar] = const

Check warning on line 224 in pennylane/transforms/decompose.py

View check run for this annotation

Codecov / codecov/patch

pennylane/transforms/decompose.py#L224

Added line #L224 was not covered by tests

for inner_eqn in jaxpr_decomp.eqns:

custom_handler = self._primitive_registrations.get(inner_eqn.primitive, None)

if custom_handler:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = custom_handler(self, *invals, **inner_eqn.params)

elif isinstance(inner_eqn.outvars[0].aval, AbstractOperator):
outvals = super().interpret_operation_eqn(inner_eqn)
elif isinstance(inner_eqn.outvars[0].aval, AbstractMeasurement):
outvals = super().interpret_measurement_eqn(inner_eqn)
else:
invals = [self.read(invar) for invar in inner_eqn.invars]
outvals = inner_eqn.primitive.bind(*invals, **inner_eqn.params)

if not inner_eqn.primitive.multiple_results:
outvals = [outvals]

for inner_outvar, inner_outval in zip(inner_eqn.outvars, outvals, strict=True):
self._env[inner_outvar] = inner_outval

def interpret_operation_eqn(self, eqn: "jax.core.JaxprEqn"):
"""
Interpret an equation corresponding to an operator.

Args:
eqn (jax.core.JaxprEqn): a jax equation for an operator.
"""

invals = (self.read(invar) for invar in eqn.invars)
with qml.QueuingManager.stop_recording():
op = eqn.primitive.impl(*invals, **eqn.params)

if hasattr(op, "_compute_plxpr_decomposition"):

jaxpr_decomp = op._plxpr_decomposition()
args = (*op.parameters, *op.wires, *op.hyperparameters)
return self.eval_dynamic_decomposition(
jaxpr_decomp.jaxpr, jaxpr_decomp.consts, *args
)

return super().interpret_operation_eqn(eqn)

return DynamicDecomposeInterpreter


DynamicDecomposeInterpreter = _get_plxpr_dynamic_decompose()


@partial(transform, plxpr_transform=decompose_plxpr_to_plxpr)
def decompose(tape, gate_set=None, max_expansion=None):
"""Decomposes a quantum circuit into a user-specified gate set.
Expand Down
1 change: 1 addition & 0 deletions tests/capture/transforms/test_capture_decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from pennylane.transforms.decompose import DecomposeInterpreter, decompose_plxpr_to_plxpr


pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]


Expand Down
Loading
Loading