Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1894,11 +1894,12 @@ class EnzymeBase {
Logic.CreateTrace(F, generativeFunctions, mode, has_dynamic_interface);

Value *trace =
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
if (!F->getReturnType()->isVoidTy())
trace = Builder.CreateExtractValue(trace, {1});
Builder.CreateCall(interface->newTraceTy(), interface->newTrace(), {});

args.push_back(trace);

Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);

// try to cast i8* returned from trace to CI->getRetType....
if (CI->getType() != trace->getType())
trace = Builder.CreatePointerCast(trace, CI->getType());

Expand Down
31 changes: 18 additions & 13 deletions enzyme/Enzyme/TraceGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
{
ElseTerm->getParent()->setName("condition." + call.getName() +
".without.trace");
ElseChoice =

auto choice =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
sample_args, "sample." + call.getName());
ElseChoice = choice;
}

Builder.SetInsertPoint(new_call);
Expand Down Expand Up @@ -132,11 +134,16 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
Logic.CreateTrace(called, tutils->generativeFunctions, tutils->mode,
tutils->hasDynamicTraceInterface());

auto trace = tutils->CreateTrace(Builder);

Instruction *tracecall;
switch (mode) {
case ProbProgMode::Trace: {
tracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn,
args, "trace." + called->getName());
SmallVector<Value *, 2> args_and_trace = SmallVector(args);
args_and_trace.push_back(trace);
tracecall =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
args_and_trace, "trace." + called->getName());
break;
}
case ProbProgMode::Condition: {
Expand All @@ -158,8 +165,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
ThenTerm->getParent()->setName("condition." + call.getName() +
".with.trace");
SmallVector<Value *, 2> args_and_cond = SmallVector(args);
auto trace = tutils->GetTrace(Builder, address,
called->getName() + ".subtrace");
auto observations = tutils->GetTrace(Builder, address,
called->getName() + ".subtrace");
args_and_cond.push_back(observations);
args_and_cond.push_back(trace);
ThenTracecall = Builder.CreateCall(samplefn->getFunctionType(),
samplefn, args_and_cond,
Expand All @@ -171,8 +179,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
ElseTerm->getParent()->setName("condition." + call.getName() +
".without.trace");
SmallVector<Value *, 2> args_and_null = SmallVector(args);
auto trace = ConstantPointerNull::get(cast<PointerType>(
auto observations = ConstantPointerNull::get(cast<PointerType>(
tutils->getTraceInterface()->newTraceTy()->getReturnType()));
args_and_null.push_back(observations);
args_and_null.push_back(trace);
ElseTracecall =
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
Expand All @@ -188,14 +197,10 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
}
}

Value *ret = Builder.CreateExtractValue(tracecall, {0});
Value *subtrace = Builder.CreateExtractValue(
tracecall, {1}, "newtrace." + called->getName());

tutils->InsertCall(Builder, address, subtrace);
tutils->InsertCall(Builder, address, trace);

ret->takeName(new_call);
new_call->replaceAllUsesWith(ret);
tracecall->takeName(new_call);
new_call->replaceAllUsesWith(tracecall);
new_call->eraseFromParent();
}
}
Expand Down
49 changes: 11 additions & 38 deletions enzyme/Enzyme/TraceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TraceUtils {
private:
TraceInterface *interface;
Value *dynamic_interface = nullptr;
Instruction *trace;
Value *trace;
Value *observations = nullptr;

public:
Expand Down Expand Up @@ -70,10 +70,9 @@ class TraceUtils {
if (mode == ProbProgMode::Condition)
params.push_back(traceType);

Type *RetTy = traceType;
if (!oldFunc->getReturnType()->isVoidTy())
RetTy = StructType::get(Context, {oldFunc->getReturnType(), traceType});
params.push_back(traceType);

Type *RetTy = oldFunc->getReturnType();
FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg());

Twine Name = (mode == ProbProgMode::Condition ? "condition_" : "trace_") +
Expand All @@ -94,21 +93,27 @@ class TraceUtils {
}

if (has_dynamic_interface) {
auto arg = newFunc->arg_end() - (1 + (mode == ProbProgMode::Condition));
auto arg = newFunc->arg_end() - (2 + (mode == ProbProgMode::Condition));
dynamic_interface = arg;
arg->setName("interface");
arg->addAttr(Attribute::ReadOnly);
arg->addAttr(Attribute::NoCapture);
}

if (mode == ProbProgMode::Condition) {
auto arg = newFunc->arg_end() - 1;
auto arg = newFunc->arg_end() - 2;
observations = arg;
arg->setName("observations");
if (oldFunc->getReturnType()->isVoidTy())
arg->addAttr(Attribute::Returned);
}

auto arg = newFunc->arg_end() - 1;
trace = arg;
arg->setName("trace");
if (oldFunc->getReturnType()->isVoidTy())
arg->addAttr(Attribute::Returned);

SmallVector<ReturnInst *, 4> Returns;
#if LLVM_VERSION_MAJOR >= 13
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
Expand All @@ -126,38 +131,6 @@ class TraceUtils {
} else {
interface = new StaticTraceInterface(F->getParent());
}

// Create trace for current function

IRBuilder<> Builder(
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
Builder.SetCurrentDebugLocation(oldFunc->getEntryBlock()
.getFirstNonPHIOrDbgOrLifetime()
->getDebugLoc());

trace = CreateTrace(Builder);

// Replace returns with ret trace

SmallVector<ReturnInst *, 3> toReplace;
for (auto &&BB : *newFunc) {
for (auto &&Inst : BB) {
if (auto Ret = dyn_cast<ReturnInst>(&Inst)) {
toReplace.push_back(Ret);
}
}
}

for (auto Ret : toReplace) {
IRBuilder<> Builder(Ret);
if (Ret->getReturnValue()) {
Value *retvals[2] = {Ret->getReturnValue(), trace};
Builder.CreateAggregateRet(retvals, 2);
} else {
Builder.CreateRet(trace);
}
Ret->eraseFromParent();
}
};

~TraceUtils() { delete interface; }
Expand Down
91 changes: 42 additions & 49 deletions enzyme/test/Enzyme/ProbProg/condition-dynamic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ entry:

; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace, i8** %interface)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition
; CHECK-NEXT: %1 = load i32, i32* @enzyme_interface
; CHECK-NEXT: %2 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace)
; CHECK-NEXT: %3 = extractvalue { double, i8* } %2, 1
; CHECK-NEXT: ret i8* %3
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %1 = load i8*, i8** %0
; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()*
; CHECK-NEXT: %2 = load i32, i32* @enzyme_condition
; CHECK-NEXT: %3 = load i32, i32* @enzyme_interface
; CHECK-NEXT: %4 = call i8* %new_trace()
; CHECK-NEXT: %5 = call double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace, i8* %4)
; CHECK-NEXT: ret i8* %4
; CHECK-NEXT: }


; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations)
; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 2
; CHECK-NEXT: %1 = load i8*, i8** %0
Expand All @@ -86,21 +89,20 @@ entry:
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %7 = load i8*, i8** %6
; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
; CHECK-NEXT: %call1.ptr = alloca double
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %9 = load i8*, i8** %8
; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %insert_choice = bitcast i8* %9 to void (i8*, i8*, double, i8*, i64)*
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 1
; CHECK-NEXT: %11 = load i8*, i8** %10
; CHECK-NEXT: %has_choice = bitcast i8* %11 to i1 (i8*, i8*)*
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %get_choice = bitcast i8* %11 to i64 (i8*, i8*, i8*, i64)*
; CHECK-NEXT: %call.ptr = alloca double
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %13 = load i8*, i8** %12
; CHECK-NEXT: %new_trace = bitcast i8* %13 to i8* ()*
; CHECK-NEXT: %trace = call i8* %new_trace()
; CHECK-NEXT: %has_choice = bitcast i8* %13 to i1 (i8*, i8*)*
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace

Expand Down Expand Up @@ -139,30 +141,27 @@ entry:
; CHECK-NEXT: %18 = bitcast double %call1 to i64
; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
; CHECK-NEXT: %trace1 = call i8* %new_trace()
; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace

; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd
; CHECK-NEXT: %calculate_loss.subtrace = call i8* %get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace)
; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace, i8* %trace1)
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd
; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null)
; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null, i8* %trace1)
; CHECK-NEXT: br label %entry.cntd.cntd.cntd

; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
; CHECK-NEXT: ret { double, i8* } %mrv1
; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1)
; CHECK-NEXT: ret double %call2
; CHECK-NEXT: }


; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations)
; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
; CHECK-NEXT: %1 = load i8*, i8** %0
Expand All @@ -174,10 +173,6 @@ entry:
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 7
; CHECK-NEXT: %5 = load i8*, i8** %4
; CHECK-NEXT: %has_choice = bitcast i8* %5 to i1 (i8*, i8*)*
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
; CHECK-NEXT: %7 = load i8*, i8** %6
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
; CHECK-NEXT: %trace = call i8* %new_trace()
; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup

Expand All @@ -186,42 +181,40 @@ entry:
; CHECK-NEXT: br label %for.body

; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %14, %for.body.cntd ]
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
; CHECK-NEXT: ret { double, i8* } %mrv1
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %12, %for.body.cntd ]
; CHECK-NEXT: ret double %loss.0.lcssa

; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader
; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ]
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %14, %for.body.cntd ]
; CHECK-NEXT: %8 = trunc i64 %indvars.iv to i32
; CHECK-NEXT: %conv2 = sitofp i32 %8 to double
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %12, %for.body.cntd ]
; CHECK-NEXT: %6 = trunc i64 %indvars.iv to i32
; CHECK-NEXT: %conv2 = sitofp i32 %6 to double
; CHECK-NEXT: %mul1 = fmul double %conv2, %m
; CHECK-NEXT: %9 = fadd double %mul1, %b
; CHECK-NEXT: %7 = fadd double %mul1, %b
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0))
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace

; CHECK: condition.call.with.trace: ; preds = %for.body
; CHECK-NEXT: %10 = bitcast double* %call.ptr to i8*
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %10, i64 8)
; CHECK-NEXT: %8 = bitcast double* %call.ptr to i8*
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %8, i64 8)
; CHECK-NEXT: %from.trace.call = load double, double* %call.ptr
; CHECK-NEXT: br label %for.body.cntd

; CHECK: condition.call.without.trace: ; preds = %for.body
; CHECK-NEXT: %sample.call = call double @normal(double %9, double 1.000000e+00)
; CHECK-NEXT: %sample.call = call double @normal(double %7, double 1.000000e+00)
; CHECK-NEXT: br label %for.body.cntd

; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call)
; CHECK-NEXT: %11 = bitcast double %call to i64
; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8)
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %7, double 1.000000e+00, double %call)
; CHECK-NEXT: %9 = bitcast double %call to i64
; CHECK-NEXT: %10 = inttoptr i64 %9 to i8*
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %10, i64 8)
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
; CHECK-NEXT: %13 = load double, double* %arrayidx3
; CHECK-NEXT: %sub = fsub double %call, %13
; CHECK-NEXT: %11 = load double, double* %arrayidx3
; CHECK-NEXT: %sub = fsub double %call, %11
; CHECK-NEXT: %mul2 = fmul double %sub, %sub
; CHECK-NEXT: %14 = fadd double %mul2, %loss.021
; CHECK-NEXT: %12 = fadd double %mul2, %loss.021
; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
; CHECK-NEXT: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
Expand Down
Loading