diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index a154776d3288..12dcbf55414f 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1894,11 +1894,12 @@ class EnzymeBase { Logic.CreateTrace(F, generativeFunctions, mode, has_dynamic_interface); Value *trace = - Builder.CreateCall(newFunc->getFunctionType(), newFunc, args); - if (!F->getReturnType()->isVoidTy()) - trace = Builder.CreateExtractValue(trace, {1}); + Builder.CreateCall(interface->newTraceTy(), interface->newTrace(), {}); + + args.push_back(trace); + + Builder.CreateCall(newFunc->getFunctionType(), newFunc, args); - // try to cast i8* returned from trace to CI->getRetType.... if (CI->getType() != trace->getType()) trace = Builder.CreatePointerCast(trace, CI->getType()); diff --git a/enzyme/Enzyme/TraceGenerator.h b/enzyme/Enzyme/TraceGenerator.h index dbd7aad89519..99131ff6bb5a 100644 --- a/enzyme/Enzyme/TraceGenerator.h +++ b/enzyme/Enzyme/TraceGenerator.h @@ -88,9 +88,11 @@ class TraceGenerator final : public llvm::InstVisitor { { ElseTerm->getParent()->setName("condition." + call.getName() + ".without.trace"); - ElseChoice = + + auto choice = Builder.CreateCall(samplefn->getFunctionType(), samplefn, sample_args, "sample." + call.getName()); + ElseChoice = choice; } Builder.SetInsertPoint(new_call); @@ -132,11 +134,16 @@ class TraceGenerator final : public llvm::InstVisitor { Logic.CreateTrace(called, tutils->generativeFunctions, tutils->mode, tutils->hasDynamicTraceInterface()); + auto trace = tutils->CreateTrace(Builder); + Instruction *tracecall; switch (mode) { case ProbProgMode::Trace: { - tracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn, - args, "trace." + called->getName()); + SmallVector args_and_trace = SmallVector(args); + args_and_trace.push_back(trace); + tracecall = + Builder.CreateCall(samplefn->getFunctionType(), samplefn, + args_and_trace, "trace." + called->getName()); break; } case ProbProgMode::Condition: { @@ -158,8 +165,9 @@ class TraceGenerator final : public llvm::InstVisitor { ThenTerm->getParent()->setName("condition." + call.getName() + ".with.trace"); SmallVector args_and_cond = SmallVector(args); - auto trace = tutils->GetTrace(Builder, address, - called->getName() + ".subtrace"); + auto observations = tutils->GetTrace(Builder, address, + called->getName() + ".subtrace"); + args_and_cond.push_back(observations); args_and_cond.push_back(trace); ThenTracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn, args_and_cond, @@ -171,8 +179,9 @@ class TraceGenerator final : public llvm::InstVisitor { ElseTerm->getParent()->setName("condition." + call.getName() + ".without.trace"); SmallVector args_and_null = SmallVector(args); - auto trace = ConstantPointerNull::get(cast( + auto observations = ConstantPointerNull::get(cast( tutils->getTraceInterface()->newTraceTy()->getReturnType())); + args_and_null.push_back(observations); args_and_null.push_back(trace); ElseTracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn, @@ -188,14 +197,10 @@ class TraceGenerator final : public llvm::InstVisitor { } } - Value *ret = Builder.CreateExtractValue(tracecall, {0}); - Value *subtrace = Builder.CreateExtractValue( - tracecall, {1}, "newtrace." + called->getName()); - - tutils->InsertCall(Builder, address, subtrace); + tutils->InsertCall(Builder, address, trace); - ret->takeName(new_call); - new_call->replaceAllUsesWith(ret); + tracecall->takeName(new_call); + new_call->replaceAllUsesWith(tracecall); new_call->eraseFromParent(); } } diff --git a/enzyme/Enzyme/TraceUtils.h b/enzyme/Enzyme/TraceUtils.h index 907eb2a06c08..e8accaaac545 100644 --- a/enzyme/Enzyme/TraceUtils.h +++ b/enzyme/Enzyme/TraceUtils.h @@ -26,7 +26,7 @@ class TraceUtils { private: TraceInterface *interface; Value *dynamic_interface = nullptr; - Instruction *trace; + Value *trace; Value *observations = nullptr; public: @@ -70,10 +70,9 @@ class TraceUtils { if (mode == ProbProgMode::Condition) params.push_back(traceType); - Type *RetTy = traceType; - if (!oldFunc->getReturnType()->isVoidTy()) - RetTy = StructType::get(Context, {oldFunc->getReturnType(), traceType}); + params.push_back(traceType); + Type *RetTy = oldFunc->getReturnType(); FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg()); Twine Name = (mode == ProbProgMode::Condition ? "condition_" : "trace_") + @@ -94,7 +93,7 @@ class TraceUtils { } if (has_dynamic_interface) { - auto arg = newFunc->arg_end() - (1 + (mode == ProbProgMode::Condition)); + auto arg = newFunc->arg_end() - (2 + (mode == ProbProgMode::Condition)); dynamic_interface = arg; arg->setName("interface"); arg->addAttr(Attribute::ReadOnly); @@ -102,13 +101,19 @@ class TraceUtils { } if (mode == ProbProgMode::Condition) { - auto arg = newFunc->arg_end() - 1; + auto arg = newFunc->arg_end() - 2; observations = arg; arg->setName("observations"); if (oldFunc->getReturnType()->isVoidTy()) arg->addAttr(Attribute::Returned); } + auto arg = newFunc->arg_end() - 1; + trace = arg; + arg->setName("trace"); + if (oldFunc->getReturnType()->isVoidTy()) + arg->addAttr(Attribute::Returned); + SmallVector Returns; #if LLVM_VERSION_MAJOR >= 13 CloneFunctionInto(newFunc, oldFunc, originalToNewFn, @@ -126,38 +131,6 @@ class TraceUtils { } else { interface = new StaticTraceInterface(F->getParent()); } - - // Create trace for current function - - IRBuilder<> Builder( - newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime()); - Builder.SetCurrentDebugLocation(oldFunc->getEntryBlock() - .getFirstNonPHIOrDbgOrLifetime() - ->getDebugLoc()); - - trace = CreateTrace(Builder); - - // Replace returns with ret trace - - SmallVector toReplace; - for (auto &&BB : *newFunc) { - for (auto &&Inst : BB) { - if (auto Ret = dyn_cast(&Inst)) { - toReplace.push_back(Ret); - } - } - } - - for (auto Ret : toReplace) { - IRBuilder<> Builder(Ret); - if (Ret->getReturnValue()) { - Value *retvals[2] = {Ret->getReturnValue(), trace}; - Builder.CreateAggregateRet(retvals, 2); - } else { - Builder.CreateRet(trace); - } - Ret->eraseFromParent(); - } }; ~TraceUtils() { delete interface; } diff --git a/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll b/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll index 250a207d006d..8835e0b67186 100644 --- a/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll +++ b/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll @@ -67,15 +67,18 @@ entry: ; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace, i8** %interface) ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition -; CHECK-NEXT: %1 = load i32, i32* @enzyme_interface -; CHECK-NEXT: %2 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace) -; CHECK-NEXT: %3 = extractvalue { double, i8* } %2, 1 -; CHECK-NEXT: ret i8* %3 +; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4 +; CHECK-NEXT: %1 = load i8*, i8** %0 +; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()* +; CHECK-NEXT: %2 = load i32, i32* @enzyme_condition +; CHECK-NEXT: %3 = load i32, i32* @enzyme_interface +; CHECK-NEXT: %4 = call i8* %new_trace() +; CHECK-NEXT: %5 = call double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace, i8* %4) +; CHECK-NEXT: ret i8* %4 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations) +; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 2 ; CHECK-NEXT: %1 = load i8*, i8** %0 @@ -86,21 +89,20 @@ entry: ; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6 ; CHECK-NEXT: %5 = load i8*, i8** %4 ; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)* -; CHECK-NEXT: %call1.ptr = alloca double -; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3 +; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4 ; CHECK-NEXT: %7 = load i8*, i8** %6 -; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)* -; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1 +; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()* +; CHECK-NEXT: %call1.ptr = alloca double +; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 3 ; CHECK-NEXT: %9 = load i8*, i8** %8 -; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)* -; CHECK-NEXT: %call.ptr = alloca double -; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 7 +; CHECK-NEXT: %insert_choice = bitcast i8* %9 to void (i8*, i8*, double, i8*, i64)* +; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 1 ; CHECK-NEXT: %11 = load i8*, i8** %10 -; CHECK-NEXT: %has_choice = bitcast i8* %11 to i1 (i8*, i8*)* -; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 4 +; CHECK-NEXT: %get_choice = bitcast i8* %11 to i64 (i8*, i8*, i8*, i64)* +; CHECK-NEXT: %call.ptr = alloca double +; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 7 ; CHECK-NEXT: %13 = load i8*, i8** %12 -; CHECK-NEXT: %new_trace = bitcast i8* %13 to i8* ()* -; CHECK-NEXT: %trace = call i8* %new_trace() +; CHECK-NEXT: %has_choice = bitcast i8* %13 to i1 (i8*, i8*)* ; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0)) ; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace @@ -139,30 +141,27 @@ entry: ; CHECK-NEXT: %18 = bitcast double %call1 to i64 ; CHECK-NEXT: %19 = inttoptr i64 %18 to i8* ; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8) +; CHECK-NEXT: %trace1 = call i8* %new_trace() ; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0)) ; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace ; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd ; CHECK-NEXT: %calculate_loss.subtrace = call i8* %get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0)) -; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace) +; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace, i8* %trace1) ; CHECK-NEXT: br label %entry.cntd.cntd.cntd ; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd -; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null) +; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null, i8* %trace1) ; CHECK-NEXT: br label %entry.cntd.cntd.cntd ; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace -; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] -; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0 -; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1 -; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss) -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] +; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1) +; CHECK-NEXT: ret double %call2 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations) +; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3 ; CHECK-NEXT: %1 = load i8*, i8** %0 @@ -174,10 +173,6 @@ entry: ; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 7 ; CHECK-NEXT: %5 = load i8*, i8** %4 ; CHECK-NEXT: %has_choice = bitcast i8* %5 to i1 (i8*, i8*)* -; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4 -; CHECK-NEXT: %7 = load i8*, i8** %6 -; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()* -; CHECK-NEXT: %trace = call i8* %new_trace() ; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0 ; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup @@ -186,42 +181,40 @@ entry: ; CHECK-NEXT: br label %for.body ; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry -; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %14, %for.body.cntd ] -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %12, %for.body.cntd ] +; CHECK-NEXT: ret double %loss.0.lcssa ; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader ; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ] -; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %14, %for.body.cntd ] -; CHECK-NEXT: %8 = trunc i64 %indvars.iv to i32 -; CHECK-NEXT: %conv2 = sitofp i32 %8 to double +; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %12, %for.body.cntd ] +; CHECK-NEXT: %6 = trunc i64 %indvars.iv to i32 +; CHECK-NEXT: %conv2 = sitofp i32 %6 to double ; CHECK-NEXT: %mul1 = fmul double %conv2, %m -; CHECK-NEXT: %9 = fadd double %mul1, %b +; CHECK-NEXT: %7 = fadd double %mul1, %b ; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0)) ; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace ; CHECK: condition.call.with.trace: ; preds = %for.body -; CHECK-NEXT: %10 = bitcast double* %call.ptr to i8* -; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %10, i64 8) +; CHECK-NEXT: %8 = bitcast double* %call.ptr to i8* +; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %8, i64 8) ; CHECK-NEXT: %from.trace.call = load double, double* %call.ptr ; CHECK-NEXT: br label %for.body.cntd ; CHECK: condition.call.without.trace: ; preds = %for.body -; CHECK-NEXT: %sample.call = call double @normal(double %9, double 1.000000e+00) +; CHECK-NEXT: %sample.call = call double @normal(double %7, double 1.000000e+00) ; CHECK-NEXT: br label %for.body.cntd ; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace ; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ] -; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call) -; CHECK-NEXT: %11 = bitcast double %call to i64 -; CHECK-NEXT: %12 = inttoptr i64 %11 to i8* -; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8) +; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %7, double 1.000000e+00, double %call) +; CHECK-NEXT: %9 = bitcast double %call to i64 +; CHECK-NEXT: %10 = inttoptr i64 %9 to i8* +; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %10, i64 8) ; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv -; CHECK-NEXT: %13 = load double, double* %arrayidx3 -; CHECK-NEXT: %sub = fsub double %call, %13 +; CHECK-NEXT: %11 = load double, double* %arrayidx3 +; CHECK-NEXT: %sub = fsub double %call, %11 ; CHECK-NEXT: %mul2 = fmul double %sub, %sub -; CHECK-NEXT: %14 = fadd double %mul2, %loss.021 +; CHECK-NEXT: %12 = fadd double %mul2, %loss.021 ; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 ; CHECK-NEXT: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count ; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body diff --git a/enzyme/test/Enzyme/ProbProg/condition-static.ll b/enzyme/test/Enzyme/ProbProg/condition-static.ll index 760b172cd56d..811313be6bf4 100644 --- a/enzyme/test/Enzyme/ProbProg/condition-static.ll +++ b/enzyme/test/Enzyme/ProbProg/condition-static.ll @@ -72,17 +72,16 @@ entry: ; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition -; CHECK-NEXT: %1 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8* %trace) -; CHECK-NEXT: %2 = extractvalue { double, i8* } %1, 1 -; CHECK-NEXT: ret i8* %2 +; CHECK-NEXT: %1 = call i8* @__enzyme_newtrace() +; CHECK-NEXT: %2 = call double @condition_loss(double* %data, i32 %n, i8* %trace, i8* %1) +; CHECK-NEXT: ret i8* %1 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8* %observations) +; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8* %observations, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %call1.ptr = alloca double ; CHECK-NEXT: %call.ptr = alloca double -; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %has.choice.call = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0)) ; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace @@ -121,33 +120,29 @@ entry: ; CHECK-NEXT: %4 = bitcast double %call1 to i64 ; CHECK-NEXT: %5 = inttoptr i64 %4 to i8* ; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %5, i64 8) +; CHECK-NEXT: %trace1 = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %has.call.call2 = call i1 @__enzyme_has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0)) ; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace ; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd ; CHECK-NEXT: %calculate_loss.subtrace = call i8* @__enzyme_get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0)) -; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8* %calculate_loss.subtrace) +; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8* %calculate_loss.subtrace, i8* %trace1) ; CHECK-NEXT: br label %entry.cntd.cntd.cntd ; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd -; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8* null) +; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8* null, i8* %trace1) ; CHECK-NEXT: br label %entry.cntd.cntd.cntd ; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace -; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] -; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0 -; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1 -; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss) -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] +; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1) +; CHECK-NEXT: ret double %call2 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8* %observations) +; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8* %observations, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %call.ptr = alloca double -; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0 ; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup @@ -157,9 +152,7 @@ entry: ; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry ; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %6, %for.body.cntd ] -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: ret double %loss.0.lcssa ; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader ; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ] diff --git a/enzyme/test/Enzyme/ProbProg/simple-condition.ll b/enzyme/test/Enzyme/ProbProg/simple-condition.ll index 5e6e9587a50d..5554829395e9 100644 --- a/enzyme/test/Enzyme/ProbProg/simple-condition.ll +++ b/enzyme/test/Enzyme/ProbProg/simple-condition.ll @@ -44,16 +44,15 @@ entry: ; CHECK: define i8* @condition(i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition -; CHECK-NEXT: %1 = call i8* @condition_test(i8* %trace) +; CHECK-NEXT: %1 = call i8* @__enzyme_newtrace() +; CHECK-NEXT: call void @condition_test(i8* %trace, i8* %1) ; CHECK-NEXT: ret i8* %1 ; CHECK-NEXT: } - -; CHECK: define internal i8* @condition_test(i8* %observations) +; CHECK: define internal void @condition_test(i8* %observations, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %x.ptr = alloca double ; CHECK-NEXT: %mu.ptr = alloca double -; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %has.choice.mu = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0)) ; CHECK-NEXT: br i1 %has.choice.mu, label %condition.mu.with.trace, label %condition.mu.without.trace @@ -92,5 +91,5 @@ entry: ; CHECK-NEXT: %4 = bitcast double %x to i64 ; CHECK-NEXT: %5 = inttoptr i64 %4 to i8* ; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %5, i64 8) -; CHECK-NEXT: ret i8* %trace +; CHECK-NEXT: ret void ; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ProbProg/simple-trace.ll b/enzyme/test/Enzyme/ProbProg/simple-trace.ll index 02ad3160cb6d..791db0188a95 100644 --- a/enzyme/test/Enzyme/ProbProg/simple-trace.ll +++ b/enzyme/test/Enzyme/ProbProg/simple-trace.ll @@ -33,14 +33,14 @@ entry: ; CHECK: define i8* @generate() ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call i8* @trace_test() +; CHECK-NEXT: %0 = call i8* @__enzyme_newtrace() +; CHECK-NEXT: call void @trace_test(i8* %0) ; CHECK-NEXT: ret i8* %0 ; CHECK-NEXT: } -; CHECK: define internal i8* @trace_test() +; CHECK: define internal void @trace_test(i8* %trace) ; CHECK-NEXT: entry: -; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %mu = call double @normal(double 0.000000e+00, double 1.000000e+00) ; CHECK-NEXT: %likelihood.mu = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %mu) ; CHECK-NEXT: %0 = bitcast double %mu to i64 @@ -51,5 +51,5 @@ entry: ; CHECK-NEXT: %2 = bitcast double %x to i64 ; CHECK-NEXT: %3 = inttoptr i64 %2 to i8* ; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %3, i64 8) -; CHECK-NEXT: ret i8* %trace +; CHECK-NEXT: ret void ; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ProbProg/trace-dynamic.ll b/enzyme/test/Enzyme/ProbProg/trace-dynamic.ll index e23c8c6da12a..ad0630808ee4 100644 --- a/enzyme/test/Enzyme/ProbProg/trace-dynamic.ll +++ b/enzyme/test/Enzyme/ProbProg/trace-dynamic.ll @@ -87,32 +87,31 @@ entry: ; CHECK: define i8* @generate(double* %data, i32 %n, i8** %interface) ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load i32, i32* @enzyme_interface -; CHECK-NEXT: %1 = call { double, i8* } @trace_loss(double* %data, i32 %n, i8** %interface) -; CHECK-NEXT: %2 = extractvalue { double, i8* } %1, 1 -; CHECK-NEXT: ret i8* %2 +; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4 +; CHECK-NEXT: %1 = load i8*, i8** %0 +; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()* +; CHECK-NEXT: %2 = load i32, i32* @enzyme_interface +; CHECK-NEXT: %3 = call i8* %new_trace() +; CHECK-NEXT: %4 = call double @trace_loss(double* %data, i32 %n, i8** %interface, i8* %3) +; CHECK-NEXT: ret i8* %3 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @trace_loss(double* %data, i32 %n, i8** %interface) +; CHECK: define internal double @trace_loss(double* %data, i32 %n, i8** %interface, i8* %trace) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3 ; CHECK-NEXT: %1 = load i8*, i8** %0 ; CHECK-NEXT: %insert_choice = bitcast i8* %1 to void (i8*, i8*, double, i8*, i64)* -; CHECK-NEXT: %2 = getelementptr inbounds i8*, i8** %interface, i32 4 -; CHECK-NEXT: %3 = load i8*, i8** %2 -; CHECK-NEXT: %new_trace = bitcast i8* %3 to i8* ()* -; CHECK-NEXT: %trace = call i8* %new_trace() ; CHECK-NEXT: %call = call double @normal(double 0.000000e+00, double 1.000000e+00) ; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call) -; CHECK-NEXT: %4 = bitcast double %call to i64 -; CHECK-NEXT: %5 = inttoptr i64 %4 to i8* -; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %5, i64 8) +; CHECK-NEXT: %2 = bitcast double %call to i64 +; CHECK-NEXT: %3 = inttoptr i64 %2 to i8* +; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %3, i64 8) ; CHECK-NEXT: %call1 = call double @normal(double 0.000000e+00, double 1.000000e+00) ; CHECK-NEXT: %likelihood.call1 = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call1) -; CHECK-NEXT: %6 = bitcast double %call1 to i64 -; CHECK-NEXT: %7 = inttoptr i64 %6 to i8* -; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %7, i64 8) +; CHECK-NEXT: %4 = bitcast double %call1 to i64 +; CHECK-NEXT: %5 = inttoptr i64 %4 to i8* +; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %5, i64 8) ; CHECK-NEXT: %cmp19.i = icmp sgt i32 %n, 0 ; CHECK-NEXT: br i1 %cmp19.i, label %for.body.preheader.i, label %calculate_loss.exit @@ -122,28 +121,26 @@ entry: ; CHECK: for.body.i: ; preds = %for.body.i, %for.body.preheader.i ; CHECK-NEXT: %indvars.iv.i = phi i64 [ 0, %for.body.preheader.i ], [ %indvars.iv.next.i, %for.body.i ] -; CHECK-NEXT: %loss.021.i = phi double [ 0.000000e+00, %for.body.preheader.i ], [ %13, %for.body.i ] -; CHECK-NEXT: %8 = trunc i64 %indvars.iv.i to i32 -; CHECK-NEXT: %conv2.i = sitofp i32 %8 to double +; CHECK-NEXT: %loss.021.i = phi double [ 0.000000e+00, %for.body.preheader.i ], [ %11, %for.body.i ] +; CHECK-NEXT: %6 = trunc i64 %indvars.iv.i to i32 +; CHECK-NEXT: %conv2.i = sitofp i32 %6 to double ; CHECK-NEXT: %mul1 = fmul double %conv2.i, %call -; CHECK-NEXT: %9 = fadd double %mul1, %call1 -; CHECK-NEXT: %call.i = call double @normal(double %9, double 1.000000e+00) -; CHECK-NEXT: %likelihood.call.i = call double @normal_logpdf(double %9, double 1.000000e+00, double %call.i) -; CHECK-NEXT: %10 = bitcast double %call.i to i64 -; CHECK-NEXT: %11 = inttoptr i64 %10 to i8* -; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call.i, i8* %11, i64 8) +; CHECK-NEXT: %7 = fadd double %mul1, %call1 +; CHECK-NEXT: %call.i = call double @normal(double %7, double 1.000000e+00) +; CHECK-NEXT: %likelihood.call.i = call double @normal_logpdf(double %7, double 1.000000e+00, double %call.i) +; CHECK-NEXT: %8 = bitcast double %call.i to i64 +; CHECK-NEXT: %9 = inttoptr i64 %8 to i8* +; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call.i, i8* %9, i64 8) ; CHECK-NEXT: %arrayidx3.i = getelementptr inbounds double, double* %data, i64 %indvars.iv.i -; CHECK-NEXT: %12 = load double, double* %arrayidx3.i -; CHECK-NEXT: %sub.i = fsub double %call.i, %12 +; CHECK-NEXT: %10 = load double, double* %arrayidx3.i +; CHECK-NEXT: %sub.i = fsub double %call.i, %10 ; CHECK-NEXT: %mul2 = fmul double %sub.i, %sub.i -; CHECK-NEXT: %13 = fadd double %mul2, %loss.021.i +; CHECK-NEXT: %11 = fadd double %mul2, %loss.021.i ; CHECK-NEXT: %indvars.iv.next.i = add nuw nsw i64 %indvars.iv.i, 1 ; CHECK-NEXT: %exitcond.not.i = icmp eq i64 %indvars.iv.next.i, %wide.trip.count.i ; CHECK-NEXT: br i1 %exitcond.not.i, label %calculate_loss.exit, label %for.body.i ; CHECK: calculate_loss.exit: ; preds = %for.body.i, %entry -; CHECK-NEXT: %loss.0.lcssa.i = phi double [ 0.000000e+00, %entry ], [ %13, %for.body.i ] -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa.i, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: %loss.0.lcssa.i = phi double [ 0.000000e+00, %entry ], [ %11, %for.body.i ] +; CHECK-NEXT: ret double %loss.0.lcssa.i ; CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/Enzyme/ProbProg/trace-static.ll b/enzyme/test/Enzyme/ProbProg/trace-static.ll index d4b62ea33133..2241ef304b38 100644 --- a/enzyme/test/Enzyme/ProbProg/trace-static.ll +++ b/enzyme/test/Enzyme/ProbProg/trace-static.ll @@ -90,15 +90,14 @@ entry: ; CHECK: define i8* @generate(double* %data, i32 %n) ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { double, i8* } @trace_loss(double* %data, i32 %n) -; CHECK-NEXT: %1 = extractvalue { double, i8* } %0, 1 -; CHECK-NEXT: ret i8* %1 +; CHECK-NEXT: %0 = call i8* @__enzyme_newtrace() +; CHECK-NEXT: %1 = call double @trace_loss(double* %data, i32 %n, i8* %0) +; CHECK-NEXT: ret i8* %0 ; CHECK-NEXT: } -; CHECK: define internal { double, i8* } @trace_loss(double* %data, i32 %n) +; CHECK: define internal double @trace_loss(double* %data, i32 %n, i8* %trace) ; CHECK-NEXT: entry: -; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace() ; CHECK-NEXT: %call = call double @normal(double 0.000000e+00, double 1.000000e+00) ; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call) ; CHECK-NEXT: %0 = bitcast double %call to i64 @@ -139,7 +138,5 @@ entry: ; CHECK: calculate_loss.exit: ; preds = %for.body.i, %entry ; CHECK-NEXT: %loss.0.lcssa.i = phi double [ 0.000000e+00, %entry ], [ %9, %for.body.i ] -; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa.i, 0 -; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1 -; CHECK-NEXT: ret { double, i8* } %mrv1 +; CHECK-NEXT: ret double %loss.0.lcssa.i ; CHECK-NEXT: } \ No newline at end of file