Skip to content

Commit

Permalink
Add wgpu-spirv and hip-jit features to text-classification example (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
syl20bnr authored Oct 30, 2024
1 parent 5730f02 commit f263e36
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ pub use burn_candle::Candle;
#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;

#[cfg(feature = "hip-jit")]
pub use burn_hip::Hip as HipJit;

#[cfg(feature = "tch")]
pub use burn_tch as libtorch;

Expand Down
2 changes: 2 additions & 0 deletions examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
cuda-jit = ["burn/cuda-jit"]
hip-jit = ["burn/hip-jit"]

[dependencies]
# Burn
Expand Down
12 changes: 12 additions & 0 deletions examples/text-classification/examples/ag-news-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,16 @@ mod cuda_jit {
}
}

#[cfg(feature = "hip-jit")]
mod hip_jit {
use crate::{launch, ElemType};
use burn::backend::{Autodiff, HipJit};

pub fn run() {
launch::<Autodiff<HipJit<ElemType, i32>>>(vec![Default::default()]);
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
Expand All @@ -117,4 +127,6 @@ fn main() {
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "hip-jit")]
hip_jit::run();
}

0 comments on commit f263e36

Please sign in to comment.