Skip to content

Commit

Permalink
Better handling for Union-type fields, particularly of singletons (#4…
Browse files Browse the repository at this point in the history
…3163)

fix #43123

(cherry picked from commit d44a534)
  • Loading branch information
vtjnash authored and KristofferC committed Nov 26, 2021
1 parent 7249528 commit 8906af3
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 72 deletions.
43 changes: 24 additions & 19 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2147,22 +2147,25 @@ static bool emit_getfield_unknownidx(jl_codectx_t &ctx,
return false;
}

static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex, jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl)
{
Instruction *tindex0 = tbaa_decorate(tbaa_unionselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
//tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
// ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
// ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex,
jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl,
unsigned union_max, MDNode *tbaa_ptindex)
{
Instruction *tindex0 = tbaa_decorate(tbaa_ptindex, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex0);
if (mutabl) {
if (fsz > 0 && mutabl) {
// move value to an immutable stack slot (excluding tindex)
Type *ET = IntegerType::get(jl_LLVMContext, 8 * al);
AllocaInst *lv = emit_static_alloca(ctx, ET);
lv->setOperand(0, ConstantInt::get(T_int32, (fsz + al - 1) / al));
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (fsz + al - 1) / al);
AllocaInst *lv = emit_static_alloca(ctx, AT);
if (al > 1)
lv->setAlignment(Align(al));
emit_memcpy(ctx, lv, tbaa, addr, tbaa, fsz, al);
addr = lv;
}
return mark_julia_slot(addr, jfty, tindex, tbaa);
return mark_julia_slot(fsz > 0 ? addr : nullptr, jfty, tindex, tbaa);
}

// If `nullcheck` is not NULL and a pointer NULL check is necessary
Expand Down Expand Up @@ -2236,7 +2239,8 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
}
else if (jl_is_uniontype(jfty)) {
size_t fsz = 0, al = 0;
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
int union_max = jl_islayout_inline(jfty, &fsz, &al);
bool isptr = (union_max == 0);
assert(!isptr && fsz == jl_field_size(jt, idx) - 1); (void)isptr;
Value *ptindex;
if (isboxed) {
Expand All @@ -2246,7 +2250,7 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
else {
ptindex = emit_struct_gep(ctx, cast<StructType>(lt), staddr, byte_offset + fsz);
}
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl);
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl, union_max, tbaa_unionselbyte);
}
assert(jl_is_concrete_type(jfty));
if (!jt->name->mutabl && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
Expand Down Expand Up @@ -3298,7 +3302,8 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
jl_value_t *jfty = jl_field_type(sty, idx0);
if (!jl_field_isptr(sty, idx0) && jl_is_uniontype(jfty)) {
size_t fsz = 0, al = 0;
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
int union_max = jl_islayout_inline(jfty, &fsz, &al);
bool isptr = (union_max == 0);
assert(!isptr && fsz == jl_field_size(sty, idx0) - 1); (void)isptr;
// compute tindex from rhs
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
Expand All @@ -3310,9 +3315,9 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
BasicBlock *BB = ctx.builder.GetInsertBlock();
jl_cgval_t oldval = rhs;
if (!issetfield)
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
Value *Success;
BasicBlock *DoneBB;
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
Value *Success = NULL;
BasicBlock *DoneBB = NULL;
if (isreplacefield || ismodifyfield) {
if (ismodifyfield) {
if (needlock)
Expand All @@ -3329,13 +3334,13 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
emit_typecheck(ctx, rhs, jfty, fname);
rhs = update_julia_type(ctx, rhs, jfty);
}
rhs_union = convert_julia_type(ctx, rhs, jfty);
rhs_union = convert_julia_type(ctx, rhs, jfty);
if (rhs_union.typ == jl_bottom_type)
return jl_cgval_t();
if (needlock)
emit_lockstate_value(ctx, strct, true);
cmp = oldval;
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
}
BasicBlock *XchgBB = BasicBlock::Create(jl_LLVMContext, "xchg", ctx.f);
DoneBB = BasicBlock::Create(jl_LLVMContext, "done_xchg", ctx.f);
Expand Down
94 changes: 45 additions & 49 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2375,25 +2375,18 @@ static Value *emit_box_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const
Value *nullcheck1, Value *nullcheck2)
{
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
assert((arg1.isboxed || arg1.constant) && (arg2.isboxed || arg2.constant) &&
"Expected unboxed cases to be handled earlier");
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : arg1.V;
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : arg2.V;
varg1 = maybe_decay_tracked(ctx, varg1);
varg2 = maybe_decay_tracked(ctx, varg2);
if (cast<PointerType>(varg1->getType())->getAddressSpace() != cast<PointerType>(varg2->getType())->getAddressSpace()) {
varg1 = decay_derived(ctx, varg1);
varg2 = decay_derived(ctx, varg2);
}
return ctx.builder.CreateICmpEQ(emit_bitcast(ctx, varg1, T_pint8),
emit_bitcast(ctx, varg2, T_pint8));
// if we can be certain we won't try to load from the pointer (because
// we know boxed is trivial), we can skip the separate null checks
// and just do the ICmpEQ test
if (!arg1.TIndex && !arg2.TIndex)
nullcheck1 = nullcheck2 = nullptr;
}

return emit_nullcheck_guard2(ctx, nullcheck1, nullcheck2, [&] {
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg1).V, T_pjlvalue);
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg2).V, T_pjlvalue);
varg1 = decay_derived(ctx, varg1);
varg2 = decay_derived(ctx, varg2);
Value *varg1 = decay_derived(ctx, boxed(ctx, arg1));
Value *varg2 = decay_derived(ctx, boxed(ctx, arg2));
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
return ctx.builder.CreateICmpEQ(varg1, varg2);
}
Value *neq = ctx.builder.CreateICmpNE(varg1, varg2);
return emit_guarded_test(ctx, neq, true, [&] {
Value *dtarg = emit_typeof_boxed(ctx, arg1);
Expand Down Expand Up @@ -2938,28 +2931,28 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
*ret = ghostValue(ety);
}
else if (!isboxed && jl_is_uniontype(ety)) {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
// isbits union selector bytes are stored after a->maxsize
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *data = emit_arrayptr(ctx, ary, ary_ex);
Value *offset = emit_arrayoffset(ctx, ary, nd);
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
Value *ptindex;
if (elsz == 0) {
ptindex = data;
}
else {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
data = emit_bitcast(ctx, data, AT->getPointerTo());
// isbits union selector bytes are stored after a->maxsize
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
}
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
Instruction *tindex = tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
tindex->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
AllocaInst *lv = emit_static_alloca(ctx, AT);
if (al > 1)
lv->setAlignment(Align(al));
emit_memcpy(ctx, lv, tbaa_arraybuf, ctx.builder.CreateInBoundsGEP(AT, data, idx), tbaa_arraybuf, elsz, al, false);
*ret = mark_julia_slot(lv, ety, ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex), tbaa_arraybuf);
*ret = emit_unionload(ctx, data, ptindex, ety, elsz, al, tbaa_arraybuf, true, union_max, tbaa_arrayselbyte);
}
else {
MDNode *aliasscope = (f == jl_builtin_const_arrayref) ? ctx.aliasscope : nullptr;
Expand Down Expand Up @@ -3045,28 +3038,31 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
if (!isboxed && jl_is_uniontype(ety)) {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
Value *offset = emit_arrayoffset(ctx, ary, nd);
// compute tindex from val
jl_cgval_t rhs_union = convert_julia_type(ctx, val, ety);
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, ety);
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *offset = emit_arrayoffset(ctx, ary, nd);
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
Value *ptindex;
if (elsz == 0) {
ptindex = data;
}
else {
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
}
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateStore(tindex, ptindex));
if (jl_is_datatype(val.typ) && jl_datatype_size(val.typ) == 0) {
// no-op
}
else {
// copy data
Value *addr = ctx.builder.CreateInBoundsGEP(AT, data, idx);
emit_unionmove(ctx, addr, tbaa_arraybuf, val, nullptr);
if (elsz > 0 && (!jl_is_datatype(val.typ) || jl_datatype_size(val.typ) > 0)) {
// copy data (if any)
emit_unionmove(ctx, data, tbaa_arraybuf, val, nullptr);
}
}
else {
Expand Down
13 changes: 9 additions & 4 deletions src/rtutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
n += jl_printf(out, ")}[");
size_t j, tlen = jl_array_len(v);
jl_array_t *av = (jl_array_t*)v;
jl_datatype_t *el_type = (jl_datatype_t*)jl_tparam0(vt);
jl_value_t *el_type = jl_tparam0(vt);
char *typetagdata = (!av->flags.ptrarray && jl_is_uniontype(el_type)) ? jl_array_typetagdata(av) : NULL;
int nlsep = 0;
if (av->flags.ptrarray) {
// print arrays with newlines, unless the elements are probably small
for (j = 0; j < tlen; j++) {
jl_value_t *p = jl_array_ptr_ref(av, j);
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
jl_value_t *p = *ptr;
if (p != NULL && (uintptr_t)p >= 4096U) {
jl_value_t *p_ty = jl_typeof(p);
if ((uintptr_t)p_ty >= 4096U) {
Expand All @@ -1022,11 +1024,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
n += jl_printf(out, "\n ");
for (j = 0; j < tlen; j++) {
if (av->flags.ptrarray) {
n += jl_static_show_x(out, jl_array_ptr_ref(v, j), depth);
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
n += jl_static_show_x(out, *ptr, depth);
}
else {
char *ptr = ((char*)av->data) + j * av->elsize;
n += jl_static_show_x_(out, (jl_value_t*)ptr, el_type, depth);
n += jl_static_show_x_(out, (jl_value_t*)ptr,
typetagdata ? (jl_datatype_t*)jl_nth_union_component(el_type, typetagdata[j]) : (jl_datatype_t*)el_type,
depth);
}
if (j != tlen - 1)
n += jl_printf(out, nlsep ? ",\n " : ", ");
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,13 @@ function f42645()
res
end
@test ((f42645()::B42645).y::A42645{Int}).x

# issue #43123
@noinline cmp43123(a::Some, b::Some) = something(a) === something(b)
@noinline cmp43123(a, b) = a[] === b[]
@test cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(+))
@test !cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(-))
@test cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(+))
@test !cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(-))
@test cmp43123(Function[+], Union{typeof(+), typeof(-)}[+])
@test !cmp43123(Function[+], Union{typeof(+), typeof(-)}[-])

0 comments on commit 8906af3

Please sign in to comment.