Skip to content

Commit

Permalink
generalize compute interface to take impl borrow
Browse files Browse the repository at this point in the history
instead of & or &mut
so that &mut can be passed as input
  • Loading branch information
crop2000 committed Nov 29, 2024
1 parent 36d3e64 commit 65a13e9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 41 deletions.
50 changes: 20 additions & 30 deletions compiler/generator/rust/rust_code_container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ void RustCodeContainer::produceClass()

tab(n, *fOut);
*fOut << "use std::convert::TryInto;";
tab(n, *fOut);
*fOut << "use std::borrow;";

// Generate gub containers
generateSubContainers();
Expand Down Expand Up @@ -527,7 +529,6 @@ void RustCodeContainer::produceClass()
generateComputeFrame(n + 1);
} else {
generateCompute(n + 1);
generateComputeInterface(n + 1);
}

tab(n, *fOut);
Expand Down Expand Up @@ -627,18 +628,24 @@ void RustCodeContainer::generateComputeHeader(int n, std::ostream* fOut, int fNu
{
// Compute "compute" declaration
tab(n, *fOut);
*fOut << "pub fn compute_arrays("
<< "&mut self, " << fFullCount << ": usize, inputs: &[&[FaustFloat] ; " << fNumInputs
<< "]"
<< ", outputs: &mut [&mut [FaustFloat] ; " << fNumOutputs << "]) {";
}

void RustCodeContainer::generateComputeInterfaceHeader(int n, std::ostream* fOut, int fNumInputs,
int fNumOutputs)
{
*fOut << "pub fn compute("
<< "&mut self, " << fFullCount << ": usize, inputs: & [& [FaustFloat] ]"
<< ", outputs: & mut[& mut[FaustFloat] ]) {";
tab(n, *fOut);
*fOut << "pub fn compute<InType, OutType>(";
tab(n + 1, *fOut);
*fOut << "&mut self,";
tab(n + 1, *fOut);
*fOut << "count: usize,";
tab(n + 1, *fOut);
*fOut << "inputs: impl borrow::Borrow<[InType]>,";
tab(n + 1, *fOut);
*fOut << "mut outputs: impl borrow::BorrowMut<[OutType]>,";
tab(n, *fOut);
*fOut << ") where";
tab(n + 1, *fOut);
*fOut << "InType: borrow::Borrow<[FaustFloat]>,";
tab(n + 1, *fOut);
*fOut << "OutType: borrow::BorrowMut<[FaustFloat]>,";
tab(n, *fOut);
*fOut << "{";
tab(n + 1, *fOut);
}

Expand Down Expand Up @@ -686,23 +693,6 @@ void RustCodeContainer::generateComputeFrame(int n)
tab(n, *fOut);
}

void RustCodeContainer::generateComputeInterface(int n)
{
// Generates declaration
tab(n, *fOut);
generateComputeInterfaceHeader(n, fOut, fNumInputs, fNumOutputs);

*fOut << "let input_array = inputs.split_at(" << fNumInputs
<< ").0.try_into().expect(\"too few input buffers\");";
tab(n + 1, *fOut);
*fOut << "let output_array = outputs.split_at_mut(" << fNumOutputs
<< ").0.try_into().expect(\"too few output buffers\");";
tab(n + 1, *fOut);
*fOut << "self.compute_arrays(count, input_array, output_array);";
tab(n, *fOut);
*fOut << "}" << endl;
}

// Scalar
RustScalarCodeContainer::RustScalarCodeContainer(const string& name, int numInputs, int numOutputs,
std::ostream* out, int sub_container_type)
Expand Down
2 changes: 0 additions & 2 deletions compiler/generator/rust/rust_code_container.hh
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class RustCodeContainer : public virtual CodeContainer {

virtual void produceClass();
void generateComputeHeader(int n, std::ostream* fOut, int fNumInputs, int fNumOutputs);
void generateComputeInterfaceHeader(int n, std::ostream* fOut, int fNumInputs, int fNumOutputs);
void generateComputeInterface(int tab);
void generateComputeFrame(int tab);
virtual void generateCompute(int tab) = 0;
void produceInternal();
Expand Down
24 changes: 15 additions & 9 deletions compiler/generator/rust/rust_instructions.hh
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ class RustInstVisitor : public TextInstVisitor {
virtual void visit(DeclareBufferIterators* inst)
{
/* Generates an expression like:
let [outputs0, outputs1, ..] = outputs;
let outputs0 = outputs0[..count].iter_mut();
let outputs1 = outputs1[..count].iter_mut();
let [outputs0, outputs1] = outputs.borrow_mut();
let outputs0 = outputs0.borrow_mut[..count].iter_mut();
let outputs1 = outputs1.borrow_mut[..count].iter_mut();
*/

// Don't generate if no channels
Expand All @@ -241,24 +241,30 @@ class RustInstVisitor : public TextInstVisitor {
for (int i = 0; i < inst->fChannels; ++i) {
*fOut << name << i << ", ";
}
*fOut << "] = " << name << ";";
*fOut << "] = " << name;
if (inst->fMutable) {
*fOut << ".borrow_mut() else { panic!(\"wrong number of outputs\"); };";
} else {
*fOut << ".borrow() else { panic!(\"wrong number of inputs\"); };";
}

// Build fixed size iterator variables

for (int i = 0; i < inst->fChannels; ++i) {
tab(fTab, *fOut);
*fOut << "let " << name << i << " = " << name << i << "[..count]";
*fOut << "let " << name << i << " = " << name << i;
;
if (inst->fMutable) {
if (inst->fChunk) {
*fOut << ".chunks_mut(vsize as usize);";
*fOut << ".borrow_mut()[..count].chunks_mut(vsize as usize);";
} else {
*fOut << ".iter_mut();";
*fOut << ".borrow_mut()[..count].iter_mut();";
}
} else {
if (inst->fChunk) {
*fOut << ".chunks(vsize as usize);";
*fOut << ".borrow()[..count].chunks(vsize as usize);";
} else {
*fOut << ".iter();";
*fOut << ".borrow()[..count].iter();";
}
}
}
Expand Down

0 comments on commit 65a13e9

Please sign in to comment.