diff --git a/Cargo.lock b/Cargo.lock index 594c02e0cc..d0ffe6bbd5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,7 +524,7 @@ dependencies = [ name = "burn-common" version = "0.16.0" dependencies = [ - "cubecl-common", + "cubecl-common 0.4.0", "dashmap", "getrandom", "indicatif", @@ -1459,15 +1459,14 @@ dependencies = [ [[package]] name = "cubecl" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75e75c7e982b943380665c5901fe0b69d5df2627644e0e50199c52b64d8d5a1c" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "cubecl-core", "cubecl-cuda", "cubecl-hip", "cubecl-linalg", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "cubecl-wgpu", ] @@ -1489,16 +1488,32 @@ dependencies = [ "web-time", ] +[[package]] +name = "cubecl-common" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" +dependencies = [ + "derive-new", + "embassy-futures", + "futures-lite", + "getrandom", + "log", + "portable-atomic", + "rand", + "serde", + "spin", + "web-time", +] + [[package]] name = "cubecl-core" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec33b64139d1dfc747df8aed5834d10c3c55c716f5219041c6eb17241c96c929" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-macros", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "derive-new", "half", "log", @@ -1509,14 +1524,13 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ded461feb0ff342a4f675131dc0ae8ad94e58f66bad11e57f852cb7f190a731" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "derive-new", "half", "log", @@ -1524,15 +1538,14 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88dfdfe616124d2abe5e82052ff56f86843c369440e181d6936f7409e161dd82" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", "cubecl-cpp", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "cudarc 0.12.1", "derive-new", "half", @@ -1541,16 +1554,15 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "409e0e176152ab51a60bbebb940b7a72aba210cd42b5f8cd2e87e7d7e674a13a" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", "cubecl-cpp", "cubecl-hip-sys", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "derive-new", "half", "log", @@ -1567,23 +1579,21 @@ dependencies = [ [[package]] name = "cubecl-linalg" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5634782d790e9b6562fc267ffd15e9a510b4d6ec32c144cd2b166af2ba0cfb" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "bytemuck", "cubecl-core", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "half", ] [[package]] name = "cubecl-macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2d22663257d9cdbcd67f5048d6f4e6eb965dd87104c3a173a7b0ea0d720e99b" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ - "cubecl-common", + "cubecl-common 0.4.0", "darling", "derive-new", "ident_case", @@ -1595,11 +1605,10 @@ dependencies = [ [[package]] name = "cubecl-opt" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59fba2561a6ceb99e9c5fe7313db0aeead02b848dc9cdacf8373e0fe98d3c247" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", "float-ord", "log", @@ -1618,7 +1627,28 @@ dependencies = [ "async-channel", "async-lock", "cfg_aliases 0.2.1", - "cubecl-common", + "cubecl-common 0.3.0", + "derive-new", + "dirs 5.0.1", + "hashbrown 0.14.5", + "log", + "md5", + "sanitize-filename", + "serde", + "serde_json", + "spin", + "wasm-bindgen-futures", +] + +[[package]] +name = "cubecl-runtime" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" +dependencies = [ + "async-channel", + "async-lock", + "cfg_aliases 0.2.1", + "cubecl-common 0.4.0", "derive-new", "dirs 5.0.1", "hashbrown 0.14.5", @@ -1633,31 +1663,29 @@ dependencies = [ [[package]] name = "cubecl-spirv" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "835bc234cdd40fbb5e3e5e41bfb4a6e2ee2d7fd899b66d44dcbcb3a825ca1a59" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", "cubecl-opt", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "hashbrown 0.14.5", "rspirv", ] [[package]] name = "cubecl-wgpu" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6779f1072d70923758421c6214fd0cd19a6f25b91035a522f9cd9407d03b5cae" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=99404b1e29946832a42b72a5c26d4cf42c67692e#99404b1e29946832a42b72a5c26d4cf42c67692e" dependencies = [ "ash", "async-channel", "bytemuck", "cfg_aliases 0.2.1", - "cubecl-common", + "cubecl-common 0.4.0", "cubecl-core", - "cubecl-runtime", + "cubecl-runtime 0.4.0", "cubecl-spirv", "derive-new", "hashbrown 0.14.5", @@ -3265,7 +3293,7 @@ dependencies = [ "burn-candle", "burn-import", "console_error_panic_hook", - "cubecl-runtime", + "cubecl-runtime 0.3.0", "js-sys", "log", "serde", @@ -3793,7 +3821,7 @@ version = "0.16.0" dependencies = [ "burn", "console_error_panic_hook", - "cubecl-runtime", + "cubecl-runtime 0.3.0", "js-sys", "serde", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index e77bbb875f..9226ee8535 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,14 +152,14 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.2", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "63da837b5ae78ff1a3b7363fe3af2b02c2bc864f" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "99404b1e29946832a42b72a5c26d4cf42c67692e" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### -cubecl = { version="0.3.0", default-features = false } -cubecl-common = { version="0.3.0", default-features = false } +# cubecl = { version = "0.3.0", default-features = false } +# cubecl-common = { version = "0.3.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "~1.1" } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 195a7ef3da..4a03a5a839 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -1,6 +1,8 @@ use cubecl::{ cpa, - ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, + ir::{ + Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, + }, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -39,7 +41,7 @@ impl Conv2dTransposeComputeShader { let weight = self.weight; let bias = self.bias; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt); @@ -92,34 +94,13 @@ impl Conv2dTransposeComputeShader { cpa!(scope, kernel_size_0 = shape(weight, 2u32)); cpa!(scope, kernel_size_1 = shape(weight, 3u32)); - let conv_stride_0 = Variable::GlobalScalar { - id: 0, - elem: Elem::UInt, - }; - let conv_stride_1 = Variable::GlobalScalar { - id: 1, - elem: Elem::UInt, - }; - let dilation_0 = Variable::GlobalScalar { - id: 2, - elem: Elem::UInt, - }; - let dilation_1 = Variable::GlobalScalar { - id: 3, - elem: Elem::UInt, - }; - let padding_0 = Variable::GlobalScalar { - id: 4, - elem: Elem::UInt, - }; - let padding_1 = Variable::GlobalScalar { - id: 5, - elem: Elem::UInt, - }; - let groups = Variable::GlobalScalar { - id: 6, - elem: Elem::UInt, - }; + let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); + let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); + let groups = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt)); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -222,9 +203,9 @@ impl Conv2dTransposeComputeShader { cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_weight_oc = oc * weight_stride_1); - let prod = scope.create_local(output.item()); - let prod_tmp = scope.create_local(output.item()); - let sum = scope.create_local(output.item()); + let prod = scope.create_local(output.item); + let prod_tmp = scope.create_local(output.item); + let sum = scope.create_local(output.item); cpa!(scope, sum = bias[oc_out]); let kh = scope.create_local(Elem::UInt); @@ -314,10 +295,10 @@ impl Kernel for Conv2dTransposeEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let weight = Variable::GlobalInputArray { id: 1, item }; - let bias = Variable::GlobalInputArray { id: 2, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let weight = Variable::new(VariableKind::GlobalInputArray(1), item); + let bias = Variable::new(VariableKind::GlobalInputArray(2), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs index fdbcffb74c..0e7a538124 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs @@ -1,6 +1,8 @@ use cubecl::{ cpa, - ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, + ir::{ + Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, + }, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -39,7 +41,7 @@ impl Conv3dTransposeComputeShader { let weight = self.weight; let bias = self.bias; let output = self.output; - let idx = Variable::AbsolutePos; + let idx = Variable::builtin(Builtin::AbsolutePos); let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt); @@ -104,46 +106,16 @@ impl Conv3dTransposeComputeShader { cpa!(scope, kernel_size_1 = shape(weight, 3u32)); cpa!(scope, kernel_size_2 = shape(weight, 4u32)); - let conv_stride_0 = Variable::GlobalScalar { - id: 0, - elem: Elem::UInt, - }; - let conv_stride_1 = Variable::GlobalScalar { - id: 1, - elem: Elem::UInt, - }; - let conv_stride_2 = Variable::GlobalScalar { - id: 2, - elem: Elem::UInt, - }; - let dilation_0 = Variable::GlobalScalar { - id: 3, - elem: Elem::UInt, - }; - let dilation_1 = Variable::GlobalScalar { - id: 4, - elem: Elem::UInt, - }; - let dilation_2 = Variable::GlobalScalar { - id: 5, - elem: Elem::UInt, - }; - let padding_0 = Variable::GlobalScalar { - id: 6, - elem: Elem::UInt, - }; - let padding_1 = Variable::GlobalScalar { - id: 7, - elem: Elem::UInt, - }; - let padding_2 = Variable::GlobalScalar { - id: 8, - elem: Elem::UInt, - }; - let groups = Variable::GlobalScalar { - id: 9, - elem: Elem::UInt, - }; + let conv_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); + let conv_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); + let conv_stride_2 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); + let dilation_2 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); + let padding_0 = Variable::new(VariableKind::GlobalScalar(6), Item::new(Elem::UInt)); + let padding_1 = Variable::new(VariableKind::GlobalScalar(7), Item::new(Elem::UInt)); + let padding_2 = Variable::new(VariableKind::GlobalScalar(8), Item::new(Elem::UInt)); + let groups = Variable::new(VariableKind::GlobalScalar(9), Item::new(Elem::UInt)); let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -273,9 +245,9 @@ impl Conv3dTransposeComputeShader { cpa!(scope, index_input_b = b * input_stride_0); cpa!(scope, index_weight_oc = oc * weight_stride_1); - let prod = scope.create_local(output.item()); - let prod_tmp = scope.create_local(output.item()); - let sum = scope.create_local(output.item()); + let prod = scope.create_local(output.item); + let prod_tmp = scope.create_local(output.item); + let sum = scope.create_local(output.item); cpa!(scope, sum = bias[oc_out]); let kd = scope.create_local(Elem::UInt); @@ -391,10 +363,10 @@ impl Kernel for Conv3dTransposeEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let weight = Variable::GlobalInputArray { id: 1, item }; - let bias = Variable::GlobalInputArray { id: 2, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let weight = Variable::new(VariableKind::GlobalInputArray(1), item); + let bias = Variable::new(VariableKind::GlobalInputArray(2), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 41aee3b08f..6e60371363 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -4,7 +4,7 @@ use crate::{ use burn_tensor::ElementConversion; use cubecl::{ cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -27,7 +27,7 @@ impl FlipComputeShader { pub fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let offset_input = scope.zero(Elem::UInt); let offset_local = scope.create_local(Elem::UInt); @@ -42,10 +42,10 @@ impl FlipComputeShader { cpa!(scope, shape = shape(output, i)); cpa!( scope, - flip = cast(Variable::GlobalScalar { - id: i as u16, - elem: Elem::UInt - }) + flip = cast(Variable::new( + VariableKind::GlobalScalar(i as u16), + Item::new(Elem::UInt) + )) ); cpa!(scope, flip_bool = flip == 1u32); @@ -61,7 +61,7 @@ impl FlipComputeShader { cpa!(scope, offset_input += offset_local); } - let result = scope.create_local(input.item()); + let result = scope.create_local(input.item); cpa!(scope, result = input[offset_input]); cpa!(scope, output[id] = result); } @@ -72,8 +72,8 @@ impl Kernel for FlipEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 628bd8efa2..d071f8dd3c 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -1,7 +1,7 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; use cubecl::{ cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -26,7 +26,7 @@ impl RepeatComputeShader { pub fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let offset_input = scope.zero(Elem::UInt); let offset_local = scope.zero(Elem::UInt); @@ -50,7 +50,7 @@ impl RepeatComputeShader { cpa!(scope, offset_input += offset_local); } - let result = scope.create_local(input.item()); + let result = scope.create_local(input.item); cpa!(scope, result = input[offset_input]); cpa!(scope, output[id] = result); } @@ -60,8 +60,8 @@ impl Kernel for RepeatEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 30aa0aca00..b2093a4498 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -4,7 +4,7 @@ use crate::{ use burn_tensor::{ElementConversion, Shape}; use cubecl::{ cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -27,7 +27,7 @@ impl SliceComputeShader { pub fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let offset_input = scope.zero(Elem::UInt); let offset_local = scope.create_local(Elem::UInt); @@ -43,10 +43,10 @@ impl SliceComputeShader { cpa!(scope, shape_output = shape(output, i)); cpa!( scope, - range_start = cast(Variable::GlobalScalar { - id: i as u16, - elem: Elem::UInt - }) + range_start = cast(Variable::new( + VariableKind::GlobalScalar(i as u16), + Item::new(Elem::UInt) + )) ); cpa!(scope, offset_local = id / stride_output); @@ -57,7 +57,7 @@ impl SliceComputeShader { cpa!(scope, offset_input += offset_local); } - let result = scope.create_local(input.item()); + let result = scope.create_local(input.item); cpa!(scope, result = input[offset_input]); cpa!(scope, output[id] = result); } @@ -68,8 +68,8 @@ impl Kernel for SliceEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index 42074616f8..fbfb270707 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -2,7 +2,7 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; use burn_tensor::ElementConversion; use cubecl::{ cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + ir::{Builtin, Elem, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, }; use std::{marker::PhantomData, ops::Range}; @@ -24,7 +24,7 @@ impl SliceAssignComputeShader { pub fn expand(self, scope: &mut Scope) { let input = self.input; let value = self.value; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let offset_input = scope.zero(Elem::UInt); let offset_value = scope.zero(Elem::UInt); @@ -46,10 +46,10 @@ impl SliceAssignComputeShader { cpa!(scope, shape_input = shape(input, i)); cpa!( scope, - range_start = cast(Variable::GlobalScalar { - id: i as u16, - elem: Elem::UInt - }) + range_start = cast(Variable::new( + VariableKind::GlobalScalar(i as u16), + Item::new(Elem::UInt) + )) ); cpa!(scope, offset_local = id / stride_value); @@ -66,7 +66,7 @@ impl SliceAssignComputeShader { cpa!(scope, offset_input += offset_local_input); } - let result = scope.create_local(input.item()); + let result = scope.create_local(input.item); cpa!(scope, result = value[offset_value]); cpa!(scope, input[offset_input] = result); } @@ -77,8 +77,8 @@ impl Kernel for SliceAssignEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray { id: 0, item }; - let value = Variable::GlobalInputArray { id: 1, item }; + let input = Variable::new(VariableKind::GlobalInputArray(0), item); + let value = Variable::new(VariableKind::GlobalInputArray(1), item); scope.write_global_custom(input); diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index e583187b5d..aaf202cab0 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -1,6 +1,6 @@ use cubecl::{ cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, + ir::{Builtin, Elem, KernelDefinition, Scope, Variable, VariableKind, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -24,7 +24,7 @@ impl InterpolateBicubicShader { pub(crate) fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let elem = E::cube_elem(); let input_stride_0 = scope.create_local(Elem::UInt); @@ -181,10 +181,10 @@ impl InterpolateBicubicShader { let index_1 = scope.create_local(Elem::UInt); let index_2 = scope.create_local(Elem::UInt); let index_3 = scope.create_local(Elem::UInt); - let inp_0 = scope.create_local(input.item()); - let inp_1 = scope.create_local(input.item()); - let inp_2 = scope.create_local(input.item()); - let inp_3 = scope.create_local(input.item()); + let inp_0 = scope.create_local(input.item); + let inp_1 = scope.create_local(input.item); + let inp_2 = scope.create_local(input.item); + let inp_3 = scope.create_local(input.item); cpa!(scope, index_0 = index_base); cpa!(scope, index_0 += y0_stride); @@ -276,7 +276,7 @@ impl InterpolateBicubicShader { fn min(scope: &mut Scope, a: Variable, b: Variable) -> Variable { let cond = scope.create_local(Elem::Bool); - let res = scope.create_local(a.item()); + let res = scope.create_local(a.item); cpa!(scope, cond = a < b); cpa!(scope, if(cond).then(|scope|{ @@ -296,7 +296,7 @@ impl InterpolateBicubicShader { x3: Variable, t: Variable, ) -> Variable { - let item = x0.item(); + let item = x0.item; let x = scope.create_local(item); let a: Variable = scope.create_with_value(-0.75, item); let one: Variable = scope.create_with_value(1, item); @@ -327,7 +327,7 @@ impl InterpolateBicubicShader { } fn cubic_convolution1(scope: &mut Scope, x: Variable, a: Variable) -> Variable { - let item = x.item(); + let item = x.item; let conv = scope.create_local(item); let tmp = scope.create_local(item); let one = scope.create_with_value(1, item); @@ -346,7 +346,7 @@ impl InterpolateBicubicShader { } fn cubic_convolution2(scope: &mut Scope, x: Variable, a: Variable) -> Variable { - let item = x.item(); + let item = x.item; let conv = scope.create_local(item); let tmp = scope.create_local(item); let four = scope.create_with_value(4, item); @@ -372,8 +372,8 @@ impl Kernel for InterpolateBicubicEagerKernel Kernel for InterpolateBilinearEagerKernel InterpolateNearestShader { pub(crate) fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let elem = E::cube_elem(); let input_stride_0 = scope.create_local(Elem::UInt); @@ -106,7 +106,7 @@ impl InterpolateNearestShader { let index = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); - let val = scope.create_local(output.item()); + let val = scope.create_local(output.item); cpa!(scope, index = b * input_stride_0); cpa!(scope, index_tmp = c * input_stride_1); @@ -126,8 +126,8 @@ impl Kernel for InterpolateNearestEagerKernel InterpolateNearestBackwardShader { fn expand(self, scope: &mut Scope) { let grad = self.out_grad; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt); @@ -88,7 +88,7 @@ impl InterpolateNearestBackwardShader { let gw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3); let gw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3); - let result = scope.create_local(grad.item()); + let result = scope.create_local(grad.item); let index_grad = scope.create_local(Elem::UInt); let index_grad_0 = scope.create_local(Elem::UInt); @@ -99,7 +99,7 @@ impl InterpolateNearestBackwardShader { cpa!(scope, index_grad_0 = b * grad_stride_0); cpa!(scope, index_grad_1 = c * grad_stride_1); - let sum = scope.zero(output.item()); + let sum = scope.zero(output.item); cpa!( scope, @@ -184,8 +184,8 @@ impl Kernel for InterpolateNearestBackwardEagerKer let mut scope = Scope::root(); let item = E::cube_elem().into(); - let out_grad = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let out_grad = Variable::new(VariableKind::GlobalInputArray(0), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); InterpolateNearestBackwardShader { out_grad, diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index 2abf8dce5a..a6e62c8464 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -7,7 +7,9 @@ use crate::{ }; use cubecl::{ cpa, - ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, + ir::{ + Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, + }, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -32,7 +34,7 @@ impl AvgPool2dBackwardComputeShader { fn expand(self, scope: &mut Scope) { let grad = self.grad; let output = self.output; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt); @@ -70,22 +72,10 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let pool_stride_0 = Variable::GlobalScalar { - id: 0, - elem: Elem::UInt, - }; - let pool_stride_1 = Variable::GlobalScalar { - id: 1, - elem: Elem::UInt, - }; - let padding_0 = Variable::GlobalScalar { - id: 4, - elem: Elem::UInt, - }; - let padding_1 = Variable::GlobalScalar { - id: 5, - elem: Elem::UInt, - }; + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); let [kernel_size_0, kernel_size_1] = self.kernel_size; let b = scope.create_local(Elem::UInt); @@ -116,9 +106,9 @@ impl AvgPool2dBackwardComputeShader { let index_tmp = scope.create_local(Elem::UInt); let index_base = scope.create_local(Elem::UInt); - let grad_accumulation = scope.zero(grad.item()); - let result = scope.create_local(grad.item()); - let count = scope.create_local(grad.item()); + let grad_accumulation = scope.zero(grad.item); + let result = scope.create_local(grad.item); + let count = scope.create_local(grad.item); let count_include_pad = self.count_include_pad; if count_include_pad { @@ -226,30 +216,12 @@ impl AvgPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::GlobalScalar { - id: 0, - elem: Elem::UInt, - }; - let pool_stride_1 = Variable::GlobalScalar { - id: 1, - elem: Elem::UInt, - }; - let dilation_0 = Variable::GlobalScalar { - id: 2, - elem: Elem::UInt, - }; - let dilation_1 = Variable::GlobalScalar { - id: 3, - elem: Elem::UInt, - }; - let padding_0 = Variable::GlobalScalar { - id: 4, - elem: Elem::UInt, - }; - let padding_1 = Variable::GlobalScalar { - id: 5, - elem: Elem::UInt, - }; + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -350,8 +322,8 @@ impl Kernel for AvgPool2dBackwardEagerKernel let mut scope = Scope::root(); let item = E::cube_elem().into(); - let grad = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let grad = Variable::new(VariableKind::GlobalInputArray(0), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 8e4eb96271..58dd09454f 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -7,7 +7,9 @@ use crate::{ }; use cubecl::{ cpa, - ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, + ir::{ + Builtin, Elem, IntKind, Item, KernelDefinition, Scope, Variable, VariableKind, Visibility, + }, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; @@ -32,7 +34,7 @@ impl MaxPool2dBackwardComputeShader { let grad = self.grad; let output = self.output; let indices = self.indices; - let id = Variable::AbsolutePos; + let id = Variable::builtin(Builtin::AbsolutePos); let grad_stride_0 = scope.create_local(Elem::UInt); let grad_stride_1 = scope.create_local(Elem::UInt); @@ -103,8 +105,8 @@ impl MaxPool2dBackwardComputeShader { let index_base = scope.create_local(Elem::UInt); let index_tmp = scope.create_local(Elem::UInt); - let grad_accumulation = scope.zero(grad.item()); - let result = scope.create_local(grad.item()); + let grad_accumulation = scope.zero(grad.item); + let result = scope.create_local(grad.item); let (oh_start, oh_end, ow_start, ow_end) = self.loop_ranges( scope, @@ -160,30 +162,12 @@ impl MaxPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::GlobalScalar { - id: 0, - elem: Elem::UInt, - }; - let pool_stride_1 = Variable::GlobalScalar { - id: 1, - elem: Elem::UInt, - }; - let dilation_0 = Variable::GlobalScalar { - id: 2, - elem: Elem::UInt, - }; - let dilation_1 = Variable::GlobalScalar { - id: 3, - elem: Elem::UInt, - }; - let padding_0 = Variable::GlobalScalar { - id: 4, - elem: Elem::UInt, - }; - let padding_1 = Variable::GlobalScalar { - id: 5, - elem: Elem::UInt, - }; + let pool_stride_0 = Variable::new(VariableKind::GlobalScalar(0), Item::new(Elem::UInt)); + let pool_stride_1 = Variable::new(VariableKind::GlobalScalar(1), Item::new(Elem::UInt)); + let dilation_0 = Variable::new(VariableKind::GlobalScalar(2), Item::new(Elem::UInt)); + let dilation_1 = Variable::new(VariableKind::GlobalScalar(3), Item::new(Elem::UInt)); + let padding_0 = Variable::new(VariableKind::GlobalScalar(4), Item::new(Elem::UInt)); + let padding_1 = Variable::new(VariableKind::GlobalScalar(5), Item::new(Elem::UInt)); let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -284,12 +268,12 @@ impl Kernel for MaxPool2dWithIndicesBackwardEagerK let mut scope = Scope::root(); let item = E::cube_elem().into(); - let indices = Variable::GlobalInputArray { - id: 0, - item: Item::new(Elem::Int(IntKind::I32)), - }; - let grad = Variable::GlobalInputArray { id: 1, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; + let indices = Variable::new( + VariableKind::GlobalInputArray(0), + Item::new(Elem::Int(IntKind::I32)), + ); + let grad = Variable::new(VariableKind::GlobalInputArray(1), item); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 3a3b9019ec..527c85d799 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -1,6 +1,6 @@ use cubecl::{ cpa, - ir::{Elem, Scope, Variable}, + ir::{Builtin, Elem, Item, Scope, Variable, VariableKind}, prelude::*, CubeCountSettings, Execution, InputInfo, OutputInfo, }; @@ -58,32 +58,17 @@ impl, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel::new(); for i in 0..P::args_length() { - args.push(Variable::GlobalScalar { - id: i as u16, - elem: item.elem(), - }); + args.push(Variable::new(VariableKind::GlobalScalar(i as u16), item)); } PrngShader::::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope); @@ -174,12 +159,12 @@ impl, E: JitElement> PrngShader { let n_values_per_thread: Variable = self.n_values_per_thread.into(); let args = self.args; - let cube_dim_x = Variable::CubeDimX; - let cube_dim_y = Variable::CubeDimY; - let cube_pos_x = Variable::CubePosX; - let cube_pos_y = Variable::CubePosY; - let cube_count_y = Variable::CubeCountY; - let local_index = Variable::UnitPos; + let cube_dim_x = Variable::builtin(Builtin::CubeDimX); + let cube_dim_y = Variable::builtin(Builtin::CubeDimY); + let cube_pos_x = Variable::builtin(Builtin::CubePosX); + let cube_pos_y = Variable::builtin(Builtin::CubePosY); + let cube_count_y = Variable::builtin(Builtin::CubeCountY); + let local_index = Variable::builtin(Builtin::UnitPos); let n_invocations = scope.create_local(Elem::UInt); cpa!(scope, n_invocations = cube_dim_x); diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index 8e3691f7c2..e9b4b428a2 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -37,7 +37,7 @@ impl Prng for Normal { output: Variable, ) { let float_elem = Elem::Float(FloatKind::F32); - let item = output.item(); + let item = output.item; let mean = args[0]; let std = args[1]; let two_pi = scope.create_with_value(2. * PI, float_elem); diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index 64dbc46aac..15490758df 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -35,7 +35,7 @@ impl Prng for Uniform { output: Variable, ) { let float_elem = Elem::Float(FloatKind::F32); - let item = output.item(); + let item = output.item; let lower_bound = args[0]; let upper_bound = args[1]; let scale = scope.create_local(item); diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs index 95ad8adbad..1b10a5f105 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs @@ -32,7 +32,7 @@ impl ReduceDimShared for Argmax { (value, index): Self::Accumulator, ) { let (value_shared_memory, index_shared_memory) = shared_memory; - let current_value = scope.create_local(value.item()); + let current_value = scope.create_local(value.item); cpa!(scope, current_value = value_shared_memory[write_position]); let condition = scope.create_local(Elem::Bool); @@ -49,7 +49,7 @@ impl ReduceDimShared for Argmax { read_position: Variable, i: Variable, ) -> Self::Accumulator { - let value = scope.create_local(input.item()); + let value = scope.create_local(input.item); cpa!(scope, value = input[read_position]); (value, i) } @@ -60,9 +60,9 @@ impl ReduceDimShared for Argmax { read_position: Variable, ) -> Self::Accumulator { let (value_shared_memory, index_shared_memory) = shared_memory; - let value = scope.create_local(value_shared_memory.item()); + let value = scope.create_local(value_shared_memory.item); cpa!(scope, value = value_shared_memory[read_position]); - let index = scope.create_local(index_shared_memory.item()); + let index = scope.create_local(index_shared_memory.item); cpa!(scope, index = index_shared_memory[read_position]); (value, index) } @@ -75,7 +75,7 @@ impl ReduceDimShared for Argmax { _shape_reduce_dim: Variable, ) { let (_, index_shared_memory) = shared_memory; - let final_value = scope.create_local(output.item()); + let final_value = scope.create_local(index_shared_memory.item); cpa!(scope, final_value = index_shared_memory[0]); cpa!(scope, output[write_position] = final_value); } diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs index 79616dc967..fa8ff21af7 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs @@ -33,7 +33,7 @@ impl ReduceDimShared for Argmin { (value, index): Self::Accumulator, ) { let (value_shared_memory, index_shared_memory) = shared_memory; - let current_value = scope.create_local(value.item()); + let current_value = scope.create_local(value.item); cpa!(scope, current_value = value_shared_memory[write_position]); let condition = scope.create_local(Elem::Bool); @@ -50,7 +50,7 @@ impl ReduceDimShared for Argmin { read_position: Variable, i: Variable, ) -> Self::Accumulator { - let value = scope.create_local(input.item()); + let value = scope.create_local(input.item); cpa!(scope, value = input[read_position]); (value, i) } @@ -61,9 +61,9 @@ impl ReduceDimShared for Argmin { read_position: Variable, ) -> Self::Accumulator { let (value_shared_memory, index_shared_memory) = shared_memory; - let value = scope.create_local(value_shared_memory.item()); + let value = scope.create_local(value_shared_memory.item); cpa!(scope, value = value_shared_memory[read_position]); - let index = scope.create_local(index_shared_memory.item()); + let index = scope.create_local(index_shared_memory.item); cpa!(scope, index = index_shared_memory[read_position]); (value, index) } @@ -76,7 +76,7 @@ impl ReduceDimShared for Argmin { _shape_reduce_dim: Variable, ) { let (_, index_shared_memory) = shared_memory; - let final_value = scope.create_local(output.item()); + let final_value = scope.create_local(index_shared_memory.item); cpa!(scope, final_value = index_shared_memory[0]); cpa!(scope, output[write_position] = final_value); } diff --git a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs index 0339d9da43..a397168729 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs @@ -16,7 +16,7 @@ impl ReduceDimShared for MeanDim { input_item: Item, ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); - let neutral_element = scope.zero(shared_memory.item()); + let neutral_element = scope.zero(shared_memory.item); cpa!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -27,8 +27,8 @@ impl ReduceDimShared for MeanDim { write_position: Variable, value: Self::Accumulator, ) { - let current_value = scope.create_local(value.item()); - let computed = scope.create_local(value.item()); + let current_value = scope.create_local(value.item); + let computed = scope.create_local(value.item); cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, computed = current_value + value); cpa!(scope, shared_memory[write_position] = computed); @@ -40,7 +40,7 @@ impl ReduceDimShared for MeanDim { read_position: Variable, _i: Variable, ) -> Self::Accumulator { - let value = scope.create_local(input.item()); + let value = scope.create_local(input.item); cpa!(scope, value = input[read_position]); value } @@ -50,7 +50,7 @@ impl ReduceDimShared for MeanDim { shared_memory: Self::Accumulator, read_position: Variable, ) -> Variable { - let read_value = scope.create_local(shared_memory.item()); + let read_value = scope.create_local(shared_memory.item); cpa!(scope, read_value = shared_memory[read_position]); read_value } @@ -62,10 +62,10 @@ impl ReduceDimShared for MeanDim { write_position: Variable, shape_reduce_dim: Variable, ) { - let final_value = scope.create_local(output.item()); + let final_value = scope.create_local(output.item); cpa!(scope, final_value = shared_memory[0]); - let denominator = scope.create_local(output.item()); + let denominator = scope.create_local(output.item); cpa!(scope, denominator = cast(shape_reduce_dim)); cpa!(scope, final_value = final_value / denominator); cpa!(scope, output[write_position] = final_value); diff --git a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs index 961e192a8b..8d1e1ff6cd 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs @@ -16,7 +16,7 @@ impl ReduceDimShared for ProdDim { input_item: Item, ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); - let neutral_element = scope.create_with_value(1, shared_memory.item()); + let neutral_element = scope.create_with_value(1, shared_memory.item); cpa!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -27,8 +27,8 @@ impl ReduceDimShared for ProdDim { write_position: Variable, value: Self::Accumulator, ) { - let current_value = scope.create_local(value.item()); - let computed = scope.create_local(value.item()); + let current_value = scope.create_local(value.item); + let computed = scope.create_local(value.item); cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, computed = current_value * value); cpa!(scope, shared_memory[write_position] = computed); @@ -40,7 +40,7 @@ impl ReduceDimShared for ProdDim { read_position: Variable, _i: Variable, ) -> Self::Accumulator { - let value = scope.create_local(input.item()); + let value = scope.create_local(input.item); cpa!(scope, value = input[read_position]); value } @@ -50,7 +50,7 @@ impl ReduceDimShared for ProdDim { shared_memory: Self::Accumulator, read_position: Variable, ) -> Self::Accumulator { - let read_value = scope.create_local(shared_memory.item()); + let read_value = scope.create_local(shared_memory.item); cpa!(scope, read_value = shared_memory[read_position]); read_value } @@ -62,7 +62,7 @@ impl ReduceDimShared for ProdDim { write_position: Variable, _shape_reduce_dim: Variable, ) { - let final_value = scope.create_local(output.item()); + let final_value = scope.create_local(output.item); cpa!(scope, final_value = shared_memory[0]); cpa!(scope, output[write_position] = final_value); } diff --git a/crates/burn-jit/src/kernel/reduce/shared/shader.rs b/crates/burn-jit/src/kernel/reduce/shared/shader.rs index 1a7d73a252..3aa48ff8a5 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/shader.rs @@ -1,6 +1,9 @@ use cubecl::{ - cpa, ir::KernelDefinition, prelude::CubeCount, CubeCountSettings, Execution, InputInfo, - KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, + cpa, + ir::{Builtin, KernelDefinition, VariableKind}, + prelude::CubeCount, + CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, + OutputInfo, }; use std::marker::PhantomData; @@ -51,14 +54,8 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Ker let item_input = EI::cube_elem().into(); let item_output = EO::cube_elem().into(); - let tensor = Variable::GlobalInputArray { - id: 0, - item: item_input, - }; - let output = Variable::GlobalOutputArray { - id: 0, - item: item_output, - }; + let tensor = Variable::new(VariableKind::GlobalInputArray(0), item_input); + let output = Variable::new(VariableKind::GlobalOutputArray(0), item_output); // Reduce groups are elements that are aligned along the reduce dim SharedReduceDimComputeShader { @@ -112,16 +109,16 @@ impl> SharedReduceDimComputeShader let tensor = self.tensor; let output = self.output; - let rank = Variable::Rank; + let rank = Variable::builtin(Builtin::Rank); let dim: Variable = self.dim.into(); - let cube_pos_x = Variable::CubePosX; - let cube_pos_y = Variable::CubePosY; - let cube_count_x = Variable::CubeCountX; - let local_invocation_id_x = Variable::UnitPosX; - let local_invocation_id_y = Variable::UnitPosY; - let cube_dim_x = Variable::CubeDimX; - let cube_dim_y = Variable::CubeDimY; + let cube_pos_x = Variable::builtin(Builtin::CubePosX); + let cube_pos_y = Variable::builtin(Builtin::CubePosY); + let cube_count_x = Variable::builtin(Builtin::CubeCountX); + let local_invocation_id_x = Variable::builtin(Builtin::UnitPosX); + let local_invocation_id_y = Variable::builtin(Builtin::UnitPosY); + let cube_dim_x = Variable::builtin(Builtin::CubeDimX); + let cube_dim_y = Variable::builtin(Builtin::CubeDimY); let stride_reduce_dim_input = scope.create_local(Elem::UInt); cpa!(scope, stride_reduce_dim_input = stride(tensor, dim)); @@ -162,12 +159,8 @@ impl> SharedReduceDimComputeShader }) ); - let shared_memory = RD::initialize_shared( - scope, - self.shared_memory_size as u32, - local_id, - tensor.item(), - ); + let shared_memory = + RD::initialize_shared(scope, self.shared_memory_size as u32, local_id, tensor.item); // Load to shared memory, unrolled cpa!( diff --git a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs index db85b09cb7..bbccb01bef 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs @@ -16,7 +16,7 @@ impl ReduceDimShared for SumDim { input_item: Item, ) -> Self::Accumulator { let shared_memory = scope.create_shared(input_item, shared_memory_size); - let neutral_element = scope.zero(shared_memory.item()); + let neutral_element = scope.zero(shared_memory.item); cpa!(scope, shared_memory[write_position] = neutral_element); shared_memory } @@ -27,8 +27,8 @@ impl ReduceDimShared for SumDim { write_position: Variable, value: Self::Accumulator, ) { - let current_value = scope.create_local(value.item()); - let computed = scope.create_local(value.item()); + let current_value = scope.create_local(value.item); + let computed = scope.create_local(value.item); cpa!(scope, current_value = shared_memory[write_position]); cpa!(scope, computed = current_value + value); cpa!(scope, shared_memory[write_position] = computed); @@ -40,7 +40,7 @@ impl ReduceDimShared for SumDim { read_position: Variable, _i: Variable, ) -> Self::Accumulator { - let value = scope.create_local(input.item()); + let value = scope.create_local(input.item); cpa!(scope, value = input[read_position]); value } @@ -50,7 +50,7 @@ impl ReduceDimShared for SumDim { shared_memory: Self::Accumulator, read_position: Variable, ) -> Self::Accumulator { - let read_value = scope.create_local(shared_memory.item()); + let read_value = scope.create_local(shared_memory.item); cpa!(scope, read_value = shared_memory[read_position]); read_value } @@ -62,7 +62,7 @@ impl ReduceDimShared for SumDim { write_position: Variable, _shape_reduce_dim: Variable, ) { - let final_value = scope.create_local(output.item()); + let final_value = scope.create_local(output.item); cpa!(scope, final_value = shared_memory[0]); cpa!(scope, output[write_position] = final_value); } diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index acb3d75ecf..75b037972f 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -73,6 +73,7 @@ mod cube_wgpu { WgpuDevice::Existing(id) => { DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32) } + WgpuDevice::DefaultDevice => DeviceId::new(6, 0), } } } diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index c44719006c..bcff11eb69 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -8,9 +8,13 @@ use core::convert::Into; use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as SqueezenetModel}; -use burn::{backend::NdArray, prelude::*, tensor::activation::softmax}; +use burn::{ + backend::{wgpu::init_device, NdArray}, + prelude::*, + tensor::activation::softmax, +}; -use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; use burn_candle::Candle; use serde::Serialize; @@ -106,7 +110,7 @@ impl ImageClassifier { log::info!("Loading the model to the Wgpu backend"); let start = Instant::now(); let device = WgpuDevice::default(); - init_async::(&device, Default::default()).await; + init_device::(&device, Default::default()).await; self.model = ModelType::WithWgpuBackend(Model::new(&device)); let duration = start.elapsed(); log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);