Skip to content

Commit

Permalink
Merge #64
Browse files Browse the repository at this point in the history
64: replace `ir::Type::Void` by an option r=ulysseB a=ulysseB

We currently represented instructions that did not produce a value with the `ir::Type::Void` variant. However, this produced a lot of special cases where we add to test if the type was void. Instead we now use an `Option<ir::Type>` in the places where the type could before be void. This will simplify future improvements to the value system, in particular, #41 .

Co-authored-by: Ulysse Beaugnon <[email protected]>
  • Loading branch information
bors[bot] and Ulysse Beaugnon committed Jul 17, 2018
2 parents af045df + 3e3885a commit a6dd14e
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 73 deletions.
8 changes: 5 additions & 3 deletions src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ pub struct Instruction<'a> {
instruction: &'a ir::Instruction<'a>,
instantiation_dims: Vec<(ir::dim::Id, u32)>,
mem_flag: Option<search_space::InstFlag>,
t: ir::Type,
t: Option<ir::Type>,
}

impl<'a> Instruction<'a> {
Expand All @@ -281,7 +281,9 @@ impl<'a> Instruction<'a> {
}).collect();
let mem_flag = instruction.as_mem_inst()
.map(|inst| space.domain().get_inst_flag(inst.id()));
let t = unwrap!(space.ir_instance().device().lower_type(instruction.t(), space));
let t = instruction.t().map(|t| {
unwrap!(space.ir_instance().device().lower_type(t, space))
});
Instruction { instruction, instantiation_dims, mem_flag, t }
}

Expand All @@ -297,7 +299,7 @@ impl<'a> Instruction<'a> {
}

/// Returns the type of the instruction.
pub fn t(&self) -> ir::Type { self.t }
pub fn t(&self) -> Option<ir::Type> { self.t }

/// Returns the operator computed by the instruction.
pub fn operator(&self) -> &ir::Operator { self.instruction.operator() }
Expand Down
6 changes: 4 additions & 2 deletions src/codegen/namer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl<'a, 'b> NameMap<'a, 'b> {
for inst in function.cfg().instructions() {
if let Some((inst_id, dim_map)) = inst.as_reduction() {
name_map.decl_alias(inst, inst_id, dim_map);
} else if inst.t() != Type::Void {
} else if inst.t().is_some() {
name_map.decl_inst(inst);
}
}
Expand Down Expand Up @@ -201,7 +201,9 @@ impl<'a, 'b> NameMap<'a, 'b> {
fn decl_inst(&mut self, inst: &Instruction) {
let (dim_ids, dim_sizes) = self.inst_name_dims(inst);
let num_name = dim_sizes.iter().product();
let names = (0 .. num_name).map(|_| self.gen_name(inst.t())).collect_vec();
let names = (0 .. num_name).map(|_| {
self.gen_name(unwrap!(inst.t()))
}).collect_vec();
let array = NDArray::new(dim_sizes, names);
assert!(self.insts.insert(inst.id(), (dim_ids, array)).is_none());
}
Expand Down
9 changes: 5 additions & 4 deletions src/codegen/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ pub trait Printer {
fn inst(&mut self, inst: &Instruction, namer: &mut NameMap, fun: &Function ) {
match *inst.operator() {
op::BinOp(op, ref lhs, ref rhs, round) => {
self.print_binop(op, inst.t(), round, &namer.name_inst(inst),
let t = unwrap!(inst.t());
self.print_binop(op, t, round, &namer.name_inst(inst),
&namer.name_op(lhs), &namer.name_op(rhs))
}
op::Mul(ref lhs, ref rhs, round, return_type) => {
Expand All @@ -324,14 +325,14 @@ pub trait Printer {
let low_mlhs_type = Self::lower_type(mul_lhs.t(), fun);
let low_arhs_type = Self::lower_type(add_rhs.t(), fun);
let mode = MulMode::from_type(low_mlhs_type, low_arhs_type);
self.print_mad(inst.t(), round, mode, &namer.name_inst(inst),
self.print_mad(unwrap!(inst.t()), round, mode, &namer.name_inst(inst),
&namer.name_op(mul_lhs),
&namer.name_op(mul_rhs),
&namer.name_op(add_rhs))
},
op::Mov(ref op) => {

self.print_mov(inst.t(), &namer.name_inst(inst), &namer.name_op(op))
let t = unwrap!(inst.t());
self.print_mov(t, &namer.name_inst(inst), &namer.name_op(op))
},
op::Ld(ld_type, ref addr, _) => {
self.print_ld(ld_type, unwrap!(inst.mem_flag()),
Expand Down
48 changes: 24 additions & 24 deletions src/device/cuda/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,44 +239,44 @@ impl Gpu {
inst: &ir::Instruction,
ctx: &device::Context) -> HwPressure {
use ir::Operator::*;
let t = self.lower_type(inst.t(), space).unwrap_or_else(|| inst.t());
let t = inst.t().map(|t| self.lower_type(t, space).unwrap_or(t));
match (inst.operator(), t) {
(&BinOp(ir::BinOp::Add, ..), Type::F(32)) |
(&BinOp(ir::BinOp::Sub, ..), Type::F(32)) => self.add_f32_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Type::F(64)) |
(&BinOp(ir::BinOp::Sub, ..), Type::F(64)) => self.add_f64_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Type::I(32)) |
(&BinOp(ir::BinOp::Sub, ..), Type::I(32)) => self.add_i32_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Type::I(64)) |
(&BinOp(ir::BinOp::Sub, ..), Type::I(64)) => self.add_i64_inst.into(),
(&Mul(..), Type::F(32)) => self.mul_f32_inst.into(),
(&Mul(..), Type::F(64)) => self.mul_f64_inst.into(),
(&Mul(..), Type::I(32)) |
(&Mul(..), Type::PtrTo(_)) => self.mul_i32_inst.into(),
(&Mul(ref op, _, _, _), Type::I(64)) => {
(&BinOp(ir::BinOp::Add, ..), Some(Type::F(32))) |
(&BinOp(ir::BinOp::Sub, ..), Some(Type::F(32))) => self.add_f32_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Some(Type::F(64))) |
(&BinOp(ir::BinOp::Sub, ..), Some(Type::F(64))) => self.add_f64_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Some(Type::I(32))) |
(&BinOp(ir::BinOp::Sub, ..), Some(Type::I(32))) => self.add_i32_inst.into(),
(&BinOp(ir::BinOp::Add, ..), Some(Type::I(64))) |
(&BinOp(ir::BinOp::Sub, ..), Some(Type::I(64))) => self.add_i64_inst.into(),
(&Mul(..), Some(Type::F(32))) => self.mul_f32_inst.into(),
(&Mul(..), Some(Type::F(64))) => self.mul_f64_inst.into(),
(&Mul(..), Some(Type::I(32))) |
(&Mul(..), Some(Type::PtrTo(_))) => self.mul_i32_inst.into(),
(&Mul(ref op, _, _, _), Some(Type::I(64))) => {
let op_t = self.lower_type(op.t(), space).unwrap_or_else(|| op.t());
if op_t == Type::I(64) {
self.mul_i64_inst.into()
} else {
self.mul_wide_inst.into()
}
},
(&Mad(..), Type::F(32)) => self.mad_f32_inst.into(),
(&Mad(..), Type::F(64)) => self.mad_f64_inst.into(),
(&Mad(..), Type::I(32)) |
(&Mad(..), Type::PtrTo(_)) => self.mad_i32_inst.into(),
(&Mad(ref op, _, _, _), Type::I(64)) => {
(&Mad(..), Some(Type::F(32))) => self.mad_f32_inst.into(),
(&Mad(..), Some(Type::F(64))) => self.mad_f64_inst.into(),
(&Mad(..), Some(Type::I(32))) |
(&Mad(..), Some(Type::PtrTo(_))) => self.mad_i32_inst.into(),
(&Mad(ref op, _, _, _), Some(Type::I(64))) => {
let op_t = self.lower_type(op.t(), space).unwrap_or_else(|| op.t());
if op_t == Type::I(64) {
self.mad_i64_inst.into()
} else {
self.mad_wide_inst.into()
}
},
(&BinOp(ir::BinOp::Div, ..), Type::F(32)) => self.div_f32_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Type::F(64)) => self.div_f64_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Type::I(32)) => self.div_i32_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Type::I(64)) => self.div_i64_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Some(Type::F(32))) => self.div_f32_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Some(Type::F(64))) => self.div_f64_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Some(Type::I(32))) => self.div_i32_inst.into(),
(&BinOp(ir::BinOp::Div, ..), Some(Type::I(64))) => self.div_i64_inst.into(),
(&Ld(..), _) | (&TmpLd(..), _) => {
let flag = space.domain().get_inst_flag(inst.id());
let mem_info = mem_model::analyse(space, self, inst, dim_sizes, ctx);
Expand Down Expand Up @@ -331,7 +331,7 @@ impl device::Device for Gpu {
fn is_valid_type(&self, t: &Type) -> bool {
match *t {
Type::I(i) | Type::F(i) => i == 32 || i == 64,
Type::Void | Type::PtrTo(_) => true,
Type::PtrTo(_) => true,
}
}

Expand Down
14 changes: 6 additions & 8 deletions src/device/cuda/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ impl CudaPrinter {
/// Prints a `Type` for the host.
fn host_type(t: &Type) -> &'static str {
match *t {
Type::Void => "void",
Type::PtrTo(..) => "CUdeviceptr",
Type::F(32) => "float",
Type::F(64) => "double",
Expand Down Expand Up @@ -290,7 +289,6 @@ impl Printer for CudaPrinter {
/// Print a type in the backend
fn get_type(t: Type) -> String {
match t {
Type::Void => panic!("void type cannot be printed"),
Type::I(1) => "pred".to_string(),
Type::I(size) => format!("s{size}", size = size),
Type::F(size) => format!("f{size}", size = size),
Expand All @@ -302,7 +300,7 @@ impl Printer for CudaPrinter {
fn print_binop(&mut self, op: ir::BinOp,
return_type: Type,
rounding: op::Rounding,
return_id: &str, lhs: &str, rhs: &str) {
return_id: &str, lhs: &str, rhs: &str) {
let op = Self::binary_op(op);
let rounding = Self::rounding(rounding);
let ret_type = Self::get_type(return_type);
Expand Down Expand Up @@ -342,27 +340,27 @@ impl Printer for CudaPrinter {
operator, t, return_id, mlhs, mrhs, arhs));
}

/// Print return_id = op
/// Print return_id = op
fn print_mov(&mut self, return_type: Type, return_id: &str, op: &str) {
unwrap!(writeln!(self.out_function, "mov.{} {}, {};",
Self::get_type(return_type), return_id, op));
}

/// Print return_id = load [addr]
/// Print return_id = load [addr]
fn print_ld(&mut self, return_type: Type, flag: InstFlag, return_id: &str, addr: &str) {
let operator = Self::ld_operator(flag);
unwrap!(writeln!(self.out_function, "{}.{} {}, [{}];",
operator, Self::get_type(return_type), return_id, addr));
}

/// Print store val [addr]
/// Print store val [addr]
fn print_st(&mut self, val_type: Type, mem_flag: InstFlag, addr: &str, val: &str) {
let operator = Self::st_operator(mem_flag);
unwrap!(writeln!(self.out_function, "{}.{} [{}], {};",
operator, Self::get_type(val_type), addr, val));
}

/// Print if (cond) store val [addr]
/// Print if (cond) store val [addr]
fn print_cond_st(&mut self, val_type: Type,
mem_flag: InstFlag,
cond: &str, addr: &str, val: &str) {
Expand Down Expand Up @@ -435,7 +433,7 @@ impl Printer for CudaPrinter {
let dst = (0..size).map(|i| {
namer.indexed_inst_name(inst, dim.id(), i).to_string()
}).collect_vec().join(", ");
let t = Self::get_type(inst.t());
let t = Self::get_type(unwrap!(inst.t()));
unwrap!(writeln!(self.out_function, "{}.{} {{{}}}, [{}];",
operator, t, dst, namer.name_op(addr)))
},
Expand Down
2 changes: 1 addition & 1 deletion src/device/x86/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl device::Device for Cpu {
fn is_valid_type(&self, t: &Type) -> bool {
match *t {
Type::I(i) | Type::F(i) => i == 32 || i == 64,
Type::Void | Type::PtrTo(_) => true,
Type::PtrTo(_) => true,
}
}

Expand Down
1 change: 0 additions & 1 deletion src/device/x86/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,6 @@ impl Printer for X86printer {

fn get_type(t: Type) -> String {
match t {
Type::Void => String::from("void"),
//Type::PtrTo(..) => " uint8_t *",
Type::PtrTo(..) => String::from("intptr_t"),
Type::F(32) => String::from("float"),
Expand Down
2 changes: 1 addition & 1 deletion src/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<'a> Instruction<'a> {
pub fn operands(&self) -> Vec<&Operand<'a>> { self.operator.operands() }

/// Returns the type of the value produced by an instruction.
pub fn t(&self) -> Type { self.operator.t() }
pub fn t(&self) -> Option<Type> { self.operator.t() }

/// Returns the operator of the instruction.
pub fn operator(&self) -> &Operator { &self.operator }
Expand Down
8 changes: 3 additions & 5 deletions src/ir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,16 @@ impl<'a> Operand<'a> {
pub fn new_inst(inst: &Instruction, dim_map: DimMap, mut scope: DimMapScope)
-> Operand<'a> {
// A temporary arry can only be generated if the type size is known.
assert_ne!(inst.t(), Type::Void);
if scope == DimMapScope::Global && inst.t().len_byte().is_none() {
if scope == DimMapScope::Global && unwrap!(inst.t()).len_byte().is_none() {
scope = DimMapScope::Thread
}
Inst(inst.id(), inst.t(), dim_map, scope)
Inst(inst.id(), unwrap!(inst.t()), dim_map, scope)
}

/// Creates a reduce operand from an instruction and a set of dimensions to reduce on.
pub fn new_reduce(init: &Instruction, dim_map: DimMap, dims: Vec<ir::dim::Id>)
-> Operand<'a> {
assert_ne!(init.t(), Type::Void);
Reduce(init.id(), init.t(), dim_map, dims)
Reduce(init.id(), unwrap!(init.t()), dim_map, dims)
}

/// Creates a new Int operand and checks its number of bits.
Expand Down
27 changes: 9 additions & 18 deletions src/ir/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,10 @@ pub enum Operator<'a> {
impl<'a> Operator<'a> {
/// Ensures the types of the operands are valid.
pub fn type_check(&self, device: &Device) {
assert!(device.is_valid_type(&self.t()));
if let Some(t) = self.t() { assert!(device.is_valid_type(&t)); }
// Check operand types.
for operand in self.operands() {
let t = operand.t();
assert_ne!(t, Type::Void);
assert!(device.is_valid_type(&t));
assert!(device.is_valid_type(&operand.t()));
}
match *self {
BinOp(_, ref lhs, ref rhs, rounding) => {
Expand All @@ -102,29 +100,22 @@ impl<'a> Operator<'a> {
_ => panic!(),
}
},
Ld(ref t, ref addr, ref pattern) => {
assert_ne!(*t, Type::Void);
assert_eq!(addr.t(), Type::PtrTo(pattern.mem_block()));
},
Ld(_, ref addr, ref pattern) =>
assert_eq!(addr.t(), Type::PtrTo(pattern.mem_block())),
St(ref addr, _, _, ref pattern) =>
assert_eq!(addr.t(), Type::PtrTo(pattern.mem_block())),
TmpLd(ref t, _) => assert_ne!(*t, Type::Void),
Cast(ref src, ref t) => {
assert_ne!(src.t(), Type::Void);
assert_ne!(*t, Type::Void);
},
Mov(..) | TmpSt(..) => (),
TmpLd(..) | Cast(..) | Mov(..) | TmpSt(..) => (),
}
}

/// Returns the type of the value produced.
pub fn t(&self) -> Type {
pub fn t(&self) -> Option<Type> {
match *self {
BinOp(_, ref op, ..) |
Mov(ref op) |
Mad(_, _, ref op, _) => op.t(),
Ld(t, ..) | TmpLd(t, _) | Cast(_, t) | Mul(.., t) => t,
St(..) | TmpSt(..) => Type::Void,
Mad(_, _, ref op, _) => Some(op.t()),
Ld(t, ..) | TmpLd(t, _) | Cast(_, t) | Mul(.., t) => Some(t),
St(..) | TmpSt(..) => None,
}
}

Expand Down
8 changes: 2 additions & 6 deletions src/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ use utils::*;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
/// Values and intructions types.
pub enum Type {
/// Type for instructions that do not produce a value.
Void,
/// Type for integer values, with a fixed number of bits.
I(u16),
/// Type for floating point values, with a fixed number of bits.
Expand All @@ -21,23 +19,22 @@ impl Type {
pub fn is_integer(&self) -> bool {
match *self {
Type::I(_) | Type::PtrTo(_) => true,
Type::Void | Type::F(_) => false,
Type::F(_) => false,
}
}

/// Returns true if the type is a float.
pub fn is_float(&self) -> bool {
match *self {
Type::F(_) => true,
Type::Void | Type::I(_) | Type::PtrTo(..) => false,
Type::I(_) | Type::PtrTo(..) => false,
}
}

/// Returns the number of bytes of the type.
pub fn len_byte(&self) -> Option<u32> {
match *self {
Type::I(i) | Type::F(i) => Some(u32::from(div_ceil(i, 8))),
Type::Void => Some(0),
Type::PtrTo(_) => None
}
}
Expand All @@ -46,7 +43,6 @@ impl Type {
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Type::Void => write!(f, "void"),
Type::I(s) => write!(f, "i{}", s),
Type::F(s) => write!(f, "f{}", s),
Type::PtrTo(mem) => write!(f, "ptr to {:?}", mem),
Expand Down

0 comments on commit a6dd14e

Please sign in to comment.