Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR Refactor #199

Merged
merged 13 commits into from
Oct 29, 2024
33 changes: 16 additions & 17 deletions crates/cubecl-core/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::Compiler;
use crate::{
ir::{
Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable,
Vectorization, Visibility,
VariableKind, Vectorization, Visibility,
},
Runtime,
};
Expand Down Expand Up @@ -399,15 +399,15 @@ impl KernelIntegrator {
size: None,
});
self.expansion.scope.write_global(
Variable::Local {
id: local,
Variable::new(
VariableKind::Local {
id: local,

depth: self.expansion.scope.depth,
},
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalOutputArray {
id: index,
item: item_adapted,
},
),
Variable::new(VariableKind::GlobalOutputArray(index), item_adapted),
position,
);
index += 1;
Expand All @@ -419,15 +419,14 @@ impl KernelIntegrator {
position,
} => {
self.expansion.scope.write_global(
Variable::Local {
id: local,
Variable::new(
VariableKind::Local {
id: local,
depth: self.expansion.scope.depth,
},
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalInputArray {
id: input,
item: bool_item(item),
},
),
Variable::new(VariableKind::GlobalInputArray(input), bool_item(item)),
position,
);
}
Expand Down
5 changes: 2 additions & 3 deletions crates/cubecl-core/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,7 @@ pub fn if_else_expr_expand<C: CubePrimitive>(
None => {
let mut then_child = context.child();
let ret = then_block(&mut then_child);
let out: ExpandElementTyped<C> =
context.create_local_variable(ret.expand.item()).into();
let out: ExpandElementTyped<C> = context.create_local_variable(ret.expand.item).into();
assign::expand(&mut then_child, ret, out.clone());

IfElseExprExpand::Runtime {
Expand Down Expand Up @@ -502,7 +501,7 @@ pub fn switch_expand_expr<I: Int, C: CubePrimitive>(
) -> SwitchExpandExpr<I, C> {
let mut default_child = context.child();
let default = default_block(&mut default_child);
let out: ExpandElementTyped<C> = context.create_local_variable(default.expand.item()).into();
let out: ExpandElementTyped<C> = context.create_local_variable(default.expand.item).into();
assign::expand(&mut default_child, default, out.clone());

SwitchExpandExpr {
Expand Down
72 changes: 43 additions & 29 deletions crates/cubecl-core/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
use std::marker::PhantomData;

use crate::{
ir::{self, Operation},
ir::{self, Instruction, Operation},
unexpanded,
};

Expand Down Expand Up @@ -239,10 +239,10 @@ pub mod fill {
value: ExpandElementTyped<C>,
) {
let value: ExpandElement = value.into();
context.register(Operation::CoopMma(ir::CoopMma::Fill {
mat: *mat.elem,
value: *value,
}));
context.register(Instruction::new(
ir::CoopMma::Fill { value: *value },
*mat.elem,
));
}
}

Expand Down Expand Up @@ -271,12 +271,14 @@ pub mod load {
"Loading accumulator requires explicit layout. Use `load_with_layout` instead."
);

context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
layout: None,
}));
context.register(Instruction::new(
ir::CoopMma::Load {
value: *value.expand,
stride: *stride,
layout: None,
},
*mat.elem,
));
}
}

Expand Down Expand Up @@ -307,12 +309,14 @@ pub mod load_with_layout {
) {
let stride: ExpandElement = stride.into();

context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
layout: Some(layout),
}));
context.register(Instruction::new(
ir::CoopMma::Load {
value: *value.expand,
stride: *stride,
layout: Some(layout),
},
*mat.elem,
));
}
}

Expand Down Expand Up @@ -342,12 +346,14 @@ pub mod store {
) {
let stride: ExpandElement = stride.into();

context.register(Operation::CoopMma(ir::CoopMma::Store {
output: *output.expand,
mat: *mat.elem,
stride: *stride,
layout,
}));
context.register(Instruction::new(
ir::CoopMma::Store {
mat: *mat.elem,
stride: *stride,
layout,
},
*output.expand,
));
}
}

Expand All @@ -374,11 +380,19 @@ pub mod execute {
mat_c: MatrixExpand,
mat_d: MatrixExpand,
) {
context.register(Operation::CoopMma(ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
mat_d: *mat_d.elem,
}));
context.register(Instruction::new(
ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
},
*mat_d.elem,
));
}
}

impl From<ir::CoopMma> for Operation {
fn from(value: ir::CoopMma) -> Self {
Operation::CoopMma(value)
}
}
46 changes: 27 additions & 19 deletions crates/cubecl-core/src/frontend/container/array/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ mod vectorization {
vectorization_factor: u32,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(value) => value.as_u32(),
let size = match size.kind {
crate::ir::VariableKind::ConstantScalar(value) => value.as_u32(),
_ => panic!("Shared memory need constant initialization value"),
};
context
Expand All @@ -131,7 +131,7 @@ mod vectorization {
.expect("Vectorization must be comptime")
.as_u32();
let var = self.expand.clone();
let item = Item::vectorized(var.item().elem(), NonZero::new(factor as u8));
let item = Item::vectorized(var.item.elem(), NonZero::new(factor as u8));

let new_var = if factor == 1 {
let new_var = context.create_local_binding(item);
Expand Down Expand Up @@ -160,6 +160,8 @@ mod vectorization {

/// Module that contains the implementation details of the metadata functions.
mod metadata {
use crate::ir::Instruction;

use super::*;

impl<E: CubeType> Array<E> {
Expand All @@ -174,10 +176,12 @@ mod metadata {
// Expand method of [len](Array::len).
pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
let out = context.create_local_binding(Item::new(Elem::UInt));
context.register(Metadata::Length {
var: self.expand.into(),
out: out.clone().into(),
});
context.register(Instruction::new(
Metadata::Length {
var: self.expand.into(),
},
out.clone().into(),
));
out.into()
}
}
Expand All @@ -186,7 +190,7 @@ mod metadata {
/// Module that contains the implementation details of the index functions.
mod indexation {
use crate::{
ir::{BinaryOperator, Operator},
ir::{BinaryOperator, Instruction, Operator},
prelude::{CubeIndex, CubeIndexMut},
};

Expand Down Expand Up @@ -224,12 +228,14 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item());
context.register(Operator::UncheckedIndex(BinaryOperator {
out: *out,
lhs: *self.expand,
rhs: i.expand.consume(),
}));
let out = context.create_local_binding(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
rhs: i.expand.consume(),
}),
*out,
));
out.into()
}

Expand All @@ -239,11 +245,13 @@ mod indexation {
i: ExpandElementTyped<u32>,
value: ExpandElementTyped<E>,
) {
context.register(Operator::UncheckedIndexAssign(BinaryOperator {
out: *self.expand,
lhs: i.expand.consume(),
rhs: value.expand.consume(),
}));
context.register(Instruction::new(
Operator::UncheckedIndexAssign(BinaryOperator {
lhs: i.expand.consume(),
rhs: value.expand.consume(),
}),
*self.expand,
));
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions crates/cubecl-core/src/frontend/container/line/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::num::NonZero;

use crate::{
ir::{ConstantScalarValue, Item},
prelude::{assign, CubeContext, ExpandElement},
prelude::{CubeContext, ExpandElement},
unexpanded,
};

Expand Down Expand Up @@ -40,6 +40,8 @@ mod new {

/// Module that contains the implementation details of the fill function.
mod fill {
use crate::prelude::cast;

use super::*;

impl<P: CubePrimitive + Into<ExpandElementTyped<P>>> Line<P> {
Expand Down Expand Up @@ -76,10 +78,10 @@ mod fill {
context: &mut CubeContext,
value: ExpandElementTyped<P>,
) -> Self {
let length = self.expand.item().vectorization;
let length = self.expand.item.vectorization;
let output = context.create_local_binding(Item::vectorized(P::as_elem(), length));

assign::expand::<P>(context, value, output.clone().into());
cast::expand::<P>(context, value, output.clone().into());

output.into()
}
Expand Down Expand Up @@ -153,7 +155,7 @@ mod size {
/// Comptime version of [size](Line::size).
pub fn size(&self) -> u32 {
self.expand
.item()
.item
.vectorization
.unwrap_or(NonZero::new(1).unwrap())
.get() as u32
Expand Down
28 changes: 16 additions & 12 deletions crates/cubecl-core/src/frontend/container/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
/// Module that contains the implementation details of the index functions.
mod indexation {
use crate::{
ir::{BinaryOperator, Operator},
ir::{BinaryOperator, Instruction, Operator},
prelude::{CubeIndex, CubeIndexMut},
unexpanded,
};
Expand Down Expand Up @@ -129,12 +129,14 @@ mod indexation {
context: &mut CubeContext,
i: ExpandElementTyped<u32>,
) -> ExpandElementTyped<E> {
let out = context.create_local_binding(self.expand.item());
context.register(Operator::UncheckedIndex(BinaryOperator {
out: *out,
lhs: *self.expand,
rhs: i.expand.consume(),
}));
let out = context.create_local_binding(self.expand.item);
context.register(Instruction::new(
Operator::UncheckedIndex(BinaryOperator {
lhs: *self.expand,
rhs: i.expand.consume(),
}),
*out,
));
out.into()
}

Expand All @@ -144,11 +146,13 @@ mod indexation {
i: ExpandElementTyped<u32>,
value: ExpandElementTyped<E>,
) {
context.register(Operator::UncheckedIndexAssign(BinaryOperator {
out: *self.expand,
lhs: i.expand.consume(),
rhs: value.expand.consume(),
}));
context.register(Instruction::new(
Operator::UncheckedIndexAssign(BinaryOperator {
lhs: i.expand.consume(),
rhs: value.expand.consume(),
}),
*self.expand,
));
}
}
}
Loading