-
Notifications
You must be signed in to change notification settings - Fork 12.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pretty, ui, and feature-gate tests for the enzyme/autodiff frontend
- Loading branch information
Showing
15 changed files
with
744 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#![feature(prelude_import)] | ||
#![no_std] | ||
//@ needs-enzyme | ||
|
||
#![feature(autodiff)] | ||
#[prelude_import] | ||
use ::std::prelude::rust_2015::*; | ||
#[macro_use] | ||
extern crate std; | ||
//@ pretty-mode:expanded | ||
//@ pretty-compare-only | ||
//@ pp-exact:autodiff_forward.pp | ||
|
||
// Test that forward mode ad macros are expanded correctly. | ||
|
||
use std::autodiff::autodiff; | ||
|
||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f1(x: &[f64], y: f64) -> f64 { | ||
|
||
|
||
|
||
// Not the most interesting derivative, but who are we to judge | ||
|
||
// We want to be sure that the same function can be differentiated in different ways | ||
|
||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(Forward, Dual, Const, Dual,)] | ||
#[inline(never)] | ||
pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f1(x, y)); | ||
::core::hint::black_box((bx,)); | ||
::core::hint::black_box((f1(x, y), f64::default())) | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f2(x: &[f64], y: f64) -> f64 { | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(Forward, Dual, Const, Const,)] | ||
#[inline(never)] | ||
pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f2(x, y)); | ||
::core::hint::black_box((bx,)); | ||
::core::hint::black_box(f2(x, y)) | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f3(x: &[f64], y: f64) -> f64 { | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(ForwardFirst, Dual, Const, Const,)] | ||
#[inline(never)] | ||
pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f3(x, y)); | ||
::core::hint::black_box((bx,)); | ||
::core::hint::black_box(f3(x, y)) | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f4() {} | ||
#[rustc_autodiff(Forward, None)] | ||
#[inline(never)] | ||
pub fn df4() { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f4()); | ||
::core::hint::black_box(()); | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f5(x: &[f64], y: f64) -> f64 { | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(Forward, Const, Dual, Const,)] | ||
#[inline(never)] | ||
pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f5(x, y)); | ||
::core::hint::black_box((by,)); | ||
::core::hint::black_box(f5(x, y)) | ||
} | ||
#[rustc_autodiff(Forward, Dual, Const, Const,)] | ||
#[inline(never)] | ||
pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f5(x, y)); | ||
::core::hint::black_box((bx,)); | ||
::core::hint::black_box(f5(x, y)) | ||
} | ||
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)] | ||
#[inline(never)] | ||
pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f5(x, y)); | ||
::core::hint::black_box((dx, dret)); | ||
::core::hint::black_box(f5(x, y)) | ||
} | ||
fn main() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
//@ needs-enzyme | ||
|
||
#![feature(autodiff)] | ||
//@ pretty-mode:expanded | ||
//@ pretty-compare-only | ||
//@ pp-exact:autodiff_forward.pp | ||
|
||
// Test that forward mode ad macros are expanded correctly. | ||
|
||
use std::autodiff::autodiff; | ||
|
||
#[autodiff(df1, Forward, Dual, Const, Dual)] | ||
pub fn f1(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
#[autodiff(df2, Forward, Dual, Const, Const)] | ||
pub fn f2(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
#[autodiff(df3, ForwardFirst, Dual, Const, Const)] | ||
pub fn f3(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
// Not the most interesting derivative, but who are we to judge | ||
#[autodiff(df4, Forward)] | ||
pub fn f4() {} | ||
|
||
// We want to be sure that the same function can be differentiated in different ways | ||
#[autodiff(df5_rev, Reverse, Duplicated, Const, Active)] | ||
#[autodiff(df5_x, Forward, Dual, Const, Const)] | ||
#[autodiff(df5_y, Forward, Const, Dual, Const)] | ||
pub fn f5(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
fn main() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#![feature(prelude_import)] | ||
#![no_std] | ||
//@ needs-enzyme | ||
|
||
#![feature(autodiff)] | ||
#[prelude_import] | ||
use ::std::prelude::rust_2015::*; | ||
#[macro_use] | ||
extern crate std; | ||
//@ pretty-mode:expanded | ||
//@ pretty-compare-only | ||
//@ pp-exact:autodiff_reverse.pp | ||
|
||
// Test that reverse mode ad macros are expanded correctly. | ||
|
||
use std::autodiff::autodiff; | ||
|
||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f1(x: &[f64], y: f64) -> f64 { | ||
|
||
// Not the most interesting derivative, but who are we to judge | ||
|
||
|
||
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant | ||
// constructor) namespace? > It's expected to work normally. | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)] | ||
#[inline(never)] | ||
pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f1(x, y)); | ||
::core::hint::black_box((dx, dret)); | ||
::core::hint::black_box(f1(x, y)) | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f2() {} | ||
#[rustc_autodiff(Reverse, None)] | ||
#[inline(never)] | ||
pub fn df2() { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f2()); | ||
::core::hint::black_box(()); | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f3(x: &[f64], y: f64) -> f64 { | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(ReverseFirst, Duplicated, Const, Active,)] | ||
#[inline(never)] | ||
pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f3(x, y)); | ||
::core::hint::black_box((dx, dret)); | ||
::core::hint::black_box(f3(x, y)) | ||
} | ||
enum Foo { Reverse, } | ||
use Foo::Reverse; | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") } | ||
#[rustc_autodiff(Reverse, Const, None)] | ||
#[inline(never)] | ||
pub fn df4(x: f32) { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f4(x)); | ||
::core::hint::black_box(()); | ||
} | ||
#[rustc_autodiff] | ||
#[inline(never)] | ||
pub fn f5(x: *const f32, y: &f32) { | ||
::core::panicking::panic("not implemented") | ||
} | ||
#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)] | ||
#[inline(never)] | ||
pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) { | ||
unsafe { asm!("NOP", options(pure, nomem)); }; | ||
::core::hint::black_box(f5(x, y)); | ||
::core::hint::black_box((dx, dy)); | ||
} | ||
fn main() {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
//@ needs-enzyme | ||
|
||
#![feature(autodiff)] | ||
//@ pretty-mode:expanded | ||
//@ pretty-compare-only | ||
//@ pp-exact:autodiff_reverse.pp | ||
|
||
// Test that reverse mode ad macros are expanded correctly. | ||
|
||
use std::autodiff::autodiff; | ||
|
||
#[autodiff(df1, Reverse, Duplicated, Const, Active)] | ||
pub fn f1(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
// Not the most interesting derivative, but who are we to judge | ||
#[autodiff(df2, Reverse)] | ||
pub fn f2() {} | ||
|
||
#[autodiff(df3, ReverseFirst, Duplicated, Const, Active)] | ||
pub fn f3(x: &[f64], y: f64) -> f64 { | ||
unimplemented!() | ||
} | ||
|
||
enum Foo { Reverse } | ||
use Foo::Reverse; | ||
// What happens if we already have Reverse in type (enum variant decl) and value (enum variant | ||
// constructor) namespace? > It's expected to work normally. | ||
#[autodiff(df4, Reverse, Const)] | ||
pub fn f4(x: f32) { | ||
unimplemented!() | ||
} | ||
|
||
#[autodiff(df5, Reverse, DuplicatedOnly, Duplicated)] | ||
pub fn f5(x: *const f32, y: &f32) { | ||
unimplemented!() | ||
} | ||
|
||
fn main() {} |
Oops, something went wrong.