Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atomics: optimize atomic modify operations (mostly) #42017

Merged
merged 2 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions base/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,13 @@ for typ in atomictypes
rt = "$lt, $lt*"
irt = "$ilt, $ilt*"
@eval getindex(x::Atomic{$typ}) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = load atomic $rt %ptr acquire, align $(gc_alignment(typ))
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}}, unsafe_convert(Ptr{$typ}, x))
@eval setindex!(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
store atomic $lt %1, $lt* %ptr release, align $(gc_alignment(typ))
ret void
Expand All @@ -371,7 +371,7 @@ for typ in atomictypes
# Note: atomic_cas! succeeded (i.e. it stored "new") if and only if the result is "cmp"
if typ <: Integer
@eval atomic_cas!(x::Atomic{$typ}, cmp::$typ, new::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rs = cmpxchg $lt* %ptr, $lt %1, $lt %2 acq_rel acquire
%rv = extractvalue { $lt, i1 } %rs, 0
Expand All @@ -380,7 +380,7 @@ for typ in atomictypes
unsafe_convert(Ptr{$typ}, x), cmp, new)
else
@eval atomic_cas!(x::Atomic{$typ}, cmp::$typ, new::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%icmp = bitcast $lt %1 to $ilt
%inew = bitcast $lt %2 to $ilt
Expand All @@ -403,15 +403,15 @@ for typ in atomictypes
if rmwop in arithmetic_ops && !(typ <: ArithmeticTypes) continue end
if typ <: Integer
@eval $fn(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%ptr = inttoptr i$WORD_SIZE %0 to $lt*
%rv = atomicrmw $rmw $lt* %ptr, $lt %1 acq_rel
ret $lt %rv
""", $typ, Tuple{Ptr{$typ}, $typ}, unsafe_convert(Ptr{$typ}, x), v)
else
rmwop === :xchg || continue
@eval $fn(x::Atomic{$typ}, v::$typ) =
llvmcall($"""
GC.@preserve x llvmcall($"""
%iptr = inttoptr i$WORD_SIZE %0 to $ilt*
%ival = bitcast $lt %1 to $ilt
%irv = atomicrmw $rmw $ilt* %iptr, $ilt %ival acq_rel
Expand Down
23 changes: 13 additions & 10 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_apply(interp, argtypes, sv, max_methods)
elseif f === invoke
return abstract_invoke(interp, argtypes, sv)
elseif f === modifyfield!
return abstract_modifyfield!(interp, argtypes, sv)
end
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
elseif f === Core.kwfunc
Expand Down Expand Up @@ -1515,7 +1517,8 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
return abstract_eval_special_value(interp, e, vtypes, sv)
end
e = e::Expr
if e.head === :call
ehead = e.head
if ehead === :call
ea = e.args
argtypes = collect_argtypes(interp, ea, vtypes, sv)
if argtypes === nothing
Expand All @@ -1525,7 +1528,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
sv.stmt_info[sv.currpc] = callinfo.info
t = callinfo.rt
end
elseif e.head === :new
elseif ehead === :new
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !ismutabletype(t)
args = Vector{Any}(undef, length(e.args)-1)
Expand Down Expand Up @@ -1562,7 +1565,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
end
end
elseif e.head === :splatnew
elseif ehead === :splatnew
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
Expand All @@ -1575,7 +1578,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
t = PartialStruct(t, at.fields::Vector{Any})
end
end
elseif e.head === :new_opaque_closure
elseif ehead === :new_opaque_closure
t = Union{}
if length(e.args) >= 5
ea = e.args
Expand All @@ -1594,29 +1597,29 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
end
end
end
elseif e.head === :foreigncall
elseif ehead === :foreigncall
abstract_eval_value(interp, e.args[1], vtypes, sv)
t = sp_type_rewrap(e.args[2], sv.linfo, true)
for i = 3:length(e.args)
if abstract_eval_value(interp, e.args[i], vtypes, sv) === Bottom
t = Bottom
end
end
elseif e.head === :cfunction
elseif ehead === :cfunction
t = e.args[1]
isa(t, Type) || (t = Any)
abstract_eval_cfunction(interp, e, vtypes, sv)
elseif e.head === :method
elseif ehead === :method
t = (length(e.args) == 1) ? Any : Nothing
elseif e.head === :copyast
elseif ehead === :copyast
t = abstract_eval_value(interp, e.args[1], vtypes, sv)
if t isa Const && t.val isa Expr
# `copyast` makes copies of Exprs
t = Expr
end
elseif e.head === :invoke
elseif ehead === :invoke || ehead === :invoke_modify
error("type inference data-flow error: tried to double infer a function")
elseif e.head === :isdefined
elseif ehead === :isdefined
sym = e.args[1]
t = Bool
if isa(sym, SlotNumber)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
return 0
end
return error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif head === :foreigncall || head === :invoke
elseif head === :foreigncall || head === :invoke || head == :invoke_modify
# Calls whose "return type" is Union{} do not actually return:
# they are errors. Since these are not part of the typical
# run-time of the function, we omit them from
Expand Down
16 changes: 16 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,22 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
ir.stmts[idx][:inst] = res
return nothing
end
if (sig.f === modifyfield! || sig.ft ⊑ typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
let info = ir.stmts[idx][:info]
info isa MethodResultPure && (info = info.info)
info isa ConstCallInfo && (info = info.call)
info isa MethodMatchInfo || return nothing
length(info.results) == 1 || return nothing
match = info.results[1]::MethodMatch
match.fully_covers || return nothing
case = compileable_specialization(state.et, match)
case === nothing && return nothing
stmt.head = :invoke_modify
pushfirst!(stmt.args, case)
ir.stmts[idx][:inst] = stmt
end
return nothing
end

check_effect_free!(ir, stmt, calltype, idx)

Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ function getindex(x::UseRef)
end

function is_relevant_expr(e::Expr)
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
return e.head in (:call, :invoke, :invoke_modify,
:new, :splatnew, :(=), :(&),
:gc_preserve_begin, :gc_preserve_end,
:foreigncall, :isdefined, :copyast,
:undefcheck, :throw_undef_if_not,
Expand Down
32 changes: 31 additions & 1 deletion base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -939,10 +939,40 @@ function modifyfield!_tfunc(o, f, op, v)
@nospecialize
T = _fieldtype_tfunc(o, isconcretetype(o), f)
T === Bottom && return Bottom
# note: we could sometimes refine this to a PartialStruct if we analyzed `op(o.f, v)::T`
PT = Const(Pair)
return instanceof_tfunc(apply_type_tfunc(PT, T, T))[1]
end
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
nargs = length(argtypes)
if !isempty(argtypes) && isvarargtype(argtypes[nargs])
nargs - 1 <= 6 || return CallMeta(Bottom, false)
nargs > 3 || return CallMeta(Any, false)
else
5 <= nargs <= 6 || return CallMeta(Bottom, false)
end
o = unwrapva(argtypes[2])
f = unwrapva(argtypes[3])
RT = modifyfield!_tfunc(o, f, Any, Any)
info = false
if nargs >= 5 && RT !== Bottom
# we may be able to refine this to a PartialStruct by analyzing `op(o.f, v)::T`
# as well as compute the info for the method matches
op = unwrapva(argtypes[4])
v = unwrapva(argtypes[5])
TF = getfield_tfunc(o, f)
push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call
callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1)
aviatesk marked this conversation as resolved.
Show resolved Hide resolved
pop!(sv.ssavalue_uses[sv.currpc], sv.currpc)
TF2 = tmeet(callinfo.rt, widenconst(TF))
if TF2 === Bottom
RT = Bottom
elseif isconcretetype(RT) && has_nontrivial_const_info(TF2) # isconcrete condition required to form a PartialStruct
RT = PartialStruct(RT, Any[TF, TF2])
end
info = callinfo.info
end
return CallMeta(RT, info)
end
replacefield!_tfunc(o, f, x, v, success_order, failure_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
replacefield!_tfunc(o, f, x, v, success_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
function replacefield!_tfunc(o, f, x, v)
Expand Down
8 changes: 5 additions & 3 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
:call => 1:typemax(Int),
:invoke => 2:typemax(Int),
:invoke_modify => 3:typemax(Int),
:static_parameter => 1:1,
:(&) => 1:1,
:(=) => 2:2,
Expand Down Expand Up @@ -78,7 +79,7 @@ end

function _validate_val!(@nospecialize(x), errors, ssavals::BitSet)
if isa(x, Expr)
if x.head === :call || x.head === :invoke
if x.head === :call || x.head === :invoke || x.head === :invoke_modify
f = x.args[1]
if f isa GlobalRef && (f.name === :cglobal) && x.head === :call
# TODO: these are not yet linearized
Expand Down Expand Up @@ -138,7 +139,8 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
end
validate_val!(lhs)
validate_val!(rhs)
elseif head === :call || head === :invoke || head === :gc_preserve_end || head === :meta ||
elseif head === :call || head === :invoke || x.head === :invoke_modify ||
head === :gc_preserve_end || head === :meta ||
head === :inbounds || head === :foreigncall || head === :cfunction ||
head === :const || head === :enter || head === :leave || head === :pop_exception ||
head === :method || head === :global || head === :static_parameter ||
Expand Down Expand Up @@ -238,7 +240,7 @@ end

function is_valid_rvalue(@nospecialize(x))
is_valid_argument(x) && return true
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :invoke_modify, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
return true
end
return false
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ extern "C" {

// head symbols for each expression type
jl_sym_t *call_sym; jl_sym_t *invoke_sym;
jl_sym_t *invoke_modify_sym;
jl_sym_t *empty_sym; jl_sym_t *top_sym;
jl_sym_t *module_sym; jl_sym_t *slot_sym;
jl_sym_t *export_sym; jl_sym_t *import_sym;
Expand Down Expand Up @@ -345,6 +346,7 @@ void jl_init_common_symbols(void)
empty_sym = jl_symbol("");
call_sym = jl_symbol("call");
invoke_sym = jl_symbol("invoke");
invoke_modify_sym = jl_symbol("invoke_modify");
foreigncall_sym = jl_symbol("foreigncall");
cfunction_sym = jl_symbol("cfunction");
quote_sym = jl_symbol("quote");
Expand Down
49 changes: 30 additions & 19 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1547,17 +1547,23 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
Value *parent, // for the write barrier, NULL if no barrier needed
bool isboxed, AtomicOrdering Order, AtomicOrdering FailOrder, unsigned alignment,
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
bool maybe_null_if_boxed, const std::string &fname)
bool maybe_null_if_boxed, const jl_cgval_t *modifyop, const std::string &fname)
{
auto newval = [&](const jl_cgval_t &lhs) {
jl_cgval_t argv[3] = { cmp, lhs, rhs };
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
argv[0] = mark_julia_type(ctx, callval, true, jl_any_type);
if (!jl_subtype(argv[0].typ, jltype)) {
emit_typecheck(ctx, argv[0], jltype, fname + "typed_store");
argv[0] = update_julia_type(ctx, argv[0], jltype);
}
return argv[0];
const jl_cgval_t argv[3] = { cmp, lhs, rhs };
jl_cgval_t ret;
if (modifyop) {
ret = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
ret = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(ret.typ, jltype)) {
emit_typecheck(ctx, ret, jltype, fname + "typed_store");
ret = update_julia_type(ctx, ret, jltype);
}
return ret;
};
assert(!needlock || parent != nullptr);
Type *elty = isboxed ? T_prjlvalue : julia_type_to_llvm(ctx, jltype);
Expand All @@ -1570,7 +1576,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
else if (isreplacefield) {
Value *Success = emit_f_is(ctx, cmp, ghostValue(jltype));
Success = ctx.builder.CreateZExt(Success, T_int8);
jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
const jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand All @@ -1579,7 +1585,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
}
else { // modifyfield
jl_cgval_t oldval = ghostValue(jltype);
jl_cgval_t argv[2] = { oldval, newval(oldval) };
const jl_cgval_t argv[2] = { oldval, newval(oldval) };
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand Down Expand Up @@ -1862,7 +1868,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
}
}
if (ismodifyfield) {
jl_cgval_t argv[2] = { oldval, rhs };
const jl_cgval_t argv[2] = { oldval, rhs };
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand All @@ -1881,7 +1887,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
oldval = mark_julia_type(ctx, instr, isboxed, jltype);
if (isreplacefield) {
Success = ctx.builder.CreateZExt(Success, T_int8);
jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
const jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
}
Expand Down Expand Up @@ -3269,7 +3275,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
jl_cgval_t rhs, jl_cgval_t cmp,
bool checked, bool wb, AtomicOrdering Order, AtomicOrdering FailOrder,
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
const std::string &fname)
const jl_cgval_t *modifyop, const std::string &fname)
{
if (!sty->name->mutabl && checked) {
std::string msg = fname + "immutable struct of type "
Expand Down Expand Up @@ -3309,9 +3315,14 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
if (ismodifyfield) {
if (needlock)
emit_lockstate_value(ctx, strct, false);
jl_cgval_t argv[3] = { cmp, oldval, rhs };
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
const jl_cgval_t argv[3] = { cmp, oldval, rhs };
if (modifyop) {
rhs = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
}
else {
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
}
if (!jl_subtype(rhs.typ, jfty)) {
emit_typecheck(ctx, rhs, jfty, fname);
rhs = update_julia_type(ctx, rhs, jfty);
Expand Down Expand Up @@ -3364,7 +3375,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
return typed_store(ctx, addr, NULL, rhs, cmp, jfty, strct.tbaa, nullptr,
wb ? maybe_bitcast(ctx, data_pointer(ctx, strct), T_pjlvalue) : nullptr,
isboxed, Order, FailOrder, align,
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, fname);
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, modifyop, fname);
}
}

Expand Down Expand Up @@ -3543,7 +3554,7 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg
else
need_wb = false;
emit_typecheck(ctx, rhs, jl_svecref(sty->types, i), "new");
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, "");
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, nullptr, "");
}
return strctinfo;
}
Expand Down
Loading