Skip to content

Commit

Permalink
Add preliminary RISC-V vector support (Assembly only)
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick O'Neill <[email protected]>
  • Loading branch information
patrick-rivos authored and rbertran committed Aug 28, 2023
1 parent b3164c6 commit ab246db
Show file tree
Hide file tree
Showing 11 changed files with 2,853 additions and 39 deletions.
8 changes: 5 additions & 3 deletions src/microprobe/code/ins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Built-in modules
import copy
from itertools import product
from typing import TYPE_CHECKING, Callable, List
from typing import TYPE_CHECKING, Callable, Dict, List

# Third party modules
import six
Expand Down Expand Up @@ -1604,7 +1604,8 @@ def __init__(self):
self._generic_type = None
self._label = None
self._mem_operands = []
self._operands = RejectingOrderedDict()
self._operands: Dict[str,
InstructionOperandValue] = RejectingOrderedDict()

def set_arch_type(self, instrtype):
"""
Expand All @@ -1613,7 +1614,8 @@ def set_arch_type(self, instrtype):
"""
self._arch_type = instrtype
self._operands = RejectingOrderedDict()
self._operands: Dict[str,
InstructionOperandValue] = RejectingOrderedDict()
self._mem_operands = []
self._allowed_regs = []
self._address = None
Expand Down
52 changes: 36 additions & 16 deletions src/microprobe/passes/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"""

# Futures
from __future__ import absolute_import, print_function
from __future__ import absolute_import, print_function, annotations

# Built-in modules
from typing import TYPE_CHECKING

# Third party modules
from six.moves import zip
Expand All @@ -36,6 +37,10 @@

# Local modules

# Type hinting
if TYPE_CHECKING:
from microprobe.code.benchmark import Benchmark
from microprobe.target import Target

# Constants
LOG = get_logger(__name__)
Expand Down Expand Up @@ -222,6 +227,7 @@ def __init__(self, *args, **kwargs):
skip_unknown = kwargs.get("skip_unknown", False)
warn_unknown = kwargs.get("warn_unknown", False)
self._force_code = kwargs.get("force_code", False)
self.lmul = kwargs.get("lmul", 1)

if len(args) == 1:
self._reg_dict = dict([
Expand Down Expand Up @@ -250,7 +256,7 @@ def __init__(self, *args, **kwargs):
self._fp_value,
v_value)

def __call__(self, building_block, target):
def __call__(self, building_block: Benchmark, target: Target):
"""
:param building_block:
Expand All @@ -259,26 +265,26 @@ def __call__(self, building_block, target):
"""
if not self._skip_unknown:
for register_name in self._reg_dict:
if register_name not in list(target.registers.keys()):
if register_name not in list(target.isa.registers.keys()):
raise MicroprobeCodeGenerationError(
"Unknown register name: '%s'. Unable to set it" %
register_name)

if self._warn_unknown:
for register_name in self._reg_dict:
if register_name not in list(target.registers.keys()):
if register_name not in list(target.isa.registers.keys()):
print_warning(
"Unknown register name: '%s'. Unable to set it" %
register_name)

regs = sorted(target.registers.values(),
regs = sorted(target.isa.registers.values(),
key=lambda x: self._priolist.index(x.name)
if x.name in self._priolist else 314159)

#
# Make sure scratch registers are set last
#
for reg in target.scratch_registers:
for reg in target.isa.scratch_registers:
if reg in regs:
regs.remove(reg)
regs.append(reg)
Expand All @@ -294,25 +300,39 @@ def __call__(self, building_block, target):
self._reg_dict.pop(reg.name)
force_direct = True

if (reg in building_block.context.reserved_registers and
not self._force_reserved):
if reg.name == "LMUL":
building_block.add_init(
target.isa.set_register(reg, self.lmul,
building_block.context))
building_block.context.set_register_value(reg, self.lmul)
continue

all_vec_regs = set([f"V{i}" for i in range(0, 32)])
lmul_allowed_regs = set([f"V{i}" for i in range(0, 32, self.lmul)])

if reg.name in all_vec_regs - lmul_allowed_regs:
# Skip vector registers ignored by lmul
continue

if (reg in building_block.context.reserved_registers
and not self._force_reserved):
LOG.debug("Skip reserved - %s", reg)
continue
elif (reg in target.control_registers and
(value is None or self._skip_control)):
elif (reg in target.isa.control_registers
and (value is None or self._skip_control)):
LOG.debug("Skip control - %s", reg)
continue

if value is None:
if reg.used_for_vector_arithmetic:
if reg.type.used_for_vector_arithmetic:
if self._vect_value is not None:
value = self._vect_value
elemsize = self._vect_elemsize
else:
LOG.debug("Skip no vector default value provided - %s",
reg)
continue
elif reg.used_for_float_arithmetic:
elif reg.type.used_for_float_arithmetic:
if self._fp_value is not None:
value = self._fp_value
else:
Expand All @@ -332,10 +352,10 @@ def __call__(self, building_block, target):
if isinstance(value, int):
value = value & ((2**reg.size)-1)

if reg.used_for_float_arithmetic:
if reg.type.used_for_float_arithmetic:
value = ieee_float_to_int64(float(value))

elif reg.used_for_vector_arithmetic:
elif reg.type.used_for_vector_arithmetic:
if isinstance(value, float):
if elemsize != 64:
raise MicroprobeCodeGenerationError(
Expand All @@ -360,13 +380,13 @@ def __call__(self, building_block, target):
else:
LOG.debug("Direct set of '%s' to '0x%x'", reg, value)
except MicroprobeCodeGenerationError:
building_block.add_init(target.set_register(
building_block.add_init(target.isa.set_register(
reg, value, building_block.context))
LOG.debug("Set '%s' to '0x%x'", reg, value)
except MicroprobeDuplicatedValueError:
LOG.debug("Skip already set - %s", reg)
else:
building_block.add_init(target.set_register(
building_block.add_init(target.isa.set_register(
reg, value, building_block.context))
building_block.context.set_register_value(reg, value)

Expand Down
7 changes: 7 additions & 0 deletions src/microprobe/target/isa/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,13 @@ def assembly(self, args, dissabled_fields=None):
"," + field.name + ")",
"," + next_operand_value().representation + ")", 1)

elif assembly_str.find(" " + field.name + ".t") >= 0:
assembly_str = assembly_str.replace(
", " + field.name + ".t",
", " + next_operand_value().representation + ".t",
1,
)

else:
LOG.debug(
"%s",
Expand Down
104 changes: 97 additions & 7 deletions src/microprobe/target/isa/operand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
"""

# Futures
from __future__ import absolute_import, print_function
from __future__ import absolute_import, print_function, annotations

# Built-in modules
import abc
import os
import random
from typing import Dict, List, TYPE_CHECKING, cast

# Third party modules
import six
Expand All @@ -39,6 +40,10 @@
from microprobe.utils.typeguard_decorator import typeguard_testsuite
from microprobe.utils.yaml import read_yaml

# Type hinting
if TYPE_CHECKING:
from microprobe.code.context import Context

# Constants
SCHEMA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "schemas",
"operand.yaml")
Expand Down Expand Up @@ -285,7 +290,7 @@ class OperandDescriptor:
"""

def __init__(self, mtype, is_input, is_output):
def __init__(self, mtype: Operand, is_input, is_output):
"""
:param mtype:
Expand All @@ -312,7 +317,7 @@ def is_output(self):
"""Is output flag (:class:`~.bool`) """
return self._is_output

def set_type(self, new_type):
def set_type(self, new_type: Operand):
"""
:param new_type:
Expand Down Expand Up @@ -616,7 +621,14 @@ def copy(self):
raise NotImplementedError

@abc.abstractmethod
def values(self):
def values(self) -> List[Register]:
"""Return the possible value of the operand."""
raise NotImplementedError

# TODO: Consider making filtered_values into values.
def filtered_values(
self, context: Context, fieldname: str
) -> List[Register]:
"""Return the possible value of the operand."""
raise NotImplementedError

Expand Down Expand Up @@ -767,8 +779,14 @@ class OperandReg(Operand):
"""

def __init__(self, name, descr, regs, address_base, address_index,
floating_point, vector):
def __init__(self,
name: str,
descr: str,
regs: List[Register] | Dict[Register, List[Register]],
address_base,
address_index: int,
floating_point: bool | None,
vector: bool | None):
"""
:param name:
Expand All @@ -783,7 +801,7 @@ def __init__(self, name, descr, regs, address_base, address_index,
super(OperandReg, self).__init__(name, descr)

if isinstance(regs, list):
self._regs = OrderedDict()
self._regs: Dict[Register, List[Register]] = OrderedDict()
for reg in regs:
self._regs[reg] = [reg]
else:
Expand All @@ -809,6 +827,53 @@ def values(self):
"""
return list(self._regs.keys())

def filtered_values(self, context: Context, fieldname: str):
lmul = cast(int | None, context.get_registername_value("LMUL"))

if lmul is None or not fieldname.startswith("v"):
return self.values()
elif fieldname in ["vd", "vmd", "vrs1", "vrs2", "vmask"]:
lmul *= 1
elif fieldname in ["vdd", "vdmd", "vdrs1", "vdrs2", "vnd", "vnmd"]:
lmul *= 2
elif fieldname in []:
lmul *= 4
elif fieldname in []:
lmul *= 8
else:
raise ValueError(f"Unhandled LMUL operand name: {fieldname}")

regs = list(self._regs.keys())

class LMULRegs:
lmul1 = regs
lmul2 = [
reg
for reg in self._regs.keys()
if reg.name in set([f"V{i}" for i in range(0, 32, 2)])
]
lmul4 = [
reg
for reg in self._regs.keys()
if reg.name in set([f"V{i}" for i in range(0, 32, 4)])
]
lmul8 = [
reg
for reg in self._regs.keys()
if reg.name in set([f"V{i}" for i in range(0, 32, 8)])
]

if lmul == 1:
return LMULRegs.lmul1
elif lmul == 2:
return LMULRegs.lmul2
elif lmul == 4:
return LMULRegs.lmul4
elif lmul == 8:
return LMULRegs.lmul8
else:
raise ValueError(f"Unhandled LMUL value: {lmul}")

def representation(self, value):
"""
Expand Down Expand Up @@ -927,6 +992,11 @@ def values(self):
]
return self._computed_values

def filtered_values(
self, context: Context, fieldname: str
):
return super().filtered_values(context, fieldname)

def set_valid_values(self, values):
"""
Expand Down Expand Up @@ -1083,6 +1153,11 @@ def values(self):
"""
return self._values

def filtered_values(
self, context: Context, fieldname: str
):
return super().filtered_values(context, fieldname)

def representation(self, value):
"""
Expand Down Expand Up @@ -1177,6 +1252,11 @@ def values(self):
"""
return [self._value]

def filtered_values(
self, context: Context, fieldname: str
):
return super().filtered_values(context, fieldname)

def representation(self, value):
"""
Expand Down Expand Up @@ -1285,6 +1365,11 @@ def values(self):
"""
return [self._reg]

def filtered_values(
self, context: Context, fieldname: str
):
return super().filtered_values(context, fieldname)

def random_value(self):
"""Return a random possible value for the operand.
Expand Down Expand Up @@ -1393,6 +1478,11 @@ def values(self):
"""
return [self._mindispl << self._shift]

def filtered_values(
self, context: Context, fieldname: str
):
return super().filtered_values(context, fieldname)

def random_value(self):
"""Return a random possible value for the operand.
Expand Down
Loading

0 comments on commit ab246db

Please sign in to comment.