Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add muldiv_c and muxadd peepopts #4740

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions passes/pmgen/Makefile.inc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ PEEPOPT_PATTERN = passes/pmgen/peepopt_shiftmul_right.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_shiftmul_left.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_shiftadd.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_muldiv.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_muldiv_c.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_muxadd.pmg
PEEPOPT_PATTERN += passes/pmgen/peepopt_formal_clockgateff.pmg

passes/pmgen/peepopt_pm.h: passes/pmgen/pmgen.py $(PEEPOPT_PATTERN)
Expand Down
6 changes: 6 additions & 0 deletions passes/pmgen/peepopt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ struct PeepoptPass : public Pass {
log("\n");
log("This pass employs the following rules by default:\n");
log("\n");
log(" * muxadd - Replace S?(A+B):A with A+(S?B:0)\n");
log("\n");
log(" * muldiv - Replace (A*B)/B with A\n");
log("\n");
log(" * muldiv_c - Replace (A*B)/C with A*(B/C) when C is a const divisible by B.\n");
log("\n");
log(" * shiftmul - Replace A>>(B*C) with A'>>(B<<K) where C and K are constants\n");
log(" and A' is derived from A by appropriately inserting padding\n");
log(" into the signal. (right variant)\n");
Expand Down Expand Up @@ -105,6 +109,8 @@ struct PeepoptPass : public Pass {
pm.run_shiftmul_right();
pm.run_shiftmul_left();
pm.run_muldiv();
pm.run_muldiv_c();
pm.run_muxadd();
}
}
}
Expand Down
76 changes: 76 additions & 0 deletions passes/pmgen/peepopt_muldiv_c.pmg
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
pattern muldiv_c
//
// Authored by Akash Levy of Silimate, Inc. under ISC license.
// Transforms mul->div into const->mul when b and c are divisible constants:
// y = (a * b_const) / c_const ===> a * eval(b_const / c_const)
//

state <SigSpec> a b_const mul_y

match mul
// Select multiplier
select mul->type == $mul
endmatch

code a b_const mul_y
// Get multiplier signals
a = port(mul, \A);
b_const = port(mul, \B);
mul_y = port(mul, \Y);

// Fanout of each multiplier Y bit should be 1 (no bit-split)
for (auto bit : mul_y)
if (nusers(bit) != 2)
reject;

// A and B can be interchanged
branch;
std::swap(a, b_const);
endcode

match div
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No check that the divider and multiplier are connected

// Select div of form (a * b_const) / c_const
select div->type == $div

// Check that b_const and c_const is constant
filter b_const.is_fully_const()
filter port(div, \B).is_fully_const()
endmatch

code
// Get div signals
SigSpec div_a = port(div, \A);
SigSpec c_const = port(div, \B);
SigSpec div_y = port(div, \Y);

// Get offset of multiplier result chunk in divider
int offset = GetSize(div_a) - GetSize(mul_y);

// Get properties and values of b_const and c_const
int b_const_width = mul->getParam(ID::B_WIDTH).as_int();
bool b_const_signed = mul->getParam(ID::B_SIGNED).as_bool();
bool c_const_signed = div->getParam(ID::B_SIGNED).as_bool();
int b_const_int = b_const.as_int(b_const_signed);
int c_const_int = c_const.as_int(c_const_signed);
int b_const_int_shifted = b_const_int << offset;

// Check that there are only zeros before offset
if (offset < 0 || !div_a.extract(0, offset).is_fully_zero())
reject;

// Check that b is divisible by c
if (b_const_int_shifted % c_const_int != 0)
reject;

// Rewire to only keep multiplier
mul->setPort(\B, Const(b_const_int_shifted / c_const_int, b_const_width));
mul->setPort(\Y, div_y);

// Remove divider
autoremove(div);

// Log, fixup, accept
log("muldiv_const pattern in %s: mul=%s, div=%s\n", log_id(module), log_id(mul), log_id(div));
mul->fixup_parameters();
accept;
endcode
57 changes: 57 additions & 0 deletions passes/pmgen/peepopt_muxadd.pmg
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
pattern muxadd
//
// Authored by Akash Levy of Silimate, Inc. under ISC license.
// Transforms add->mux into mux->add:
// y = s ? (a + b) : a ===> y = a + (s ? b : 0)
//

state <SigSpec> add_a add_b add_y

match add
// Select adder
select add->type == $add
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
select add->type == $add
select add->type == $add
choice <IdString> A {\A, \B}
define <IdString> B (A == \A ? \B : \A)
set add_y port(add, \Y)
set add_a port(add, A)
set add_b port(add, B)
set add_a_signed param(add, (A == \A) ? \A_SIGNED : \B_SIGNED))
set add_a_id A

and replace the branch-and-swap pattern further below

Requires new state

state <bool> add_a_signed
state <IdString> add_a_id

endmatch

code add_y add_a add_b
// Get adder signals
add_a = port(add, \A);
add_b = port(add, \B);
add_y = port(add, \Y);

// Fanout of each adder Y bit should be 1 (no bit-split)
for (auto bit : add_y)
if (nusers(bit) != 2)
reject;
Comment on lines +22 to +24
Copy link
Member

@povik povik Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto bit : add_y)
if (nusers(bit) != 2)
reject;
if (nusers(add_y) != 2)
reject;

is all we need and more idiomatic.

We know nusers(add_y) >= 2 by this being an output from add and mux being connected to it too. If any bit has an extra fanout then nusers(add_y) > 2


// A and B can be interchanged
branch;
std::swap(add_a, add_b);
endcode

match mux
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also match on the case of swapped mux inputs

// Select mux of form s ? (a + b) : a, allow leading 0s when A_WIDTH != Y_WIDTH
select mux->type == $mux
index <SigSpec> port(mux, \A) === SigSpec({Const(State::S0, GetSize(add_y)-GetSize(add_a)), add_a})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be zero-padded if !param(add, \A_SIGNED).bool() and sign-extended otherwise (provided add_a is the A input)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure this doesn't crash on add_y < add_a which I don't think we rule out anywhere

index <SigSpec> port(mux, \B) === add_y
endmatch

code
// Get mux signal
SigSpec mux_y = port(mux, \Y);

// Create new mid wire
SigSpec mid = module->addWire(NEW_ID, GetSize(add_b));

// Rewire
mux->setPort(\A, Const(State::S0, GetSize(add_b)));
mux->setPort(\B, add_b);
mux->setPort(\Y, mid);
add->setPort(\B, mid);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong side of the adder if the swap above was hit (we need to use add_a_id instead)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, more like add_b_id

add->setPort(\Y, mux_y);

// Log, fixup, accept
log("muxadd pattern in %s: mux=%s, add=%s\n", log_id(module), log_id(mux), log_id(add));
add->fixup_parameters();
mux->fixup_parameters();
accept;
endcode