Skip to content

Commit

Permalink
[WIP] Migrate to cubecl IR refactor (#2418)
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge authored Oct 30, 2024
1 parent 69856a9 commit 5730f02
Show file tree
Hide file tree
Showing 25 changed files with 302 additions and 382 deletions.
128 changes: 78 additions & 50 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
cubecl = { version="0.3.0", default-features = false }
cubecl-common = { version="0.3.0", default-features = false }
# cubecl = { version = "0.3.0", default-features = false }
# cubecl-common = { version = "0.3.0", default-features = false }

### For xtask crate ###
tracel-xtask = { version = "~1.1" }
Expand Down
55 changes: 18 additions & 37 deletions crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use cubecl::{
cpa,
ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility},
ir::{
Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility,
},
CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings,
OutputInfo,
};
Expand Down Expand Up @@ -39,7 +41,7 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
let weight = self.weight;
let bias = self.bias;
let output = self.output;
let id = Variable::AbsolutePos;
let id = Variable::builtin(Builtin::AbsolutePos);

let input_stride_0 = scope.create_local(Elem::UInt);
let input_stride_1 = scope.create_local(Elem::UInt);
Expand Down Expand Up @@ -92,34 +94,13 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
cpa!(scope, kernel_size_1 = shape(weight, 3u32));

let conv_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let conv_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let groups = Variable::GlobalScalar {
id: 6,
elem: Elem::UInt,
};
let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt));
let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt));
let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt));
let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt));
let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt));
let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt));
let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt));

let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
Expand Down Expand Up @@ -222,9 +203,9 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, index_input_b = b * input_stride_0);
cpa!(scope, index_weight_oc = oc * weight_stride_1);

let prod = scope.create_local(output.item());
let prod_tmp = scope.create_local(output.item());
let sum = scope.create_local(output.item());
let prod = scope.create_local(output.item);
let prod_tmp = scope.create_local(output.item);
let sum = scope.create_local(output.item);
cpa!(scope, sum = bias[oc_out]);

let kh = scope.create_local(Elem::UInt);
Expand Down Expand Up @@ -314,10 +295,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();

let input = Variable::GlobalInputArray { id: 0, item };
let weight = Variable::GlobalInputArray { id: 1, item };
let bias = Variable::GlobalInputArray { id: 2, item };
let output = Variable::GlobalOutputArray { id: 0, item };
let input = Variable::new(VariableKind::GlobalInputArray(0), item);
let weight = Variable::new(VariableKind::GlobalInputArray(1), item);
let bias = Variable::new(VariableKind::GlobalInputArray(2), item);
let output = Variable::new(VariableKind::GlobalOutputArray(0), item);

scope.write_global_custom(output);

Expand Down
Loading

0 comments on commit 5730f02

Please sign in to comment.