@@ -3,7 +3,7 @@ using Core.Compiler: Const, isconstType, argtypes_to_type, tuple_tfunc, Const,
33 getfield_tfunc, _methods_by_ftype, VarTable, cache_lookup, nfields_tfunc,
44 ArgInfo, singleton_type, CallMeta, MethodMatchInfo, specialize_method,
55 PartialOpaque, UnionSplitApplyCallInfo, typeof_tfunc, apply_type_tfunc, instanceof_tfunc,
6- StmtInfo
6+ StmtInfo, NoCallInfo
77using Core: PartialStruct
88using Base. Meta
99
@@ -41,7 +41,11 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia
4141 else
4242 rt2 = obtype
4343 end
44+ @static if VERSION ≥ v " 1.11.0-DEV.945"
45+ return CallMeta (rt2, call. exct, call. effects, RecurseInfo (call. info))
46+ else
4447 return CallMeta (rt2, call. effects, RecurseInfo (call. info))
48+ end
4549 end
4650
4751 # Check if there is a rrule for this function
@@ -56,7 +60,12 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia
5660 end
5761 call = abstract_call_gf_by_type (lower_level (interp), ChainRules. rrule, ArgInfo (nothing , rrule_argtypes), rrule_atype, sv, - 1 )
5862 if call. rt != Const (nothing )
59- return CallMeta (getfield_tfunc (call. rt, Const (1 )), call. effects, RRuleInfo (call. rt, call. info))
63+ newrt = getfield_tfunc (call. rt, Const (1 ))
64+ @static if VERSION ≥ v " 1.11.0-DEV.945"
65+ return CallMeta (newrt, call. exct, call. effects, RRuleInfo (call. rt, call. info))
66+ else
67+ return CallMeta (newrt, call. exct, call. effects, RRuleInfo (call. rt, call. info))
68+ end
6069 end
6170 end
6271 end
@@ -74,26 +83,39 @@ function Core.Compiler.abstract_call_gf_by_type(interp::ADInterpreter, @nospecia
7483 return ret
7584end
7685
77- function abstract_accum (interp:: AbstractInterpreter , args :: Vector{Any} , sv:: InferenceState )
78- args = filter (x -> ! (widenconst (x) <: Union{ZeroTangent, NoTangent} ), args )
86+ function abstract_accum (interp:: AbstractInterpreter , argtypes :: Vector{Any} , sv:: InferenceState )
87+ argtypes = filter (@nospecialize (x) -> ! (widenconst (x) <: Union{ZeroTangent, NoTangent} ), argtypes )
7988
80- if length (args) == 0
81- return CallMeta (ZeroTangent, Effects (), nothing )
89+ if length (argtypes) == 0
90+ @static if VERSION ≥ v " 1.11.0-DEV.945"
91+ return CallMeta (ZeroTangent, Any, Effects (), NoCallInfo ())
92+ else
93+ return CallMeta (ZeroTangent, Effects (), NoCallInfo ())
94+ end
8295 end
8396
84- if length (args) == 1
85- return CallMeta (args[1 ], Effects (), nothing )
97+ if length (argtypes) == 1
98+ @static if VERSION ≥ v " 1.11.0-DEV.945"
99+ return CallMeta (argtypes[1 ], Any, Effects (), NoCallInfo ())
100+ else
101+ return CallMeta (argtypes[1 ], Effects (), NoCallInfo ())
102+ end
86103 end
87104
88- rtype = reduce (tmerge, args )
105+ rtype = reduce (tmerge, argtypes )
89106 if widenconst (rtype) <: Tuple
90107 targs = Any[]
91108 for i = 1 : nfields_tfunc (rtype). val
92- push! (targs, abstract_accum (interp, Any[getfield_tfunc (arg, Const (i)) for arg in args], sv). rt)
109+ push! (targs, abstract_accum (interp, Any[getfield_tfunc (arg, Const (i)) for arg in argtypes], sv). rt)
110+ end
111+ rt = tuple_tfunc (targs)
112+ @static if VERSION ≥ v " 1.11.0-DEV.945"
113+ return CallMeta (rt, Any, Effects (), NoCallInfo ())
114+ else
115+ return CallMeta (rt, Effects (), NoCallInfo ())
93116 end
94- return CallMeta (tuple_tfunc (targs), nothing )
95117 end
96- call = abstract_call (change_level (interp, 0 ), nothing , Any[typeof (accum), args ... ],
118+ call = abstract_call (change_level (interp, 0 ), nothing , Any[typeof (accum), argtypes ... ],
97119 sv:: InferenceState )
98120 return call
99121end
@@ -249,7 +271,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
249271 ft = argextype (inst. args[1 ], primal, primal. sptypes)
250272 f = singleton_type (ft)
251273 if isa (f, Core. Builtin)
252- call = CallMeta (backwards_tfunc (f, primal, inst, Δ), nothing )
274+ rt = backwards_tfunc (f, primal, inst, Δ)
275+ @static if VERSION ≥ v " 1.11.0-DEV.945"
276+ call = CallMeta (rt, Any, Effects (), NoCallInfo ())
277+ else
278+ call = CallMeta (rt, Effects (), NoCallInfo ())
279+ end
253280 else
254281 bail! (inst)
255282 continue
@@ -265,7 +292,12 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
265292 arg = getfield_tfunc (Δ, Const (1 ))
266293 call = abstract_call (interp, nothing , Any[clos, arg], sv)
267294 # No derivative wrt the functor
268- call = CallMeta (tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ]), ReifyInfo (call. info))
295+ rt = tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ])
296+ @static if VERSION ≥ v " 1.11.0-DEV.945"
297+ call = CallMeta (rt, Any, Effects (), ReifyInfo (call. info))
298+ else
299+ call = CallMeta (rt, Effects (), ReifyInfo (call. info))
300+ end
269301 else
270302 (level, close) = derive_closure_type (call_info)
271303 call = abstract_call (change_level (interp, level), ArgInfo (nothing , Any[close, Δ]), sv)
@@ -274,13 +306,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
274306
275307 if isa (info, UnionSplitApplyCallInfo)
276308 argts = Any[argextype (inst. args[i], primal, primal. sptypes) for i = 4 : length (inst. args)]
277- call = CallMeta (repackage_apply_rt (info, call. rt, argts),
278- UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)]))
309+ rt = repackage_apply_rt (info, call. rt, argts)
310+ newinfo = UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)])
311+ @static if VERSION ≥ v " 1.11.0-DEV.945"
312+ call = CallMeta (rt, Any, Effects (), newinfo)
313+ else
314+ call = CallMeta (rt, Effects (), newinfo)
315+ end
279316 end
280317
281318 if isa (call_info, ReifyInfo)
282319 new_rt = tuple_tfunc (Any[derive_closure_type (call. info)[2 ]; call. rt])
283- call = CallMeta (new_rt, RecurseInfo (call. info))
320+ newinfo = RecurseInfo (call. info)
321+ @static if VERSION ≥ v " 1.11.0-DEV.945"
322+ call = CallMeta (new_rt, Any, Effects (), newinfo)
323+ else
324+ call = CallMeta (new_rt, Effects (), newinfo)
325+ end
284326 end
285327
286328 if call. rt === Union{}
@@ -312,15 +354,23 @@ function infer_cc_backward(interp::ADInterpreter, cc::AbstractCompClosure, @nosp
312354 accum_call = abstract_accum (interp, this_arg_typs, sv)
313355 if accum_call. rt == Union{}
314356 @show accum_call. rt
315- return CallMeta (Union{}, false )
357+ @static if VERSION ≥ v " 1.11.0-DEV.945"
358+ return CallMeta (Union{}, Any, Effects (), NoCallInfo ())
359+ else
360+ return CallMeta (Union{}, Effects (), NoCallInfo ())
361+ end
316362 end
317363 push! (arg_accums, accum_call)
318364 tup_push! (tup_elemns, accum_call. rt)
319365 end
320366 end
321367
322368 rt = tuple_tfunc (Any[tup_elemns... ])
369+ @static if VERSION ≥ v " 1.11.0-DEV.945"
370+ return CallMeta (rt, Any, Effects (), CompClosInfo (cc, ssa_infos))
371+ else
323372 return CallMeta (rt, Effects (), CompClosInfo (cc, ssa_infos))
373+ end
324374end
325375
326376function infer_cc_forward (interp:: ADInterpreter , cc:: AbstractCompClosure , @nospecialize (cc_Δ), sv:: InferenceState )
@@ -389,7 +439,11 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
389439
390440 if isa (inst, ReturnNode)
391441 rt = accum_arg (inst. val)
392- return CallMeta (rt, CompClosInfo (cc, ssa_infos))
442+ @static if VERSION ≥ v " 1.11.0-DEV.945"
443+ return CallMeta (rt, Any, Effects (), CompClosInfo (cc, ssa_infos))
444+ else
445+ return CallMeta (rt, Effects (), CompClosInfo (cc, ssa_infos))
446+ end
393447 end
394448
395449 args = Any[]
@@ -451,7 +505,12 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
451505 arg = getfield_tfunc (Δ, Const (2 ))
452506 call = abstract_call (interp, nothing , Any[clos, arg], sv)
453507 # No derivative wrt the functor
454- call = CallMeta (tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ]), ReifyInfo (call. info))
508+ newrt = tuple_tfunc (Any[NoTangent; tuple_type_fields (call. rt)... ])
509+ @static if VERSION ≥ v " 1.11.0-DEV.945"
510+ call = CallMeta (newrt, Any, Effects (), ReifyInfo (call. info))
511+ else
512+ call = CallMeta (newrt, Effects (), ReifyInfo (call. info))
513+ end
455514 # error()
456515 else
457516 (level, clos) = derive_closure_type (call_info)
@@ -461,11 +520,20 @@ function infer_cc_forward(interp::ADInterpreter, cc::AbstractCompClosure, @nospe
461520
462521 if isa (call_info, ReifyInfo)
463522 new_rt = tuple_tfunc (Any[call. rt; derive_closure_type (call. info)[2 ]])
464- call = CallMeta (new_rt, RecurseInfo ())
523+ @static if VERSION ≥ v " 1.11.0-DEV.945"
524+ call = CallMeta (new_rt, Any, Effects (), RecurseInfo ())
525+ else
526+ call = CallMeta (new_rt, Effects (), RecurseInfo ())
527+ end
465528 end
466529
467530 if isa (info, UnionSplitApplyCallInfo)
468- call = CallMeta (call. rt, UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)]))
531+ newinfo = UnionSplitApplyCallInfo ([ApplyCallInfo (call. info)])
532+ @static if VERSION ≥ v " 1.11.0-DEV.945"
533+ call = CallMeta (call. rt, call. exct, Effects (), newinfo)
534+ else
535+ call = CallMeta (call. rt, Effects (), newinfo)
536+ end
469537 end
470538
471539 accums[i] = call. rt
@@ -485,13 +553,16 @@ function infer_comp_closure(interp::ADInterpreter, cc::AbstractCompClosure, @nos
485553end
486554
487555function infer_prim_closure (interp:: ADInterpreter , pc:: PrimClosure , @nospecialize (Δ), sv:: InferenceState )
488- @show (" enter" , pc)
489-
490556 if pc. seq == 1
491557 call = abstract_call (change_level (interp, pc. order), nothing , Any[pc. dual, Δ], sv)
492558 rt = call. rt
493559 @show (pc, Δ, rt)
494- return CallMeta (call. rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
560+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
561+ @static if VERSION ≥ v " 1.11.0-DEV.945"
562+ return CallMeta (call. rt, call. exct, Effects (), newinfo)
563+ else
564+ return CallMeta (call. rt, Effects (), newinfo)
565+ end
495566 elseif pc. seq == 2
496567 ni = change_level (interp, pc. order)
497568 mi′ = specialize_method (pc. info_below. results. matches[1 ], true )
@@ -500,8 +571,12 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
500571 call = infer_comp_closure (ni, cc, Δ, sv)
501572 rt = getfield_tfunc (call. rt, Const (2 ))
502573 @show (pc, Δ, rt)
503- return CallMeta (rt,
504- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried)))
574+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried))
575+ @static if VERSION ≥ v " 1.11.0-DEV.945"
576+ return CallMeta (rt, Any, Effects (), newinfo)
577+ else
578+ return CallMeta (rt, Effects (), newinfo)
579+ end
505580 elseif pc. seq == 3
506581 ni = change_level (interp, pc. order)
507582 mi′ = specialize_method (pc. info_carried. info. results. matches[1 ], true )
@@ -511,41 +586,62 @@ function infer_prim_closure(interp::ADInterpreter, pc::PrimClosure, @nospecializ
511586 Any[clos, tuple_tfunc (Any[Δ, pc. dual])], sv)
512587 rt = tuple_tfunc (Any[tuple_type_fields (call. rt)[2 : end ]. .. ])
513588 @show (pc, Δ, rt)
514- return CallMeta (rt,
515- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
589+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
590+ @static if VERSION ≥ v " 1.11.0-DEV.945"
591+ return CallMeta (rt, Any, Effects (), newinfo)
592+ else
593+ return CallMeta (rt, Effects (), newinfo)
594+ end
516595 elseif mod (pc. seq, 4 ) == 0
517596 info = pc. info_below
518597 clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
519-
520598 # Add back gradient w.r.t. rrule
521599 Δ = tuple_tfunc (Any[NoTangent, tuple_type_fields (Δ)... ])
522600 call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, Δ], sv)
523601 rt = getfield_tfunc (call. rt, Const (1 ))
524602 @show (pc, Δ, rt)
525- return CallMeta (rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (2 )), call. info, pc. info_carried)))
603+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (2 )), call. info, pc. info_carried))
604+ @static if VERSION ≥ v " 1.11.0-DEV.945"
605+ return CallMeta (rt, Any, Effects (), newinfo)
606+ else
607+ return CallMeta (rt, Effects (), newinfo)
608+ end
526609 elseif mod (pc. seq, 4 ) == 1
527610 info = pc. info_carried
528611 clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
529612 call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, tuple_tfunc (Any[pc. dual, Δ])], sv)
530613 rt = call. rt
531614 @show (pc, Δ, rt)
532- return CallMeta (call. rt, PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
615+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
616+ @static if VERSION ≥ v " 1.11.0-DEV.945"
617+ return CallMeta (rt, Any, Effects (), newinfo)
618+ else
619+ return CallMeta (rt, Effects (), newinfo)
620+ end
533621 elseif mod (pc. seq, 4 ) == 2
534622 info = pc. info_below
535623 clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
536624 call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, Δ], sv)
537625 rt = getfield_tfunc (call. rt, Const (2 ))
538626 @show (pc, Δ, rt)
539- return CallMeta (rt,
540- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried)))
627+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , getfield_tfunc (call. rt, Const (1 )), call. info, pc. info_carried))
628+ @static if VERSION ≥ v " 1.11.0-DEV.945"
629+ return CallMeta (rt, Any, Effects (), newinfo)
630+ else
631+ return CallMeta (rt, Effects (), newinfo)
632+ end
541633 elseif mod (pc. seq, 4 ) == 3
542634 info = pc. info_carried
543635 clos = AbstractCompClosure (info. clos. order, info. clos. seq + 1 , info. clos. primal_info, info. infos)
544636 call = abstract_call (change_level (interp, pc. order), nothing , Any[clos, tuple_tfunc (Any[Δ, pc. dual])], sv)
545637 rt = tuple_tfunc (Any[tuple_type_fields (call. rt)[2 : end ]. .. ])
546638 @show (pc, Δ, rt)
547- return CallMeta (rt,
548- PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below)))
639+ newinfo = PrimClosInfo (PrimClosure (pc. name, pc. order, pc. seq + 1 , nothing , call. info, pc. info_below))
640+ @static if VERSION ≥ v " 1.11.0-DEV.945"
641+ return CallMeta (rt, Any, Effects (), newinfo)
642+ else
643+ return CallMeta (rt, Effects (), newinfo)
644+ end
549645 end
550646 error ()
551647end
@@ -556,8 +652,7 @@ function Core.Compiler.abstract_call_opaque_closure(interp::ADInterpreter,
556652 if isa (closure. source, AbstractCompClosure)
557653 (;argtypes) = arginfo
558654 if length (argtypes) != = 2
559- error ()
560- return CallMeta (Union{}, false )
655+ error (" bad argtypes" )
561656 end
562657 return infer_comp_closure (interp, closure. source, argtypes[2 ], sv)
563658 elseif isa (closure. source, PrimClosure)
0 commit comments