Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sway-ir/src/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ impl Block {
context.blocks[self.0].instructions = insts;
}

pub fn insert_instructions_after(&self, context: &mut Context, value: Value, insts: impl IntoIterator<Item = Value>) {
let block_ins = &mut context.blocks[self.0].instructions;
let pos = block_ins.iter().position(|x| x == &value).unwrap();
block_ins.splice(pos+1..pos+1, insts);
}

/// Replace an instruction in this block with another. Will return a ValueNotFound on error.
/// Any use of the old instruction value will also be replaced by the new value throughout the
/// owning function if `replace_uses` is set.
Expand Down
2 changes: 2 additions & 0 deletions sway-ir/src/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub mod sroa;
pub use sroa::*;
pub mod fn_dedup;
pub use fn_dedup::*;
pub mod branchless;
pub use branchless::*;

mod target_fuel;

Expand Down
150 changes: 150 additions & 0 deletions sway-ir/src/optimize/branchless.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use sway_features::ExperimentalFeatures;

use crate::{AnalysisResults, Block, BranchToWithArgs, Constant, ConstantContent, Context, Function, InstOp, IrError, Pass, PassMutability, ScopedPass, Value, ValueDatum};


pub const BRANCHLESS_NAME: &str = "branchless";

pub fn create_branchless() -> Pass {
Pass {
name: BRANCHLESS_NAME,
descr: "Branchless",
deps: vec![],
runner: ScopedPass::FunctionPass(PassMutability::Transform(branchless)),
}
}

// check if a block simple calls another block with a u64
fn is_block_simple_integer<'a>(context: &'a Context, branch: &BranchToWithArgs) -> Option<(&'a Constant, &'a Block)> {
let b = &context.blocks[branch.block.0];

if b.instructions.len() > 1 {
return None;
}

let v = &b.instructions[0];
let v = &context.values[v.0];
match &v.value {
crate::ValueDatum::Instruction(i) => match &i.op {
InstOp::Branch(branch) => {
if branch.args.len() != 1 {
return None;
}

let arg0 = &context.values[branch.args[0].0];
match &arg0.value {
crate::ValueDatum::Constant(constant) => Some((constant, &branch.block)),
_ => None,
}
},
_ => None,
},
_ => None,
}
}

fn find_cbr(context: &mut Context, function: Function) -> Option<(Block, Value, Block, Value, Constant, Constant)> {
for (block, value) in function.instruction_iter(context) {
match &context.values[value.0].value {
ValueDatum::Argument(_) => {},
ValueDatum::Constant(_) => {},
ValueDatum::Instruction(instruction) => {
match &instruction.op {
InstOp::ConditionalBranch { cond_value, true_block, false_block } => {
let target_block_true = is_block_simple_integer(context, &true_block);
let target_block_false = is_block_simple_integer(context, &false_block);

// both branches call the same block
match (target_block_true, target_block_false) {
(Some((constant_true, target_block_true)), Some((constant_false, target_block_false))) if target_block_true == target_block_false => {
return Some((block, value, *target_block_true, *cond_value, *constant_true, *constant_false));
},
_ => {},
}
},
_ => {},
}
},
};
}

None
}

pub fn branchless(
context: &mut Context,
_: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
let mut modified = false;
return Ok(false);

loop {
if let Some((block, instr_val, target_block, cond_value, constant_true, constant_false)) = find_cbr(context, function) {
block.remove_instruction(context, instr_val);

let one = ConstantContent::new_uint(context, 64, 1);
let one = Constant::unique(context, one);
let one = Value::new_constant(context, one);
let a = Value::new_constant(context, constant_true);
let b = Value::new_constant(context, constant_false);

// c is a boolean (1 or 0)
// Can we use predication?
// x = c * a + (1 − c) * b
let c_times_a = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Mul, arg1: cond_value, arg2: a });
let one_minus_c = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Sub, arg1: one, arg2: cond_value });
let one_minus_c_times_b = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Mul, arg1: one_minus_c, arg2: b });
let x = Value::new_instruction(context, block, InstOp::BinaryOp { op: crate::BinaryOpKind::Add, arg1: c_times_a, arg2: one_minus_c_times_b });

block.insert_instructions_after(context, cond_value, [c_times_a, one_minus_c, one_minus_c_times_b, x]);

let call_target_block = Value::new_instruction(context, block, InstOp::Branch(BranchToWithArgs {
block: target_block,
args: vec![x]
}));

let block = &mut context.blocks[block.0];
block.instructions.push(call_target_block);

modified = true;
} else {
break;
}
}

eprintln!("{}", context.to_string());

Ok(modified)
}

#[cfg(test)]
mod tests {
use crate::tests::assert_optimization;
use super::BRANCHLESS_NAME;

#[test]
fn branchless_optimized() {
let before_optimization = format!(
"
fn main(baba !68: u64) -> u64, !71 {{
entry(baba: u64):
v0 = const u64 0, !72
cbr v0, block0(), block1(), !73

block0():
v2 = const u64 1, !76
br block2(v2)

block1():
v3 = const u64 2, !77
br block2(v3)

block2(v4: u64):
ret u64 v4
}}
",
);
assert_optimization(&[BRANCHLESS_NAME], &before_optimization, Some(["const u64 1, !76"]));
}
}
17 changes: 5 additions & 12 deletions sway-ir/src/pass_manager.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
use crate::{
create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass,
create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass,
create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass,
create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass,
create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass,
create_module_printer_pass, create_module_verifier_pass, create_postorder_pass,
create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function,
IrError, Module, ARG_DEMOTION_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME,
CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME,
GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME,
SIMPLIFY_CFG_NAME, SROA_NAME,
create_arg_demotion_pass, create_branchless, create_ccp_pass, create_const_demotion_pass, create_const_folding_pass, create_cse_pass, create_dce_pass, create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass, create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_globals_dce_pass, create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass, create_module_printer_pass, create_module_verifier_pass, create_postorder_pass, create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function, IrError, Module, ARG_DEMOTION_NAME, BRANCHLESS_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, CSE_NAME, DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, GLOBALS_DCE_NAME, MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME
};
use downcast_rs::{impl_downcast, Downcast};
use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -164,7 +154,8 @@ pub struct PassManager {
}

impl PassManager {
pub const OPTIMIZATION_PASSES: [&'static str; 14] = [
pub const OPTIMIZATION_PASSES: [&'static str; 15] = [
BRANCHLESS_NAME,
FN_INLINE_NAME,
SIMPLIFY_CFG_NAME,
SROA_NAME,
Expand Down Expand Up @@ -395,6 +386,7 @@ pub fn register_known_passes(pm: &mut PassManager) {
pm.register(create_fn_dedup_debug_profile_pass());
pm.register(create_mem2reg_pass());
pm.register(create_sroa_pass());
pm.register(create_branchless());
pm.register(create_fn_inline_pass());
pm.register(create_const_folding_pass());
pm.register(create_ccp_pass());
Expand All @@ -415,6 +407,7 @@ pub fn create_o1_pass_group() -> PassGroup {
// Configure to run our passes.
o1.append_pass(MEM2REG_NAME);
o1.append_pass(FN_DEDUP_RELEASE_PROFILE_NAME);
o1.append_pass(BRANCHLESS_NAME);
o1.append_pass(FN_INLINE_NAME);
o1.append_pass(SIMPLIFY_CFG_NAME);
o1.append_pass(GLOBALS_DCE_NAME);
Expand Down
63 changes: 47 additions & 16 deletions test/src/e2e_vm_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ fn print_receipts(output: &mut String, receipts: &[Receipt]) {
}
}

struct RunResult {
size: Option<u64>,
gas: Option<u64>,
}

impl TestContext {
async fn deploy_contract(
&self,
Expand Down Expand Up @@ -306,7 +311,7 @@ impl TestContext {
})
}

async fn run(&self, test: TestDescription, output: &mut String, verbose: bool) -> Result<()> {
async fn run(&self, test: TestDescription, output: &mut String, verbose: bool) -> Result<RunResult> {
let TestDescription {
name,
suffix,
Expand Down Expand Up @@ -342,6 +347,11 @@ impl TestContext {
expected_result
};

let mut r = RunResult {
size: None,
gas: None,
};

match category {
TestCategory::Runs => {
let expected_result = expected_result.expect("No expected result found. This is likely because test.toml is missing either an \"expected_result_new_encoding\" or \"expected_result\" entry");
Expand All @@ -358,6 +368,7 @@ impl TestContext {

for p in packages {
let bytecode_len = p.bytecode.bytes.len();
r.size = Some(bytecode_len as u64);

let configurables = match &p.program_abi {
sway_core::asm_generation::ProgramABI::Fuel(abi) => {
Expand Down Expand Up @@ -408,6 +419,13 @@ impl TestContext {
harness::VMExecutionResult::Fuel(state, receipts, ecal) => {
print_receipts(output, &receipts);

if let Some(gas_used) = receipts.iter().filter_map(|x| match x {
Receipt::ScriptResult { gas_used, .. } => Some(*gas_used),
_ => None
}).last() {
r.gas = Some(gas_used);
}

use std::fmt::Write;
let _ = writeln!(output, " {}", "Captured Output".green().bold());
for captured in ecal.captured.iter() {
Expand Down Expand Up @@ -483,7 +501,7 @@ impl TestContext {
output.push_str(&out);
result?;
}
Ok(())
Ok(r)
}
}

Expand Down Expand Up @@ -543,7 +561,7 @@ impl TestContext {
output.push_str(&out);
}
}
Ok(())
Ok(r)
}

TestCategory::FailsToCompile => {
Expand All @@ -560,7 +578,7 @@ impl TestContext {
Err(anyhow::Error::msg("Test compiles but is expected to fail"))
} else {
check_file_checker(checker, &name, output)?;
Ok(())
Ok(r)
}
}

Expand Down Expand Up @@ -654,7 +672,7 @@ impl TestContext {
_ => {}
};

Ok(())
Ok(r)
}

TestCategory::UnitTestsPass => {
Expand Down Expand Up @@ -729,6 +747,8 @@ impl TestContext {
decoded_logs, expected_decoded_test_logs
);
}

r
})
}

Expand Down Expand Up @@ -833,18 +853,29 @@ pub async fn run(filter_config: &FilterConfig, run_config: &RunConfig) -> Result
context.run(test, &mut output, run_config.verbose).await
};

if let Err(err) = result {
println!(" {}", "failed".red().bold());
println!("{}", textwrap::indent(err.to_string().as_str(), " "));
println!("{}", textwrap::indent(&output, " "));
number_of_tests_failed += 1;
failed_tests.push(name);
} else {
println!(" {}", "ok".green().bold());
match result {
Err(err) => {
println!(" {}", "failed".red().bold());
println!("{}", textwrap::indent(err.to_string().as_str(), " "));
println!("{}", textwrap::indent(&output, " "));
number_of_tests_failed += 1;
failed_tests.push(name);
}
Ok(r) => {
if let Some(size) = r.size {
print!(" {} bytes ", size);
}

// If verbosity is requested then print it out.
if run_config.verbose && !output.is_empty() {
println!("{}", textwrap::indent(&output, " "));
if let Some(gas) = r.gas {
print!(" {} gas used ", gas);
}

println!(" {}", "ok".green().bold());

// If verbosity is requested then print it out.
if run_config.verbose && !output.is_empty() {
println!("{}", textwrap::indent(&output, " "));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
script;

#[inline(never)]
fn f(baba:u64) -> u64 {
if baba == 0 {
1
} else {
2
}
}

fn main(baba: u64) -> u64 {
baba + 1
f(baba)
}
Loading