Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create read only simulation state vector #6248

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""Objects and methods for acting efficiently on a state vector."""
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union

from dataclasses import dataclass

import numpy as np

from cirq import linalg, protocols, qis, sim
Expand Down Expand Up @@ -306,7 +308,50 @@ def supports_factor(self) -> bool:
return True


class StateVectorSimulationState(SimulationState[_BufferedStateVector]):
@dataclass(frozen=True)
class _ReadOnlyStateVector(qis.QuantumStateRepresentation):
"""A readonly state vector that represents the final simulation state."""

_state_vector: np.ndarray

def copy(self, deep_copy_buffers: bool = True) -> '_ReadOnlyStateVector':
"""Creates a copy of the object.
Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.
Returns:
A copied instance.
"""
return _ReadOnlyStateVector(
self._state_vector.copy() if deep_copy_buffers else self._state_vector
)

def measure(
self, axes: Sequence[int], seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None
) -> List[int]:
raise TypeError(
'Measurement is not supported by _ReadOnlyStateVector. '
'You can call to_mutable_state() and run measurement on the '
'resultant state object.'
)

def to_mutable_state(self, qid_shape: Optional[Tuple[int, ...]] = None) -> _BufferedStateVector:
return _BufferedStateVector.create(
initial_state=self._state_vector.copy(), qid_shape=qid_shape
)

def apply_unitary(self, action: Any, axes: Sequence[int]) -> bool:
return NotImplemented

def apply_mixture(self, action: Any, axes: Sequence[int], prng) -> Optional[int]:
return NotImplemented

def apply_channel(self, action: Any, axes: Sequence[int], prng) -> Optional[int]:
return NotImplemented


class StateVectorSimulationState(SimulationState[_BufferedStateVector | _ReadOnlyStateVector]):
"""State and context for an operation acting on a state vector.

There are two common ways to act on this object:
Expand All @@ -326,6 +371,7 @@ def __init__(
initial_state: Union[np.ndarray, 'cirq.STATE_VECTOR_LIKE'] = 0,
dtype: Type[np.complexfloating] = np.complex64,
classical_data: Optional['cirq.ClassicalDataStore'] = None,
is_read_only: bool = False,
):
"""Inits StateVectorSimulationState.

Expand All @@ -347,13 +393,22 @@ def __init__(
`target_tenson` is None.
classical_data: The shared classical data container for this
simulation.
is_read_only: Whether the state vector is immutable (_ReadOnlyStateVector)
or mutable (_BufferedStateVector).
"""
state = _BufferedStateVector.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None,
dtype=dtype,
buffer=available_buffer,
)
if is_read_only:
initial_state = initial_state.astype(dtype=dtype)
if qubits is not None:
qid_shape = tuple(q.dimension for q in qubits)
initial_state = initial_state.reshape(qid_shape)
state = _ReadOnlyStateVector(initial_state)
else:
state = _BufferedStateVector.create(
initial_state=initial_state,
qid_shape=tuple(q.dimension for q in qubits) if qubits is not None else None,
dtype=dtype,
buffer=available_buffer,
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
Expand All @@ -365,6 +420,8 @@ def add_qubits(self, qubits: Sequence['cirq.Qid']):
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
if isinstance(self._state, _ReadOnlyStateVector):
return NotImplemented
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
Expand Down
28 changes: 28 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import numpy as np

import cirq
Expand Down Expand Up @@ -207,3 +209,29 @@ def text(self, to_print):
p = FakePrinter()
result._repr_pretty_(p, True)
assert p.text_pretty == 'StateVectorTrialResult(...)'


def test_readonlystatevector():
q = cirq.NamedQubit('q')
final_state_vector = cirq.StateVectorSimulationState(
initial_state=np.array([0, 1]), is_read_only=True, qubits=[q]
)

assert not np.may_share_memory(
final_state_vector._state._state_vector, final_state_vector._state.copy()._state_vector
)
assert np.may_share_memory(
final_state_vector._state._state_vector,
final_state_vector._state.copy(deep_copy_buffers=False)._state_vector,
)

with pytest.raises(TypeError, match='Measurement is not supported by _ReadOnlyStateVector..*'):
_ = final_state_vector.measure([q], 0, '', {})

state = final_state_vector._state
assert state.to_mutable_state().measure([0], 0) == [1]

assert state.apply_unitary(0, []) is NotImplemented
assert state.apply_mixture(0, [], 0) is NotImplemented
assert state.apply_channel(0, [], 0) is NotImplemented
assert final_state_vector.remove_qubits([q]) is NotImplemented