@@ -67,15 +67,18 @@ entry:
6767
6868; CHECK: define i8* @condition(double* %data, i32 %n, i8* %trace, i8** %interface)
6969; CHECK-NEXT: entry:
70- ; CHECK-NEXT: %0 = load i32, i32* @enzyme_condition
71- ; CHECK-NEXT: %1 = load i32, i32* @enzyme_interface
72- ; CHECK-NEXT: %2 = call { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace)
73- ; CHECK-NEXT: %3 = extractvalue { double, i8* } %2, 1
74- ; CHECK-NEXT: ret i8* %3
70+ ; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 4
71+ ; CHECK-NEXT: %1 = load i8*, i8** %0
72+ ; CHECK-NEXT: %new_trace = bitcast i8* %1 to i8* ()*
73+ ; CHECK-NEXT: %2 = load i32, i32* @enzyme_condition
74+ ; CHECK-NEXT: %3 = load i32, i32* @enzyme_interface
75+ ; CHECK-NEXT: %4 = call i8* %new_trace()
76+ ; CHECK-NEXT: %5 = call double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %trace, i8* %4)
77+ ; CHECK-NEXT: ret i8* %4
7578; CHECK-NEXT: }
7679
7780
78- ; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations)
81+ ; CHECK: define internal double @condition_loss(double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace )
7982; CHECK-NEXT: entry:
8083; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 2
8184; CHECK-NEXT: %1 = load i8*, i8** %0
@@ -86,21 +89,20 @@ entry:
8689; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
8790; CHECK-NEXT: %5 = load i8*, i8** %4
8891; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
89- ; CHECK-NEXT: %call1.ptr = alloca double
90- ; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
92+ ; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
9193; CHECK-NEXT: %7 = load i8*, i8** %6
92- ; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
93- ; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
94+ ; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
95+ ; CHECK-NEXT: %call1.ptr = alloca double
96+ ; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 3
9497; CHECK-NEXT: %9 = load i8*, i8** %8
95- ; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
96- ; CHECK-NEXT: %call.ptr = alloca double
97- ; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 7
98+ ; CHECK-NEXT: %insert_choice = bitcast i8* %9 to void (i8*, i8*, double, i8*, i64)*
99+ ; CHECK-NEXT: %10 = getelementptr inbounds i8*, i8** %interface, i32 1
98100; CHECK-NEXT: %11 = load i8*, i8** %10
99- ; CHECK-NEXT: %has_choice = bitcast i8* %11 to i1 (i8*, i8*)*
100- ; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 4
101+ ; CHECK-NEXT: %get_choice = bitcast i8* %11 to i64 (i8*, i8*, i8*, i64)*
102+ ; CHECK-NEXT: %call.ptr = alloca double
103+ ; CHECK-NEXT: %12 = getelementptr inbounds i8*, i8** %interface, i32 7
101104; CHECK-NEXT: %13 = load i8*, i8** %12
102- ; CHECK-NEXT: %new_trace = bitcast i8* %13 to i8* ()*
103- ; CHECK-NEXT: %trace = call i8* %new_trace()
105+ ; CHECK-NEXT: %has_choice = bitcast i8* %13 to i1 (i8*, i8*)*
104106; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
105107; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace
106108
@@ -139,30 +141,27 @@ entry:
139141; CHECK-NEXT: %18 = bitcast double %call1 to i64
140142; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
141143; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
144+ ; CHECK-NEXT: %trace1 = call i8* %new_trace()
142145; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
143146; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
144147
145148; CHECK: condition.call2.with.trace: ; preds = %entry.cntd.cntd
146149; CHECK-NEXT: %calculate_loss.subtrace = call i8* %get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
147- ; CHECK-NEXT: %condition.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace)
150+ ; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* %calculate_loss.subtrace, i8* %trace1 )
148151; CHECK-NEXT: br label %entry.cntd.cntd.cntd
149152
150153; CHECK: condition.call2.without.trace: ; preds = %entry.cntd.cntd
151- ; CHECK-NEXT: %trace.calculate_loss = call { double, i8* } @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null)
154+ ; CHECK-NEXT: %trace.calculate_loss = call double @condition_calculate_loss(double %call, double %call1, double* %data, i32 %n, i8** %interface, i8* null, i8* %trace1 )
152155; CHECK-NEXT: br label %entry.cntd.cntd.cntd
153156
154157; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
155- ; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
156- ; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
157- ; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
158- ; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
159- ; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
160- ; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
161- ; CHECK-NEXT: ret { double, i8* } %mrv1
158+ ; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
159+ ; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %trace1)
160+ ; CHECK-NEXT: ret double %call2
162161; CHECK-NEXT: }
163162
164163
165- ; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations)
164+ ; CHECK: define internal double @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8** %interface, i8* %observations, i8* %trace )
166165; CHECK-NEXT: entry:
167166; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
168167; CHECK-NEXT: %1 = load i8*, i8** %0
@@ -174,10 +173,6 @@ entry:
174173; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 7
175174; CHECK-NEXT: %5 = load i8*, i8** %4
176175; CHECK-NEXT: %has_choice = bitcast i8* %5 to i1 (i8*, i8*)*
177- ; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 4
178- ; CHECK-NEXT: %7 = load i8*, i8** %6
179- ; CHECK-NEXT: %new_trace = bitcast i8* %7 to i8* ()*
180- ; CHECK-NEXT: %trace = call i8* %new_trace()
181176; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
182177; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup
183178
@@ -186,42 +181,40 @@ entry:
186181; CHECK-NEXT: br label %for.body
187182
188183; CHECK: for.cond.cleanup: ; preds = %for.body.cntd, %entry
189- ; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %14, %for.body.cntd ]
190- ; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %loss.0.lcssa, 0
191- ; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
192- ; CHECK-NEXT: ret { double, i8* } %mrv1
184+ ; CHECK-NEXT: %loss.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %12, %for.body.cntd ]
185+ ; CHECK-NEXT: ret double %loss.0.lcssa
193186
194187; CHECK: for.body: ; preds = %for.body.cntd, %for.body.preheader
195188; CHECK-NEXT: %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.body.cntd ]
196- ; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %14 , %for.body.cntd ]
197- ; CHECK-NEXT: %8 = trunc i64 %indvars.iv to i32
198- ; CHECK-NEXT: %conv2 = sitofp i32 %8 to double
189+ ; CHECK-NEXT: %loss.021 = phi double [ 0.000000e+00, %for.body.preheader ], [ %12 , %for.body.cntd ]
190+ ; CHECK-NEXT: %6 = trunc i64 %indvars.iv to i32
191+ ; CHECK-NEXT: %conv2 = sitofp i32 %6 to double
199192; CHECK-NEXT: %mul1 = fmul double %conv2, %m
200- ; CHECK-NEXT: %9 = fadd double %mul1, %b
193+ ; CHECK-NEXT: %7 = fadd double %mul1, %b
201194; CHECK-NEXT: %has.choice.call = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0))
202195; CHECK-NEXT: br i1 %has.choice.call, label %condition.call.with.trace, label %condition.call.without.trace
203196
204197; CHECK: condition.call.with.trace: ; preds = %for.body
205- ; CHECK-NEXT: %10 = bitcast double* %call.ptr to i8*
206- ; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %10 , i64 8)
198+ ; CHECK-NEXT: %8 = bitcast double* %call.ptr to i8*
199+ ; CHECK-NEXT: %call.size = call i64 %get_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), i8* %8 , i64 8)
207200; CHECK-NEXT: %from.trace.call = load double, double* %call.ptr
208201; CHECK-NEXT: br label %for.body.cntd
209202
210203; CHECK: condition.call.without.trace: ; preds = %for.body
211- ; CHECK-NEXT: %sample.call = call double @normal(double %9 , double 1.000000e+00)
204+ ; CHECK-NEXT: %sample.call = call double @normal(double %7 , double 1.000000e+00)
212205; CHECK-NEXT: br label %for.body.cntd
213206
214207; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
215208; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
216- ; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9 , double 1.000000e+00, double %call)
217- ; CHECK-NEXT: %11 = bitcast double %call to i64
218- ; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
219- ; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12 , i64 8)
209+ ; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %7 , double 1.000000e+00, double %call)
210+ ; CHECK-NEXT: %9 = bitcast double %call to i64
211+ ; CHECK-NEXT: %10 = inttoptr i64 %9 to i8*
212+ ; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %10 , i64 8)
220213; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
221- ; CHECK-NEXT: %13 = load double, double* %arrayidx3
222- ; CHECK-NEXT: %sub = fsub double %call, %13
214+ ; CHECK-NEXT: %11 = load double, double* %arrayidx3
215+ ; CHECK-NEXT: %sub = fsub double %call, %11
223216; CHECK-NEXT: %mul2 = fmul double %sub, %sub
224- ; CHECK-NEXT: %14 = fadd double %mul2, %loss.021
217+ ; CHECK-NEXT: %12 = fadd double %mul2, %loss.021
225218; CHECK-NEXT: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
226219; CHECK-NEXT: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
227220; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
0 commit comments