diff --git a/src/codegen.cpp b/src/codegen.cpp index a51f9ea2020ca..11213e54a0450 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -5273,7 +5273,11 @@ static jl_cgval_t emit_call_specfun_other(jl_codectx_t &ctx, bool is_opaque_clos AllocaInst *result = nullptr; if (returninfo.cc == jl_returninfo_t::SRet || returninfo.cc == jl_returninfo_t::Union) { - result = emit_static_alloca(ctx, returninfo.union_bytes, Align(returninfo.union_align)); + if (returninfo.all_roots) { + result = emit_static_roots(ctx, returninfo.union_bytes / sizeof(void *)); + } else { + result = emit_static_alloca(ctx, returninfo.union_bytes, Align(returninfo.union_align)); + } setName(ctx.emission_context, result, "sret_box"); argvals[idx] = result; idx++; @@ -8179,7 +8183,6 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module Type *rt = NULL; Type *srt = NULL; Type *T_prjlvalue = PointerType::get(M->getContext(), AddressSpace::Tracked); - bool all_roots = false; uint64_t tracked_count = 0; if (jlrettype == (jl_value_t*)jl_bottom_type) { rt = getVoidTy(M->getContext()); @@ -8199,8 +8202,8 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module // convert all_roots to only union_bytes props.union_bytes = return_roots * sizeof(void*); props.union_minalign = props.union_align = sizeof(void*); + //props.all_roots = true; //return_roots = 0; - //all_roots = true; } props.return_roots = (int) return_roots; if (props.union_bytes) { @@ -8225,7 +8228,6 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module if (rt != getVoidTy(M->getContext()) && deserves_sret(jlrettype, rt)) { auto tracked = CountTrackedPointers(rt, true); assert(!tracked.derived); - all_roots = tracked.all; tracked_count = tracked.count; if (tracked.count && !tracked.all) { props.return_roots = tracked.count; @@ -8234,6 +8236,7 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module props.cc = jl_returninfo_t::SRet; props.union_bytes = jl_datatype_size(jlrettype); props.union_align = props.union_minalign = julia_alignment(jlrettype); + props.all_roots = tracked.all; // sret is always passed from alloca assert(M); fsig.push_back(PointerType::get(M->getContext(), M->getDataLayout().getAllocaAddrSpace())); @@ -8254,7 +8257,7 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module assert(srt); AttrBuilder param(M->getContext()); param.addStructRetAttr(srt); - if (all_roots) { + if (props.all_roots) { assert(!props.return_roots); param.addAttribute("julia.return_roots", std::to_string(tracked_count)); } @@ -8266,7 +8269,7 @@ static jl_returninfo_t get_specsig_function(jl_codegen_params_t ¶ms, Module } if (props.cc == jl_returninfo_t::Union) { AttrBuilder param(M->getContext()); - if (all_roots) { + if (props.all_roots) { assert(!props.return_roots); param.addAttribute("julia.return_roots", std::to_string(tracked_count)); } diff --git a/src/jitlayers.h b/src/jitlayers.h index 82b022e90cc99..64b5b4f8b8bac 100644 --- a/src/jitlayers.h +++ b/src/jitlayers.h @@ -216,6 +216,7 @@ struct jl_returninfo_t { size_t union_align; size_t union_minalign; unsigned return_roots; + bool all_roots; }; struct jl_codegen_call_target_t {