Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Add finite differences jvps #6853

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9bc4fa9
restructure qnode_prim access
albi3ro Jan 16, 2025
b6ce7d3
fix improt
albi3ro Jan 16, 2025
7ee4757
add backprop validation and some jvp structure
albi3ro Jan 16, 2025
d6155de
add finite difference derivatives
albi3ro Jan 16, 2025
a26231e
changelog
albi3ro Jan 17, 2025
d246020
Merge branch 'master' into finite-diff-capture
albi3ro Jan 22, 2025
1ed5e9e
move workflow capture tests
albi3ro Jan 22, 2025
cac689e
adding tests
albi3ro Jan 22, 2025
300fdca
Merge branch 'master' into finite-diff-capture
albi3ro Jan 22, 2025
e056a67
somehow file was properly moved
albi3ro Jan 22, 2025
f34f1ed
Merge branch 'finite-diff-capture' of https://github.com/PennyLaneAI/…
albi3ro Jan 22, 2025
39f371d
fixing test
albi3ro Jan 22, 2025
17b9fcd
minor clean up
albi3ro Jan 23, 2025
0df100c
add finite_diff_jvp to gradients module
albi3ro Jan 23, 2025
ab2e6c6
adding tests for finite_difF_jvp
albi3ro Jan 23, 2025
50425b1
one additional tesT
albi3ro Jan 23, 2025
8f28f06
Merge branch 'master' into finite-diff-capture
albi3ro Jan 23, 2025
fd19014
add strategy and approx_order
albi3ro Jan 23, 2025
dd4a49b
minor efficiency rewriting
albi3ro Jan 23, 2025
09a2671
Apply suggestions from code review
albi3ro Jan 24, 2025
b87978a
responding to feedback
albi3ro Jan 24, 2025
e6ba05a
Apply suggestions from code review
albi3ro Jan 24, 2025
c53f10d
mergeing
albi3ro Jan 24, 2025
a2d622f
Apply suggestions from code review
albi3ro Jan 27, 2025
b045726
Merge branch 'master' into finite-diff-capture
albi3ro Jan 27, 2025
297b88d
Merge branch 'master' into finite-diff-capture
albi3ro Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
* An informative error is raised when a `QNode` with `diff_method=None` is differentiated.
[(#6770)](https://github.com/PennyLaneAI/pennylane/pull/6770)

* With program capture enabled, `QNode`'s can now be differentiated with `diff_method="finite-diff"`.
[(#6853)](https://github.com/PennyLaneAI/pennylane/pull/6853)

* The requested `diff_method` is now validated when program capture is enabled.
[(#6852)](https://github.com/PennyLaneAI/pennylane/pull/6852)

Expand Down
55 changes: 51 additions & 4 deletions pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_di
raise NotImplementedError(
"Overriding shots is not yet supported with the program capture execution."
)
if qnode_kwargs["diff_method"] not in {"backprop", "best"}:
raise NotImplementedError("Only backpropagation derivatives are supported at this time.")

consts = args[:n_consts]
non_const_args = args[n_consts:]
Expand Down Expand Up @@ -251,6 +249,7 @@ def _qnode_batching_rule(
"using parameter broadcasting to a quantum operation that supports batching.",
UserWarning,
)

# To resolve this ambiguity, we might add more properties to the AbstractOperator
# class to indicate which operators support batching and check them here.
# As above, at this stage we raise a warning and give the user full flexibility.
Expand Down Expand Up @@ -291,7 +290,50 @@ def _backprop(args, tangents, **impl_kwargs):
return jax.jvp(partial(qnode_prim.impl, **impl_kwargs), args, tangents)


diff_method_map = {"backprop": _backprop}
def _finite_diff(args, tangents, **impl_kwargs):

gradient_kwargs = impl_kwargs["qnode_kwargs"]["gradient_kwargs"]
h = gradient_kwargs.get("h", 1e-6)
if gradient_kwargs.get("approx_order", 1) != 1:
raise NotImplementedError("only approx_order=1 is currently supported.")
if gradient_kwargs.get("strategy", "forward") != "forward":
raise NotImplementedError("only strategy='forward' is currently supported.")
available_kwargs = {"h", "strategy", "approx_order"}
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
if any(kwarg not in available_kwargs for kwarg in gradient_kwargs):
raise ValueError(
f"The only available gradient kwargs for finite diff are {available_kwargs}. Got {gradient_kwargs}."
)

res1 = qnode_prim.bind(*args, **impl_kwargs)

jvps = [0 for _ in res1]
for i, t in enumerate(tangents):
if isinstance(t, ad.Zero):
continue
shifted_args = list(args)

shape = getattr(shifted_args[i], "shape", ())
flat_arg = jax.numpy.reshape(shifted_args[i], -1)
flat_t = jax.numpy.reshape(t, -1)

if getattr(flat_arg, "dtype", None) == jax.numpy.float32:
warn(
"Detected float32 parameter with finite differences. Recommend use of float64 with finite diff.",
UserWarning,
)

for element_idx, element in enumerate(flat_arg):
arg = flat_arg.at[element_idx].set(element + h)
shifted_args[i] = jax.numpy.reshape(arg, shape)
res2 = qnode_prim.bind(*shifted_args, **impl_kwargs)

for result_idx, (r1, r2) in enumerate(zip(res1, res2)):
jvps[result_idx] += flat_t[element_idx] * (r2 - r1) / h

return res1, jvps


diff_method_map = {"backprop": _backprop, "finite-diff": _finite_diff}


def _resolve_diff_method(diff_method: str, device) -> str:
Expand Down Expand Up @@ -405,7 +447,12 @@ def f(x):

execute_kwargs = copy(qnode.execute_kwargs)
mcm_config = asdict(execute_kwargs.pop("mcm_config"))
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_kwargs = {
"diff_method": qnode.diff_method,
**execute_kwargs,
"gradient_kwargs": qnode.gradient_kwargs,
**mcm_config,
}

flat_args = jax.tree_util.tree_leaves(args)

Expand Down
236 changes: 236 additions & 0 deletions tests/capture/workflow/test_capture_finite_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains tests for using finite difference derivatives
with program capture enabled.
"""

import pytest

import pennylane as qml

pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]

jax = pytest.importorskip("jax")
jnp = pytest.importorskip("jax.numpy")


class TestValidation:

def test_approx_order_unsupported_error(self):
"""Test that a NotImplementedError is raised for higher approx_order."""

@qml.qnode(qml.device("default.qubit", wires=2), diff_method="finite-diff", approx_order=2)
def f(_):
return qml.expval(qml.Z(0))

with pytest.raises(NotImplementedError, match="only approx_order=1"):
jax.grad(f)(0.5)

def test_strategy_unsupported_error(self):
"""Test that a NotImplementedError is raised for different strategies."""

@qml.qnode(
qml.device("default.qubit", wires=2), diff_method="finite-diff", strategy="backward"
)
def f(_):
return qml.expval(qml.Z(0))

with pytest.raises(NotImplementedError, match="only strategy='forward'"):
jax.grad(f)(0.5)

def test_unsupported_kwarg(self):
"""Test that an error is raised for unsupported gradient kwargs."""

@qml.qnode(
qml.device("default.qubit", wires=2), diff_method="finite-diff", something="value"
)
def f(_):
return qml.expval(qml.Z(0))

with pytest.raises(
ValueError, match="The only available gradient kwargs for finite diff are"
):
jax.grad(f)(0.5)

def test_warning_float32(self):
"""Test that a warning is raised if trainable inputs are float32."""

@qml.qnode(qml.device("default.qubit", wires=1), diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

with pytest.warns(UserWarning, match="Detected float32 parameter with finite differences."):
jax.grad(circuit)(jnp.array(0.5, dtype=jnp.float32))


class TestGradients:
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("grad_f", (jax.grad, jax.jacobian))
def test_simple_circuit(self, grad_f):
"""Test accurage results for a simple, single parameter circuit."""
albi3ro marked this conversation as resolved.
Show resolved Hide resolved

@qml.qnode(qml.device("default.qubit", wires=1), diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

x = 0.5
result = grad_f(circuit)(x)

assert qml.math.allclose(result, -jnp.sin(x))

@pytest.mark.parametrize("argnums", ((0,), (1,), (0, 1)))
def test_multi_inputs(self, argnums):
"""Test gradients can be computed with multiple scalar inputs."""

@qml.qnode(qml.device("default.qubit", wires=2), diff_method="finite-diff")
def circuit(x, y):
qml.RX(x, 0)
qml.RY(y, 1)
qml.CNOT((0, 1))
return qml.expval(qml.Z(1))

x = 1.2
y = jnp.array(2.0)
grad = jax.grad(circuit, argnums=argnums)(x, y)

grad_x = -jnp.sin(x) * jnp.cos(y)
grad_y = -jnp.cos(x) * jnp.sin(y)
g = [grad_x, grad_y]
expected_grad = [g[i] for i in argnums]

assert qml.math.allclose(grad, expected_grad)

def test_array_input(self):
"""Test that we can differentiate a circuit with an array input."""

@qml.qnode(qml.device("default.qubit", wires=3), diff_method="finite-diff")
def circuit(x):
qml.RX(x[0], 0)
qml.RX(x[1], 1)
qml.RX(x[2], 2)
return qml.expval(qml.Z(0) @ qml.Z(1) @ qml.Z(2))

x = jnp.array([0.5, 1.0, 1.5])
grad = jax.grad(circuit)(x)
assert grad.shape == (3,)

grad0 = -jnp.sin(x[0]) * jnp.cos(x[1]) * jnp.cos(x[2])
assert qml.math.allclose(grad[0], grad0)
grad1 = jnp.cos(x[0]) * -jnp.sin(x[1]) * jnp.cos(x[2])
assert qml.math.allclose(grad[1], grad1)
grad2 = jnp.cos(x[0]) * jnp.cos(x[1]) * -jnp.sin(x[2])
assert qml.math.allclose(grad[2], grad2)

def test_jacobian_multiple_outputs(self):
"""Test that finite diff can handle multiple outputs."""

@qml.qnode(qml.device("default.qubit", wires=1), diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
return (
qml.probs(wires=0),
qml.expval(qml.Z(0)),
qml.expval(qml.Y(0)),
qml.expval(qml.X(0)),
)

x = jnp.array(-0.65)
jac = jax.jacobian(circuit)(x)

# probs = [cos(x/2)**2, sin(x/2)**2]
probs_jac = [-jnp.cos(x / 2) * jnp.sin(x / 2), jnp.sin(x / 2) * jnp.cos(x / 2)]
assert qml.math.allclose(jac[0], probs_jac)

assert qml.math.allclose(jac[1], -jnp.sin(x))
assert qml.math.allclose(jac[2], -jnp.cos(x))
assert qml.math.allclose(jac[3], 0)

def test_classical_control_flow(self):
"""Test that classical control flow can exist inside the circuit."""

@qml.qnode(qml.device("default.qubit", wires=4), diff_method="finite-diff")
def circuit(x):
@qml.for_loop(3)
def f(i):
qml.cond(i < 2, qml.RX, false_fn=qml.RZ)(x[i], i)

f()
return [qml.expval(qml.Z(i)) for i in range(3)]

x = jnp.array([0.2, 0.6, 1.0])
jac = jax.jacobian(circuit)(x)

assert qml.math.allclose(jac[0][0], -jnp.sin(x[0]))
assert qml.math.allclose(jac[0][1:], 0)

assert qml.math.allclose(jac[1][0], 0)
assert qml.math.allclose(jac[1][1], -jnp.sin(x[1]))
assert qml.math.allclose(jac[1][2], 0)

# i = 2 applies RZ. grad should be zero
assert qml.math.allclose(jac[2], jnp.zeros(3))

def test_pre_and_postprocessing(self):
"""Test that we can chain together pre and post processing."""

@qml.qnode(qml.device("default.qubit", wires=4), diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

def workflow(y):
return 2 * circuit(y**2)

x = jnp.array(-0.9)
jac = jax.jacobian(workflow)(x)

# res = 2*cos(y**2)
# dres = 2 * -sin(y**2) * 2 *y
expected = 2 * -jnp.sin(x**2) * 2 * x
assert qml.math.allclose(jac, expected)

def test_hessian(self):
"""Test that higher order derivatives like the hessian can be computed."""

@qml.qnode(qml.device("default.qubit", wires=4), diff_method="finite-diff")
def circuit(x):
qml.RX(x, 0)
return qml.expval(qml.Z(0))

hess = jax.grad(jax.grad(circuit))(0.5)
print(hess - -jnp.cos(0.5))
albi3ro marked this conversation as resolved.
Show resolved Hide resolved
assert qml.math.allclose(hess, -jnp.cos(0.5), atol=5e-4) # gets noisy

@pytest.mark.parametrize("argnums", ((0,), (0, 1)))
def test_jaxpr_contents(self, argnums):
"""Make some tests on the captured jaxpr to assert we are doing the correct thing."""

@qml.qnode(qml.device("default.qubit", wires=1), diff_method="finite-diff", h=1e-4)
def circuit(x, y):
qml.RX(x, 0)
qml.RY(y, 0)
return qml.expval(qml.Z(0))

jaxpr = jax.make_jaxpr(jax.grad(circuit, argnums=argnums))(0.5, 1.2)

qnode_eqns = [eqn for eqn in jaxpr.eqns if eqn.primitive.name == "qnode"]
assert len(qnode_eqns) == 1 + len(argnums)

for eqn in jaxpr.eqns:
if eqn.primitive.name == "add":
# only addition eqns are adding h to var
assert eqn.invars[1].val == 1e-4
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def circuit(x):
assert eqn0.params["device"] == dev
assert eqn0.params["qnode"] == circuit
assert eqn0.params["shots"] == qml.measurements.Shots(None)
expected_kwargs = {"diff_method": "best"}
expected_kwargs = {"diff_method": "best", "gradient_kwargs": {}}
expected_kwargs.update(circuit.execute_kwargs)
expected_kwargs.update(asdict(expected_kwargs.pop("mcm_config")))
assert eqn0.params["qnode_kwargs"] == expected_kwargs
Expand Down Expand Up @@ -287,6 +287,7 @@ def circuit():
"device_vjp": False,
"mcm_method": None,
"postselect_mode": None,
"gradient_kwargs": {},
}
assert jaxpr.eqns[0].params["qnode_kwargs"] == expected

Expand Down
Loading