diff --git a/scratch/lower_zair.py b/scratch/lower_zair.py index 01077015..1935da42 100644 --- a/scratch/lower_zair.py +++ b/scratch/lower_zair.py @@ -1,11 +1,15 @@ -from dataclasses import dataclass, field +import json import math -from typing import Any, Sequence -from kirin import ir +from dataclasses import dataclass, field +from typing import Any, Generic, Sequence, TypeVar + +from kirin import ir, types +from kirin.dialects import func, ilist, py from kirin.dialects.ilist import IList -from kirin.dialects import py, func, ilist -from bloqade.geometry import grid -from bloqade.shuttle.dialects import gate, filled + +from bloqade.shuttle.dialects import filled, gate, init, spec +from bloqade.shuttle.prelude import move + def _simple_region() -> ir.Region: return ir.Region(ir.Block()) @@ -13,11 +17,14 @@ def _simple_region() -> ir.Region: @dataclass class ShuttleBuilder: - move_kernel: ir.Method[[IList[tuple[int, int]], IList[tuple[int, int]]], None] - num_qubits: int + spec_mapping: dict[int, str] + move_kernel: ir.Method[ + [IList[int, Any], IList[int, Any], IList[int, Any], IList[int, Any]], None + ] + + num_qubits: int = field(init=False) body: ir.Region = field(default_factory=_simple_region, init=False) - grid_mapping: ir.SSAValue = None def push_stmt(self, stmt: ir.Statement): self.body.blocks[0].stmts.append(stmt) @@ -27,47 +34,35 @@ def push_constant(self, value: Any) -> ir.SSAValue: const_stmt = py.Constant(value) return self.push_stmt(const_stmt).expect_one_result() - def construct_grid( + def get_zone(self, zone_id: int) -> ir.SSAValue: + return self.push_stmt( + spec.GetStaticTrap(zone_id=self.spec_mapping[zone_id]) + ).expect_one_result() + + def get_slm( self, - grid_id: int, - offset: tuple[int, int], - x_spacing: list[int], - y_spacing: list[int], - dim: tuple[int, int], + slm_id: int, ): - x_init = offset[0] - y_init = offset[1] - - grid_ref = self.push_constant( - grid.Grid( - x_spacing=tuple(x_spacing), y_spacing=tuple(y_spacing), x_init=x_init, y_init=y_init - ) - ) - - self.grid_mapping = grid_ref - - return grid_ref + return self.push_stmt( + spec.GetStaticTrap(zone_id=self.spec_mapping[slm_id]) + ).expect_one_result() - def insert_move( + def lower_rearrange( self, - srcs: Sequence[tuple[int, int, int, int]], - dsts: Sequence[tuple[int, int, int, int]], + begin_locs: Sequence[tuple[int, int, int, int]], + end_locs: Sequence[tuple[int, int, int, int]], ): - # ignoring zone mapping, - sorted_srcs = sorted(srcs, key=lambda x: x[0]) - sorted_dsts = sorted(dsts, key=lambda x: x[0]) - - x_src = IList([src[2] for src in sorted_srcs]) - y_src = IList([src[3] for src in sorted_srcs]) - # assume only address qubits will move - x_dst = IList([dst[2] * 2 for dst in sorted_dsts]) - y_dst = IList([dst[3] * 2 for dst in sorted_dsts]) + # ignoring zone mapping, only include all unique x and y positions + x_src = IList(sorted(set(src[2] for src in begin_locs))) + y_src = IList(sorted(set(src[3] for src in begin_locs))) + x_dst = IList(sorted(set(dst[2] for dst in end_locs))) + y_dst = IList(sorted(set(dst[3] for dst in end_locs))) x_src_ref = self.push_constant(x_src) y_src_ref = self.push_constant(y_src) x_dst_ref = self.push_constant(x_dst) y_dst_ref = self.push_constant(y_dst) - + self.push_stmt( func.Invoke( inputs=(x_src_ref, y_src_ref, x_dst_ref, y_dst_ref), @@ -75,39 +70,124 @@ def insert_move( kwargs=(), ) ) - - def entangle(self, grid_id: int): - self.push_stmt(gate.TopHatCZ(self.grid_mapping[grid_id])) - def r_gate( + def rydberg(self, zone_id: int): + self.push_stmt(gate.TopHatCZ(self.get_zone(zone_id))) + + def lower_r_gate( self, axis_angle: float, rotation_angle: float, - locs: Sequence[tuple[int, int, int]], + locs: Sequence[tuple[int, int, int, int]], ): - filled_locs = {} + filled_locs: dict[int, list[tuple[int, int]]] = {} - for _, x, y in locs: - filled_locs.setdefault(0, []).append((x, y)) + for _, grid_id, x, y in locs: + filled_locs.setdefault(grid_id, []).append((x, y)) - filled_loc_refs: dict[int, ir.SSAValue] = {} + filled_loc_refs: list[ir.SSAValue] = [] - for grid_id, locs in filled_locs.items(): - locs_ref = self.push_constant(locs) - filled_loc_refs[grid_id] = self.push_stmt( - filled.Fill(self.grid_mapping[grid_id], locs_ref) - ).expect_one_result() + for grid_id, coords in filled_locs.items(): + locs_ref = self.push_constant(ilist.IList(coords)) + filled_loc_refs.append( + self.push_stmt( + filled.Fill(self.get_slm(grid_id), locs_ref) + ).expect_one_result() + ) axis_angle_ref = self.push_constant(axis_angle / (2 * math.pi)) rotation_angle_ref = self.push_constant(rotation_angle / (2 * math.pi)) - for filled_ref in filled_loc_refs.values(): + for filled_ref in filled_loc_refs: self.push_stmt(gate.LocalR(axis_angle_ref, rotation_angle_ref, filled_ref)) + def lower_rz_gate( + self, rotation_angle: float, locations: Sequence[tuple[int, int, int, int]] + ): + filled_locs: dict[int, list[tuple[int, int]]] = {} + + for _, grid_id, x, y in locations: + filled_locs.setdefault(grid_id, []).append((x, y)) + + filled_loc_refs: list[ir.SSAValue] = [] + + for grid_id, coords in filled_locs.items(): + locs_ref = self.push_constant(ilist.IList(coords)) + filled_loc_refs.append( + self.push_stmt( + filled.Fill(self.get_slm(grid_id), locs_ref) + ).expect_one_result() + ) + + rotation_angle_ref = self.push_constant(rotation_angle / (2 * math.pi)) + + for filled_ref in filled_loc_refs: + self.push_stmt(gate.LocalRz(rotation_angle_ref, filled_ref)) + def lower_h(self, locs: Sequence[tuple[int, int, int, int]]): - assert len(locs) == self.num_qubits, "H gate must be applied to all qubits" - self.push_stmt(gate.GlobalRz(0.25)) - self.push_stmt(gate.GlobalR(0, 0.5)) - self.push_stmt(gate.GlobalRz(0.25)) + quarter_rotation = self.push_constant(0.25) + zero = self.push_constant(0.0) + half_rotation = self.push_constant(0.5) + if len(locs) == self.num_qubits: + + self.push_stmt(gate.GlobalRz(quarter_rotation)) + self.push_stmt(gate.GlobalR(zero, half_rotation)) + self.push_stmt(gate.GlobalRz(quarter_rotation)) + + def lower_init(self, locs: Sequence[tuple[int, int, int, int]]): + filled_locs: dict[int, list[tuple[int, int]]] = {} + + for _, grid_id, x, y in locs: + filled_locs.setdefault(grid_id, []).append((x, y)) - + filled_loc_refs: list[ir.SSAValue] = [] + + for grid_id, coords in filled_locs.items(): + locs_ref = self.push_constant(ilist.IList(coords)) + filled_loc_refs.append( + self.push_stmt( + filled.Fill(self.get_slm(grid_id), locs_ref) + ).expect_one_result() + ) + + locations = self.push_constant(ilist.IList(filled_loc_refs)) + self.push_stmt(init.Fill(locations)) + + def lower_instruction(self, instruction: dict[str, Any]): + match instruction: + case {"type": "init", "locs": locs}: + self.lower_init(locs) + case {"type": "1qGate", "unitary": "ry", "locs": locs}: + raise NotImplementedError + case {"type": "1qGate", "unitary": "h", "locs": locs}: + self.lower_h(locs) + case {"type": "rydberg", "zone_id": zone_id}: + self.rydberg(zone_id) + case { + "type": "rearrangeJob", + "begin_locs": begin_locs, + "end_locs": end_locs, + }: + self.lower_rearrange(begin_locs, end_locs) + + def lower(self, program: dict[str, Any]) -> ir.Method: + """Entry point for lowering a ZAIR program + + Args: + program (dict[str, Any]): JSON representation of the ZAIR program + + Returns: + ir.Method: Lowered IR method + """ + sym_name = program["name"] + signature = func.Signature((), types.NoneType) + + for inst in program["instructions"]: + self.lower_instruction(inst) + + code = func.Function( + sym_name=sym_name, + signature=signature, + body=self.body, + ) + return ir.Method(None, None, sym_name, [], move, code) diff --git a/scratch/qcrank.py b/scratch/qcrank.py index 77069629..bf20db13 100644 --- a/scratch/qcrank.py +++ b/scratch/qcrank.py @@ -1,28 +1,29 @@ +import json from typing import Any, Literal, TypeVar +import numpy as np from bloqade.geometry.dialects import grid from kirin.dialects import ilist +from lower_zair import ShuttleBuilder from bloqade.shuttle import action, gate, init, measure, schedule, spec -from bloqade.shuttle.stdlib.layouts.two_col_zone import rearrange from bloqade.shuttle.prelude import move, tweezer +from bloqade.shuttle.stdlib.layouts.two_col_zone import rearrange from bloqade.shuttle.visualizer import MatplotlibRenderer, PathVisualizer -from lower_zair import ShuttleBuilder -import json -import numpy as np + def run_qcrank(filename: str): - with open(filename, 'r') as f: + with open(filename, "r") as f: compiled_qcrank = json.load(f) - + arch_filename = compiled_qcrank["architecture_spec_path"] - with open(arch_filename, 'r') as f: + with open(arch_filename, "r") as f: architecture_spec = json.load(f) # set architecture # assume single entagnlement zone - entanglement_zone_spec = architecture_spec["entanglement_zones"][0] + entanglement_zone_spec = architecture_spec["entanglement_zones"][0] slms = entanglement_zone_spec["slms"] assert len(slms) == 2 slm0 = slms[0] @@ -50,21 +51,25 @@ def run_qcrank(filename: str): x_spacing.append(dis_trap) x_spacing.append(dis_site) x_spacing = x_spacing - + inst_init = compiled_qcrank["instructions"][0] init_quibt_location = inst_init["init_locs"] - - shuttle_builder = ShuttleBuilder(num_qubits = len(init_quibt_location), move_kernel=rearrange) - shuttle_builder.construct_grid(entanglement_zone_spec['zone_id'], - entanglement_zone_spec["offset"], - x_spacing, - y_spacing, - entanglement_zone_spec["dimension"]) + + shuttle_builder = ShuttleBuilder( + num_qubits=len(init_quibt_location), move_kernel=rearrange + ) + shuttle_builder.construct_grid( + entanglement_zone_spec["zone_id"], + entanglement_zone_spec["offset"], + x_spacing, + y_spacing, + entanglement_zone_spec["dimension"], + ) spec_value = spec.ArchSpec( layout=spec.Layout( static_traps={ - "mem": shuttle_builder.grid_mapping, + "mem": shuttle_builder.spec_mapping, }, fillable=set(["mem"]), ) @@ -78,11 +83,11 @@ def run_qcrank(filename: str): y *= 2 else: x *= 2 - grid_init_quibt_location.append((x,y)) - + grid_init_quibt_location.append((x, y)) + def main(): # init.fill([spec.get_static_trap(zone_id="mem")]) - init.fill(shuttle_builder.grid_mapping) # ! + init.fill(shuttle_builder.spec_mapping) # ! insts = compiled_qcrank["instructions"][1:] for inst in insts: if inst["type"] == "1qGate": @@ -93,7 +98,7 @@ def main(): locs = [(loc[0], 2 * loc[2], loc[3])] else: locs = [(loc[0], loc[2], loc[3])] - shuttle_builder.r_gate(0, rotation_angle, locs) + shuttle_builder.lower_r_gate(0, rotation_angle, locs) elif inst["unitary"] == "h": shuttle_builder.lower_h(inst["locs"]) @@ -101,10 +106,10 @@ def main(): shuttle_builder.entangle(inst["zone_id"]) elif inst["type"] == "rearrangeJob": shuttle_builder.insert_move(inst["begin_locs"], inst["end_locs"]) - return measure.measure((shuttle_builder.grid_mapping,)) + return measure.measure((shuttle_builder.spec_mapping,)) return main, spec_value if __name__ == "__main__": - filename = "scratch/qcr_4a8d_quera_circ_code.json" \ No newline at end of file + filename = "scratch/qcr_4a8d_quera_circ_code.json"