Skip to content

Commit 5894561

Browse files
authored
pass trace as argument (#992)
1 parent 1843339 commit 5894561

File tree

9 files changed

+129
-171
lines changed

9 files changed

+129
-171
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,11 +1894,12 @@ class EnzymeBase {
18941894
Logic.CreateTrace(F, generativeFunctions, mode, has_dynamic_interface);
18951895

18961896
Value *trace =
1897-
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
1898-
if (!F->getReturnType()->isVoidTy())
1899-
trace = Builder.CreateExtractValue(trace, {1});
1897+
Builder.CreateCall(interface->newTraceTy(), interface->newTrace(), {});
1898+
1899+
args.push_back(trace);
1900+
1901+
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
19001902

1901-
// try to cast i8* returned from trace to CI->getRetType....
19021903
if (CI->getType() != trace->getType())
19031904
trace = Builder.CreatePointerCast(trace, CI->getType());
19041905

enzyme/Enzyme/TraceGenerator.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
8888
{
8989
ElseTerm->getParent()->setName("condition." + call.getName() +
9090
".without.trace");
91-
ElseChoice =
91+
92+
auto choice =
9293
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
9394
sample_args, "sample." + call.getName());
95+
ElseChoice = choice;
9496
}
9597

9698
Builder.SetInsertPoint(new_call);
@@ -132,11 +134,16 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
132134
Logic.CreateTrace(called, tutils->generativeFunctions, tutils->mode,
133135
tutils->hasDynamicTraceInterface());
134136

137+
auto trace = tutils->CreateTrace(Builder);
138+
135139
Instruction *tracecall;
136140
switch (mode) {
137141
case ProbProgMode::Trace: {
138-
tracecall = Builder.CreateCall(samplefn->getFunctionType(), samplefn,
139-
args, "trace." + called->getName());
142+
SmallVector<Value *, 2> args_and_trace = SmallVector(args);
143+
args_and_trace.push_back(trace);
144+
tracecall =
145+
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
146+
args_and_trace, "trace." + called->getName());
140147
break;
141148
}
142149
case ProbProgMode::Condition: {
@@ -158,8 +165,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
158165
ThenTerm->getParent()->setName("condition." + call.getName() +
159166
".with.trace");
160167
SmallVector<Value *, 2> args_and_cond = SmallVector(args);
161-
auto trace = tutils->GetTrace(Builder, address,
162-
called->getName() + ".subtrace");
168+
auto observations = tutils->GetTrace(Builder, address,
169+
called->getName() + ".subtrace");
170+
args_and_cond.push_back(observations);
163171
args_and_cond.push_back(trace);
164172
ThenTracecall = Builder.CreateCall(samplefn->getFunctionType(),
165173
samplefn, args_and_cond,
@@ -171,8 +179,9 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
171179
ElseTerm->getParent()->setName("condition." + call.getName() +
172180
".without.trace");
173181
SmallVector<Value *, 2> args_and_null = SmallVector(args);
174-
auto trace = ConstantPointerNull::get(cast<PointerType>(
182+
auto observations = ConstantPointerNull::get(cast<PointerType>(
175183
tutils->getTraceInterface()->newTraceTy()->getReturnType()));
184+
args_and_null.push_back(observations);
176185
args_and_null.push_back(trace);
177186
ElseTracecall =
178187
Builder.CreateCall(samplefn->getFunctionType(), samplefn,
@@ -188,14 +197,10 @@ class TraceGenerator final : public llvm::InstVisitor<TraceGenerator> {
188197
}
189198
}
190199

191-
Value *ret = Builder.CreateExtractValue(tracecall, {0});
192-
Value *subtrace = Builder.CreateExtractValue(
193-
tracecall, {1}, "newtrace." + called->getName());
194-
195-
tutils->InsertCall(Builder, address, subtrace);
200+
tutils->InsertCall(Builder, address, trace);
196201

197-
ret->takeName(new_call);
198-
new_call->replaceAllUsesWith(ret);
202+
tracecall->takeName(new_call);
203+
new_call->replaceAllUsesWith(tracecall);
199204
new_call->eraseFromParent();
200205
}
201206
}

enzyme/Enzyme/TraceUtils.h

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TraceUtils {
2626
private:
2727
TraceInterface *interface;
2828
Value *dynamic_interface = nullptr;
29-
Instruction *trace;
29+
Value *trace;
3030
Value *observations = nullptr;
3131

3232
public:
@@ -70,10 +70,9 @@ class TraceUtils {
7070
if (mode == ProbProgMode::Condition)
7171
params.push_back(traceType);
7272

73-
Type *RetTy = traceType;
74-
if (!oldFunc->getReturnType()->isVoidTy())
75-
RetTy = StructType::get(Context, {oldFunc->getReturnType(), traceType});
73+
params.push_back(traceType);
7674

75+
Type *RetTy = oldFunc->getReturnType();
7776
FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg());
7877

7978
Twine Name = (mode == ProbProgMode::Condition ? "condition_" : "trace_") +
@@ -94,21 +93,27 @@ class TraceUtils {
9493
}
9594

9695
if (has_dynamic_interface) {
97-
auto arg = newFunc->arg_end() - (1 + (mode == ProbProgMode::Condition));
96+
auto arg = newFunc->arg_end() - (2 + (mode == ProbProgMode::Condition));
9897
dynamic_interface = arg;
9998
arg->setName("interface");
10099
arg->addAttr(Attribute::ReadOnly);
101100
arg->addAttr(Attribute::NoCapture);
102101
}
103102

104103
if (mode == ProbProgMode::Condition) {
105-
auto arg = newFunc->arg_end() - 1;
104+
auto arg = newFunc->arg_end() - 2;
106105
observations = arg;
107106
arg->setName("observations");
108107
if (oldFunc->getReturnType()->isVoidTy())
109108
arg->addAttr(Attribute::Returned);
110109
}
111110

111+
auto arg = newFunc->arg_end() - 1;
112+
trace = arg;
113+
arg->setName("trace");
114+
if (oldFunc->getReturnType()->isVoidTy())
115+
arg->addAttr(Attribute::Returned);
116+
112117
SmallVector<ReturnInst *, 4> Returns;
113118
#if LLVM_VERSION_MAJOR >= 13
114119
CloneFunctionInto(newFunc, oldFunc, originalToNewFn,
@@ -126,38 +131,6 @@ class TraceUtils {
126131
} else {
127132
interface = new StaticTraceInterface(F->getParent());
128133
}
129-
130-
// Create trace for current function
131-
132-
IRBuilder<> Builder(
133-
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
134-
Builder.SetCurrentDebugLocation(oldFunc->getEntryBlock()
135-
.getFirstNonPHIOrDbgOrLifetime()
136-
->getDebugLoc());
137-
138-
trace = CreateTrace(Builder);
139-
140-
// Replace returns with ret trace
141-
142-
SmallVector<ReturnInst *, 3> toReplace;
143-
for (auto &&BB : *newFunc) {
144-
for (auto &&Inst : BB) {
145-
if (auto Ret = dyn_cast<ReturnInst>(&Inst)) {
146-
toReplace.push_back(Ret);
147-
}
148-
}
149-
}
150-
151-
for (auto Ret : toReplace) {
152-
IRBuilder<> Builder(Ret);
153-
if (Ret->getReturnValue()) {
154-
Value *retvals[2] = {Ret->getReturnValue(), trace};
155-
Builder.CreateAggregateRet(retvals, 2);
156-
} else {
157-
Builder.CreateRet(trace);
158-
}
159-
Ret->eraseFromParent();
160-
}
161134
};
162135

163136
~TraceUtils() { delete interface; }

enzyme/test/Enzyme/ProbProg/condition-dynamic.ll

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,18 @@ entry:
6767

6868
; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace, i8** %interface)
6969
; CHECK-NEXT: entry:
70-
; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition
71-
; CHECK-NEXT: %1 = load i32, i32* @enzyme_interface
72-
; CHECK-NEXT: %2 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace)
73-
; CHECK-NEXT: %3 = extractvalue { double, i8* } %2, 1
74-
; CHECK-NEXT: ret i8* %3
70+
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4
71+
; CHECK-NEXT: %1 = load i8*, i8** %0
72+
; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()*
73+
; CHECK-NEXT: %2 = load i32, i32* @enzyme_condition
74+
; CHECK-NEXT: %3 = load i32, i32* @enzyme_interface
75+
; CHECK-NEXT: %4 = call i8* %new_trace()
76+
; CHECK-NEXT: %5 = call double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace, i8* %4)
77+
; CHECK-NEXT: ret i8* %4
7578
; CHECK-NEXT: }
7679

7780

78-
; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations)
81+
; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
7982
; CHECK-NEXT: entry:
8083
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 2
8184
; CHECK-NEXT: %1 = load i8*, i8** %0
@@ -86,21 +89,20 @@ entry:
8689
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
8790
; CHECK-NEXT: %5 = load i8*, i8** %4
8891
; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
89-
; CHECK-NEXT: %call1.ptr = alloca double
90-
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
92+
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
9193
; CHECK-NEXT: %7 = load i8*, i8** %6
92-
; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
93-
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
94+
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
95+
; CHECK-NEXT: %call1.ptr = alloca double
96+
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 3
9497
; CHECK-NEXT: %9 = load i8*, i8** %8
95-
; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
96-
; CHECK-NEXT: %call.ptr = alloca double
97-
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 7
98+
; CHECK-NEXT: %insert_choice = bitcast i8* %9 to void (i8*, i8*, double, i8*, i64)*
99+
; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 1
98100
; CHECK-NEXT: %11 = load i8*, i8** %10
99-
; CHECK-NEXT: %has_choice = bitcast i8* %11 to i1 (i8*, i8*)*
100-
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 4
101+
; CHECK-NEXT: %get_choice = bitcast i8* %11 to i64 (i8*, i8*, i8*, i64)*
102+
; CHECK-NEXT: %call.ptr = alloca double
103+
; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 7
101104
; CHECK-NEXT: %13 = load i8*, i8** %12
102-
; CHECK-NEXT: %new_trace = bitcast i8* %13 to i8* ()*
103-
; CHECK-NEXT: %trace = call i8* %new_trace()
105+
; CHECK-NEXT: %has_choice = bitcast i8* %13 to i1 (i8*, i8*)*
104106
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
105107
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace
106108

@@ -139,30 +141,27 @@ entry:
139141
; CHECK-NEXT: %18 = bitcast double %call1 to i64
140142
; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
141143
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
144+
; CHECK-NEXT: %trace1 = call i8* %new_trace()
142145
; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
143146
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
144147

145148
; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd
146149
; CHECK-NEXT: %calculate_loss.subtrace = call i8* %get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
147-
; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace)
150+
; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace, i8* %trace1)
148151
; CHECK-NEXT: br label %entry.cntd.cntd.cntd
149152

150153
; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd
151-
; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null)
154+
; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null, i8* %trace1)
152155
; CHECK-NEXT: br label %entry.cntd.cntd.cntd
153156

154157
; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
155-
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
156-
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
157-
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
158-
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
159-
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
160-
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
161-
; CHECK-NEXT: ret { double, i8* } %mrv1
158+
; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
159+
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1)
160+
; CHECK-NEXT: ret double %call2
162161
; CHECK-NEXT: }
163162

164163

165-
; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations)
164+
; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace)
166165
; CHECK-NEXT: entry:
167166
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
168167
; CHECK-NEXT: %1 = load i8*, i8** %0
@@ -174,10 +173,6 @@ entry:
174173
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 7
175174
; CHECK-NEXT: %5 = load i8*, i8** %4
176175
; CHECK-NEXT: %has_choice = bitcast i8* %5 to i1 (i8*, i8*)*
177-
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
178-
; CHECK-NEXT: %7 = load i8*, i8** %6
179-
; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
180-
; CHECK-NEXT: %trace = call i8* %new_trace()
181176
; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
182177
; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup
183178

@@ -186,42 +181,40 @@ entry:
186181
; CHECK-NEXT: br label %for.body
187182

188183
; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry
189-
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %14, %for.body.cntd ]
190-
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0
191-
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
192-
; CHECK-NEXT: ret { double, i8* } %mrv1
184+
; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %12, %for.body.cntd ]
185+
; CHECK-NEXT: ret double %loss.0.lcssa
193186

194187
; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader
195188
; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ]
196-
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %14, %for.body.cntd ]
197-
; CHECK-NEXT: %8 = trunc i64 %indvars.iv to i32
198-
; CHECK-NEXT: %conv2 = sitofp i32 %8 to double
189+
; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %12, %for.body.cntd ]
190+
; CHECK-NEXT: %6 = trunc i64 %indvars.iv to i32
191+
; CHECK-NEXT: %conv2 = sitofp i32 %6 to double
199192
; CHECK-NEXT: %mul1 = fmul double %conv2, %m
200-
; CHECK-NEXT: %9 = fadd double %mul1, %b
193+
; CHECK-NEXT: %7 = fadd double %mul1, %b
201194
; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0))
202195
; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace
203196

204197
; CHECK: condition.call.with.trace: ; preds = %for.body
205-
; CHECK-NEXT: %10 = bitcast double* %call.ptr to i8*
206-
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %10, i64 8)
198+
; CHECK-NEXT: %8 = bitcast double* %call.ptr to i8*
199+
; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %8, i64 8)
207200
; CHECK-NEXT: %from.trace.call = load double, double* %call.ptr
208201
; CHECK-NEXT: br label %for.body.cntd
209202

210203
; CHECK: condition.call.without.trace: ; preds = %for.body
211-
; CHECK-NEXT: %sample.call = call double @normal(double %9, double 1.000000e+00)
204+
; CHECK-NEXT: %sample.call = call double @normal(double %7, double 1.000000e+00)
212205
; CHECK-NEXT: br label %for.body.cntd
213206

214207
; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
215208
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
216-
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call)
217-
; CHECK-NEXT: %11 = bitcast double %call to i64
218-
; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
219-
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8)
209+
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %7, double 1.000000e+00, double %call)
210+
; CHECK-NEXT: %9 = bitcast double %call to i64
211+
; CHECK-NEXT: %10 = inttoptr i64 %9 to i8*
212+
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %10, i64 8)
220213
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
221-
; CHECK-NEXT: %13 = load double, double* %arrayidx3
222-
; CHECK-NEXT: %sub = fsub double %call, %13
214+
; CHECK-NEXT: %11 = load double, double* %arrayidx3
215+
; CHECK-NEXT: %sub = fsub double %call, %11
223216
; CHECK-NEXT: %mul2 = fmul double %sub, %sub
224-
; CHECK-NEXT: %14 = fadd double %mul2, %loss.021
217+
; CHECK-NEXT: %12 = fadd double %mul2, %loss.021
225218
; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
226219
; CHECK-NEXT: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
227220
; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body

0 commit comments

Comments
 (0)