Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 140 additions & 60 deletions scratch/lower_zair.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from dataclasses import dataclass, field
import json
import math
from typing import Any, Sequence
from kirin import ir
from dataclasses import dataclass, field
from typing import Any, Generic, Sequence, TypeVar

from kirin import ir, types
from kirin.dialects import func, ilist, py
from kirin.dialects.ilist import IList
from kirin.dialects import py, func, ilist
from bloqade.geometry import grid
from bloqade.shuttle.dialects import gate, filled

from bloqade.shuttle.dialects import filled, gate, init, spec
from bloqade.shuttle.prelude import move


def _simple_region() -> ir.Region:
return ir.Region(ir.Block())


@dataclass
class ShuttleBuilder:
move_kernel: ir.Method[[IList[tuple[int, int]], IList[tuple[int, int]]], None]
num_qubits: int

spec_mapping: dict[int, str]
move_kernel: ir.Method[
[IList[int, Any], IList[int, Any], IList[int, Any], IList[int, Any]], None
]

num_qubits: int = field(init=False)
body: ir.Region = field(default_factory=_simple_region, init=False)
grid_mapping: ir.SSAValue = None

def push_stmt(self, stmt: ir.Statement):
self.body.blocks[0].stmts.append(stmt)
Expand All @@ -27,87 +34,160 @@ def push_constant(self, value: Any) -> ir.SSAValue:
const_stmt = py.Constant(value)
return self.push_stmt(const_stmt).expect_one_result()

def construct_grid(
def get_zone(self, zone_id: int) -> ir.SSAValue:
return self.push_stmt(
spec.GetStaticTrap(zone_id=self.spec_mapping[zone_id])
).expect_one_result()

def get_slm(
self,
grid_id: int,
offset: tuple[int, int],
x_spacing: list[int],
y_spacing: list[int],
dim: tuple[int, int],
slm_id: int,
):
x_init = offset[0]
y_init = offset[1]

grid_ref = self.push_constant(
grid.Grid(
x_spacing=tuple(x_spacing), y_spacing=tuple(y_spacing), x_init=x_init, y_init=y_init
)
)

self.grid_mapping = grid_ref

return grid_ref
return self.push_stmt(
spec.GetStaticTrap(zone_id=self.spec_mapping[slm_id])
).expect_one_result()

def insert_move(
def lower_rearrange(
self,
srcs: Sequence[tuple[int, int, int, int]],
dsts: Sequence[tuple[int, int, int, int]],
begin_locs: Sequence[tuple[int, int, int, int]],
end_locs: Sequence[tuple[int, int, int, int]],
):
# ignoring zone mapping,
sorted_srcs = sorted(srcs, key=lambda x: x[0])
sorted_dsts = sorted(dsts, key=lambda x: x[0])

x_src = IList([src[2] for src in sorted_srcs])
y_src = IList([src[3] for src in sorted_srcs])
# assume only address qubits will move
x_dst = IList([dst[2] * 2 for dst in sorted_dsts])
y_dst = IList([dst[3] * 2 for dst in sorted_dsts])
# ignoring zone mapping, only include all unique x and y positions
x_src = IList(sorted(set(src[2] for src in begin_locs)))
y_src = IList(sorted(set(src[3] for src in begin_locs)))
x_dst = IList(sorted(set(dst[2] for dst in end_locs)))
y_dst = IList(sorted(set(dst[3] for dst in end_locs)))

x_src_ref = self.push_constant(x_src)
y_src_ref = self.push_constant(y_src)
x_dst_ref = self.push_constant(x_dst)
y_dst_ref = self.push_constant(y_dst)

self.push_stmt(
func.Invoke(
inputs=(x_src_ref, y_src_ref, x_dst_ref, y_dst_ref),
callee=self.move_kernel,
kwargs=(),
)
)

def entangle(self, grid_id: int):
self.push_stmt(gate.TopHatCZ(self.grid_mapping[grid_id]))

def r_gate(
def rydberg(self, zone_id: int):
self.push_stmt(gate.TopHatCZ(self.get_zone(zone_id)))

def lower_r_gate(
self,
axis_angle: float,
rotation_angle: float,
locs: Sequence[tuple[int, int, int]],
locs: Sequence[tuple[int, int, int, int]],
):
filled_locs = {}
filled_locs: dict[int, list[tuple[int, int]]] = {}

for _, x, y in locs:
filled_locs.setdefault(0, []).append((x, y))
for _, grid_id, x, y in locs:
filled_locs.setdefault(grid_id, []).append((x, y))

filled_loc_refs: dict[int, ir.SSAValue] = {}
filled_loc_refs: list[ir.SSAValue] = []

for grid_id, locs in filled_locs.items():
locs_ref = self.push_constant(locs)
filled_loc_refs[grid_id] = self.push_stmt(
filled.Fill(self.grid_mapping[grid_id], locs_ref)
).expect_one_result()
for grid_id, coords in filled_locs.items():
locs_ref = self.push_constant(ilist.IList(coords))
filled_loc_refs.append(
self.push_stmt(
filled.Fill(self.get_slm(grid_id), locs_ref)
).expect_one_result()
)

axis_angle_ref = self.push_constant(axis_angle / (2 * math.pi))
rotation_angle_ref = self.push_constant(rotation_angle / (2 * math.pi))

for filled_ref in filled_loc_refs.values():
for filled_ref in filled_loc_refs:
self.push_stmt(gate.LocalR(axis_angle_ref, rotation_angle_ref, filled_ref))

def lower_rz_gate(
self, rotation_angle: float, locations: Sequence[tuple[int, int, int, int]]
):
filled_locs: dict[int, list[tuple[int, int]]] = {}

for _, grid_id, x, y in locations:
filled_locs.setdefault(grid_id, []).append((x, y))

filled_loc_refs: list[ir.SSAValue] = []

for grid_id, coords in filled_locs.items():
locs_ref = self.push_constant(ilist.IList(coords))
filled_loc_refs.append(
self.push_stmt(
filled.Fill(self.get_slm(grid_id), locs_ref)
).expect_one_result()
)

rotation_angle_ref = self.push_constant(rotation_angle / (2 * math.pi))

for filled_ref in filled_loc_refs:
self.push_stmt(gate.LocalRz(rotation_angle_ref, filled_ref))

def lower_h(self, locs: Sequence[tuple[int, int, int, int]]):
assert len(locs) == self.num_qubits, "H gate must be applied to all qubits"
self.push_stmt(gate.GlobalRz(0.25))
self.push_stmt(gate.GlobalR(0, 0.5))
self.push_stmt(gate.GlobalRz(0.25))
quarter_rotation = self.push_constant(0.25)
zero = self.push_constant(0.0)
half_rotation = self.push_constant(0.5)
if len(locs) == self.num_qubits:

self.push_stmt(gate.GlobalRz(quarter_rotation))
self.push_stmt(gate.GlobalR(zero, half_rotation))
self.push_stmt(gate.GlobalRz(quarter_rotation))

def lower_init(self, locs: Sequence[tuple[int, int, int, int]]):
filled_locs: dict[int, list[tuple[int, int]]] = {}

for _, grid_id, x, y in locs:
filled_locs.setdefault(grid_id, []).append((x, y))


filled_loc_refs: list[ir.SSAValue] = []

for grid_id, coords in filled_locs.items():
locs_ref = self.push_constant(ilist.IList(coords))
filled_loc_refs.append(
self.push_stmt(
filled.Fill(self.get_slm(grid_id), locs_ref)
).expect_one_result()
)

locations = self.push_constant(ilist.IList(filled_loc_refs))
self.push_stmt(init.Fill(locations))

def lower_instruction(self, instruction: dict[str, Any]):
match instruction:
case {"type": "init", "locs": locs}:
self.lower_init(locs)
case {"type": "1qGate", "unitary": "ry", "locs": locs}:
raise NotImplementedError
case {"type": "1qGate", "unitary": "h", "locs": locs}:
self.lower_h(locs)
case {"type": "rydberg", "zone_id": zone_id}:
self.rydberg(zone_id)
case {
"type": "rearrangeJob",
"begin_locs": begin_locs,
"end_locs": end_locs,
}:
self.lower_rearrange(begin_locs, end_locs)

def lower(self, program: dict[str, Any]) -> ir.Method:
"""Entry point for lowering a ZAIR program

Args:
program (dict[str, Any]): JSON representation of the ZAIR program

Returns:
ir.Method: Lowered IR method
"""
sym_name = program["name"]
signature = func.Signature((), types.NoneType)

for inst in program["instructions"]:
self.lower_instruction(inst)

code = func.Function(
sym_name=sym_name,
signature=signature,
body=self.body,
)
return ir.Method(None, None, sym_name, [], move, code)
51 changes: 28 additions & 23 deletions scratch/qcrank.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import json
from typing import Any, Literal, TypeVar

import numpy as np
from bloqade.geometry.dialects import grid
from kirin.dialects import ilist
from lower_zair import ShuttleBuilder

from bloqade.shuttle import action, gate, init, measure, schedule, spec
from bloqade.shuttle.stdlib.layouts.two_col_zone import rearrange
from bloqade.shuttle.prelude import move, tweezer
from bloqade.shuttle.stdlib.layouts.two_col_zone import rearrange
from bloqade.shuttle.visualizer import MatplotlibRenderer, PathVisualizer
from lower_zair import ShuttleBuilder
import json
import numpy as np


def run_qcrank(filename: str):
with open(filename, 'r') as f:
with open(filename, "r") as f:
compiled_qcrank = json.load(f)

arch_filename = compiled_qcrank["architecture_spec_path"]

with open(arch_filename, 'r') as f:
with open(arch_filename, "r") as f:
architecture_spec = json.load(f)

# set architecture
# assume single entagnlement zone
entanglement_zone_spec = architecture_spec["entanglement_zones"][0]
entanglement_zone_spec = architecture_spec["entanglement_zones"][0]
slms = entanglement_zone_spec["slms"]
assert len(slms) == 2
slm0 = slms[0]
Expand Down Expand Up @@ -50,21 +51,25 @@ def run_qcrank(filename: str):
x_spacing.append(dis_trap)
x_spacing.append(dis_site)
x_spacing = x_spacing

inst_init = compiled_qcrank["instructions"][0]
init_quibt_location = inst_init["init_locs"]

shuttle_builder = ShuttleBuilder(num_qubits = len(init_quibt_location), move_kernel=rearrange)
shuttle_builder.construct_grid(entanglement_zone_spec['zone_id'],
entanglement_zone_spec["offset"],
x_spacing,
y_spacing,
entanglement_zone_spec["dimension"])

shuttle_builder = ShuttleBuilder(
num_qubits=len(init_quibt_location), move_kernel=rearrange
)
shuttle_builder.construct_grid(
entanglement_zone_spec["zone_id"],
entanglement_zone_spec["offset"],
x_spacing,
y_spacing,
entanglement_zone_spec["dimension"],
)

spec_value = spec.ArchSpec(
layout=spec.Layout(
static_traps={
"mem": shuttle_builder.grid_mapping,
"mem": shuttle_builder.spec_mapping,
},
fillable=set(["mem"]),
)
Expand All @@ -78,11 +83,11 @@ def run_qcrank(filename: str):
y *= 2
else:
x *= 2
grid_init_quibt_location.append((x,y))
grid_init_quibt_location.append((x, y))

def main():
# init.fill([spec.get_static_trap(zone_id="mem")])
init.fill(shuttle_builder.grid_mapping) # !
init.fill(shuttle_builder.spec_mapping) # !
insts = compiled_qcrank["instructions"][1:]
for inst in insts:
if inst["type"] == "1qGate":
Expand All @@ -93,18 +98,18 @@ def main():
locs = [(loc[0], 2 * loc[2], loc[3])]
else:
locs = [(loc[0], loc[2], loc[3])]
shuttle_builder.r_gate(0, rotation_angle, locs)
shuttle_builder.lower_r_gate(0, rotation_angle, locs)
elif inst["unitary"] == "h":
shuttle_builder.lower_h(inst["locs"])

elif inst["type"] == "rydberg":
shuttle_builder.entangle(inst["zone_id"])
elif inst["type"] == "rearrangeJob":
shuttle_builder.insert_move(inst["begin_locs"], inst["end_locs"])
return measure.measure((shuttle_builder.grid_mapping,))
return measure.measure((shuttle_builder.spec_mapping,))

return main, spec_value


if __name__ == "__main__":
filename = "scratch/qcr_4a8d_quera_circ_code.json"
filename = "scratch/qcr_4a8d_quera_circ_code.json"