Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tensorflow_quantum/core/ops/tfq_ps_decompose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
new_op_map["exponent_scalar"].mutable_arg_value()->set_float_value(
cur_exponent_scalar * -0.5);
new_op_map["exponent"].set_symbol(symbol);
// Copy over control metadata.
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
cur_op_map["control_qubits"].arg_value().string_value());
new_op_map["control_values"].mutable_arg_value()->set_string_value(
cur_op_map["control_values"].arg_value().string_value());
// Step 4. add qubits.
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
return new_op;
Expand Down Expand Up @@ -215,6 +220,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
}
// Step 4. add qubits.
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
// Copy over control metadata.
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
cur_op_map["control_qubits"].arg_value().string_value());
new_op_map["control_values"].mutable_arg_value()->set_string_value(
cur_op_map["control_values"].arg_value().string_value());
return new_op;
}

Expand Down Expand Up @@ -251,6 +261,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
}
*new_op.mutable_qubits() = {cur_op_qubits.begin() + use_target,
cur_op_qubits.end() - !use_target};
// Copy over control metadata.
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
cur_op_map["control_qubits"].arg_value().string_value());
new_op_map["control_values"].mutable_arg_value()->set_string_value(
cur_op_map["control_values"].arg_value().string_value());
return new_op;
}

Expand Down Expand Up @@ -290,6 +305,11 @@ class TfqPsDecomposeOp : public tensorflow::OpKernel {
}
// Step 4. add qubits.
*new_op.mutable_qubits() = {cur_op_qubits.begin(), cur_op_qubits.end()};
// Copy over control metadata.
new_op_map["control_qubits"].mutable_arg_value()->set_string_value(
cur_op_map["control_qubits"].arg_value().string_value());
new_op_map["control_values"].mutable_arg_value()->set_string_value(
cur_op_map["control_values"].arg_value().string_value());
return new_op;
}
};
Expand Down
157 changes: 137 additions & 20 deletions tensorflow_quantum/core/serialize/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,51 @@ def _symbol_extractor(x):
"information.")


def _serialize_controls(gate):
"""Helper to serialize control qubits if applicable."""
if hasattr(gate, '_tfq_control_qubits'):
return ','.join(
v2.qubit_to_proto_id(q) for q in gate._tfq_control_qubits)
return ''


def _serialize_control_vals(gate):
"""Helper to serialize control values if applicable.."""
if hasattr(gate, '_tfq_control_values'):
return ','.join(str(v[0]) for v in gate._tfq_control_values)
return ''


class DelayedAssignmentGate(cirq.Gate):
"""Class to do control qubit assignment before sub_gate qubit assignment."""

def __init__(self, gate_callable, control_qubits, control_values):
self._gate_callable = gate_callable
self._control_qubits = control_qubits
self._control_values = control_values

def _qid_shape_(self):
raise ValueError("Called qid_shape on workaround class.")

# pylint: disable=invalid-name
def on(self, *qubits):
"""Returns gate_callable on qubits controlled by contol_qubits."""
return self._gate_callable(*qubits).controlled_by(
*self._control_qubits, control_values=self._control_values)

# pylint: enable=invalid-name


def _optional_control_promote(gate, qubits_message, values_message):
"""Optionally promote to controlled gate based on serialized control msg."""
if qubits_message == '' and values_message == '':
return gate
qbs = [v2.qubit_from_proto_id(qb) for qb in qubits_message.split(',')]
vals = [int(cv) for cv in values_message.split(',')]

return DelayedAssignmentGate(gate, qbs, vals)


def _eigen_gate_serializer(gate_type, serialized_id):
"""Make standard serializer for eigen gates."""

Expand All @@ -124,7 +169,14 @@ def _eigen_gate_serializer(gate_type, serialized_id):
cirq.google.SerializingArg(
serialized_name="global_shift",
serialized_type=float,
op_getter=lambda x: float(x.gate._global_shift))
op_getter=lambda x: float(x.gate._global_shift)),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: _serialize_controls(x)),
cirq.google.SerializingArg(
serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: _serialize_control_vals(x))
]
return cirq.google.GateOpSerializer(gate_type=gate_type,
serialized_gate_id=serialized_id,
Expand All @@ -135,26 +187,35 @@ def _eigen_gate_serializer(gate_type, serialized_id):
def _eigen_gate_deserializer(gate_type, serialized_id):
"""Make standard deserializer for eigen gates."""

def _scalar_combiner(exponent, global_shift, exponent_scalar):
def _scalar_combiner(exponent, global_shift, exponent_scalar,
control_qubits, control_values):
"""This is a workaround to support symbol scalar multiplication.
In the future we should likely get rid of this in favor of proper
expression parsing once cirq supports it. See cirq.op_serializer
and cirq's program protobuf for details. This is needed for things
like cirq.rx('alpha').
"""
if exponent_scalar == 1.0:
return gate_type(exponent=_round(exponent),
global_shift=_round(global_shift))
return gate_type(exponent=_round(exponent) * _round(exponent_scalar),
global_shift=_round(global_shift))
return _optional_control_promote(
gate_type(exponent=_round(exponent),
global_shift=_round(global_shift)), control_qubits,
control_values)
return _optional_control_promote(
gate_type(exponent=_round(exponent) * _round(exponent_scalar),
global_shift=_round(global_shift)), control_qubits,
control_values)

args = [
cirq.google.DeserializingArg(serialized_name="exponent",
constructor_arg_name="exponent"),
cirq.google.DeserializingArg(serialized_name="global_shift",
constructor_arg_name="global_shift"),
cirq.google.DeserializingArg(serialized_name="exponent_scalar",
constructor_arg_name="exponent_scalar")
constructor_arg_name="exponent_scalar"),
cirq.google.DeserializingArg(serialized_name="control_qubits",
constructor_arg_name="control_qubits"),
cirq.google.DeserializingArg(serialized_name="control_values",
constructor_arg_name="control_values")
]
return cirq.google.GateOpDeserializer(serialized_gate_id=serialized_id,
gate_constructor=_scalar_combiner,
Expand All @@ -181,6 +242,13 @@ def _fsim_gate_serializer():
serialized_name="phi_scalar",
serialized_type=float,
op_getter=lambda x: _scalar_extractor(x.gate.phi)),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: _serialize_controls(x)),
cirq.google.SerializingArg(
serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: _serialize_control_vals(x))
]
return cirq.google.GateOpSerializer(gate_type=cirq.FSimGate,
serialized_gate_id="FSIM",
Expand All @@ -191,12 +259,15 @@ def _fsim_gate_serializer():
def _fsim_gate_deserializer():
"""Make standard deserializer for fsim gate."""

def _scalar_combiner(theta, theta_scalar, phi, phi_scalar):
def _scalar_combiner(theta, theta_scalar, phi, phi_scalar, control_qubits,
control_values):
"""This is a workaround to support symbol scalar multiplication.
See `_eigen_gate_deserializer` for details.
"""
return cirq.FSimGate(theta=_round(theta) * _round(theta_scalar),
phi=_round(phi) * _round(phi_scalar))
return _optional_control_promote(
cirq.FSimGate(theta=_round(theta) * _round(theta_scalar),
phi=_round(phi) * _round(phi_scalar)), control_qubits,
control_values)

args = [
cirq.google.DeserializingArg(serialized_name="theta",
Expand All @@ -207,6 +278,10 @@ def _scalar_combiner(theta, theta_scalar, phi, phi_scalar):
constructor_arg_name="theta_scalar"),
cirq.google.DeserializingArg(serialized_name="phi_scalar",
constructor_arg_name="phi_scalar"),
cirq.google.DeserializingArg(serialized_name="control_qubits",
constructor_arg_name="control_qubits"),
cirq.google.DeserializingArg(serialized_name="control_values",
constructor_arg_name="control_values")
]
return cirq.google.GateOpDeserializer(serialized_gate_id="FSIM",
gate_constructor=_scalar_combiner,
Expand All @@ -228,7 +303,14 @@ def _identity_check(x):
args = [
cirq.google.SerializingArg(serialized_name="unused",
serialized_type=bool,
op_getter=_identity_check)
op_getter=_identity_check),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: _serialize_controls(x)),
cirq.google.SerializingArg(
serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: _serialize_control_vals(x))
]
return cirq.google.GateOpSerializer(gate_type=cirq.IdentityGate,
serialized_gate_id="I",
Expand All @@ -240,11 +322,15 @@ def _identity_gate_deserializer():
"""Make a standard deserializer for the single qubit identity."""
args = [
cirq.google.DeserializingArg(serialized_name="unused",
constructor_arg_name="unused")
constructor_arg_name="unused"),
cirq.google.DeserializingArg(serialized_name="control_qubits",
constructor_arg_name="control_qubits"),
cirq.google.DeserializingArg(serialized_name="control_values",
constructor_arg_name="control_values")
]

def _cirq_i_workaround(unused):
return cirq.I
def _cirq_i_workaround(unused, control_qubits, control_values):
return _optional_control_promote(cirq.I, control_qubits, control_values)

return cirq.google.GateOpDeserializer(serialized_gate_id="I",
gate_constructor=_cirq_i_workaround,
Expand Down Expand Up @@ -274,7 +360,14 @@ def _phased_eigen_gate_serializer(gate_type, serialized_id):
cirq.google.SerializingArg(
serialized_name="global_shift",
serialized_type=float,
op_getter=lambda x: float(x.gate.global_shift))
op_getter=lambda x: float(x.gate.global_shift)),
cirq.google.SerializingArg(serialized_name="control_qubits",
serialized_type=str,
op_getter=lambda x: _serialize_controls(x)),
cirq.google.SerializingArg(
serialized_name="control_values",
serialized_type=str,
op_getter=lambda x: _serialize_control_vals(x))
]
return cirq.google.GateOpSerializer(gate_type=gate_type,
serialized_gate_id=serialized_id,
Expand All @@ -286,7 +379,8 @@ def _phased_eigen_gate_deserializer(gate_type, serialized_id):
"""Make a standard deserializer for phased eigen gates."""

def _scalar_combiner(exponent, global_shift, exponent_scalar,
phase_exponent, phase_exponent_scalar):
phase_exponent, phase_exponent_scalar, control_qubits,
control_values):
"""This is a workaround to support symbol scalar multiplication.
In the future we should likely get rid of this in favor of proper
expression parsing once cirq supports it. See cirq.op_serializer
Expand All @@ -302,10 +396,14 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
if global_shift != 0:
# needed in case this specific phasedeigengate doesn't
# have a global_phase in constructor.
return gate_type(exponent=exponent,
global_shift=_round(global_shift),
phase_exponent=phase_exponent)
return gate_type(exponent=exponent, phase_exponent=phase_exponent)
return _optional_control_promote(
gate_type(exponent=exponent,
global_shift=_round(global_shift),
phase_exponent=phase_exponent), control_qubits,
control_values)
return _optional_control_promote(
gate_type(exponent=exponent, phase_exponent=phase_exponent),
control_qubits, control_values)

args = [
cirq.google.DeserializingArg(serialized_name="phase_exponent",
Expand All @@ -319,6 +417,10 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
constructor_arg_name="exponent_scalar"),
cirq.google.DeserializingArg(serialized_name="global_shift",
constructor_arg_name="global_shift"),
cirq.google.DeserializingArg(serialized_name="control_qubits",
constructor_arg_name="control_qubits"),
cirq.google.DeserializingArg(serialized_name="control_values",
constructor_arg_name="control_values")
]
return cirq.google.GateOpDeserializer(serialized_gate_id=serialized_id,
gate_constructor=_scalar_combiner,
Expand Down Expand Up @@ -434,6 +536,21 @@ def serialize_circuit(circuit_inp):
old_moment.operations))
circuit[moment_ind] = new_moment

# Demote cirq.controlled_operations (controlled gates) to their sub_gate
# types with _tfq_control_qubits and _tfq_control_values fields so that
# the gates can still get picked up by the serializer. There would be no way
# to discern controlledgates from one another otherwise. This
# "momentary demotion" occurs with the help of the DelayedAssignmentGate.
for i, moment in enumerate(circuit):
for op in moment:
if isinstance(op,
cirq.ops.controlled_operation.ControlledOperation):
tfq_compatible = op.sub_operation
tfq_compatible._tfq_control_qubits = op.controls
tfq_compatible._tfq_control_values = op.control_values
dropped_moment = moment.without_operations_touching(op.qubits)
circuit[i] = dropped_moment.with_operation(tfq_compatible)

return SERIALIZER.serialize(circuit)


Expand Down
Loading