Skip to content

Commit

Permalink
Support Julia 1.12
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed May 2, 2024
1 parent 3907459 commit 898e85e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SIMD"
uuid = "fdea26ae-647d-5447-a871-4b548cad5224"
authors = ["Erik Schnetter <[email protected]>", "Kristoffer Carlsson <[email protected]>"]
version = "3.4.6"
version = "3.4.7"

[deps]
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down
61 changes: 46 additions & 15 deletions src/LLVM_intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ const d = Dict{DataType, String}(
Float64 => "double",
)
# Add the Ptr translations
foreach(x -> (d[Ptr{x}] = d[Int]), collect(keys(d)))
# Julia <=1.11 (LLVM <=16) passes `Ptr{T}` as `i64`, Julia >=1.12 (LLVM >=17) passes them as `T*`.
# Use `argtoptr` e.g. as `%ptr = $argtoptr $(d[Ptr{T}]) %0 to <$N x $(d[T])>*`
@static if VERSION >= v"1.12-DEV"
const argtoptr = "bitcast"
foreach(x -> (d[Ptr{x}] = "$(d[x])*"), collect(keys(d)))
else
const argtoptr = "inttoptr"
foreach(x -> (d[Ptr{x}] = "$(d[Int])"), collect(keys(d)))
end

# LT = LLVM Type (scalar and vectors), we keep type names intentionally short
# to make the signatures smaller
Expand Down Expand Up @@ -462,7 +470,7 @@ temporal_str(temporal) = temporal ? ", !nontemporal !{i32 1}" : ""
@generated function load(x::Type{LVec{N, T}}, ptr::Ptr{T},
::Val{Al}=Val(false), ::Val{Te}=Val(false)) where {N, T, Al, Te}
s = """
%ptr = inttoptr $(d[Int]) %0 to <$N x $(d[T])>*
%ptr = $argtoptr $(d[Ptr{T}]) %0 to <$N x $(d[T])>*
%res = load <$N x $(d[T])>, <$N x $(d[T])>* %ptr, align $(n_align(Al, N, T)) $(temporal_str(Te))
ret <$N x $(d[T])> %res
"""
Expand All @@ -478,10 +486,10 @@ end
mod = """
declare <$N x $(d[T])> @llvm.masked.load.$(suffix(N, T))(<$N x $(d[T])>*, i32, <$N x i1>, <$N x $(d[T])>)
define <$N x $(d[T])> @entry($(d[Int]), <$(N) x i8>) #0 {
define <$N x $(d[T])> @entry($(d[Ptr{T}]), <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %1 to <$(N) x i1>
%ptr = inttoptr $(d[Int]) %0 to <$N x $(d[T])>*
%ptr = $argtoptr $(d[Ptr{T}]) %0 to <$N x $(d[T])>*
%res = call <$N x $(d[T])> @llvm.masked.load.$(suffix(N, T))(<$N x $(d[T])>* %ptr, i32 $(n_align(Al, N, T)), <$N x i1> %mask, <$N x $(d[T])> zeroinitializer)
ret <$N x $(d[T])> %res
}
Expand All @@ -497,7 +505,7 @@ end
@generated function store(x::LVec{N, T}, ptr::Ptr{T},
::Val{Al}=Val(false), ::Val{Te}=Val(false)) where {N, T, Al, Te}
s = """
%ptr = inttoptr $(d[Int]) %1 to <$N x $(d[T])>*
%ptr = $argtoptr $(d[Ptr{T}]) %1 to <$N x $(d[T])>*
store <$N x $(d[T])> %0, <$N x $(d[T])>* %ptr, align $(n_align(Al, N, T)) $(temporal_str(Te))
ret void
"""
Expand All @@ -514,10 +522,10 @@ end
mod = """
declare void @llvm.masked.store.$(suffix(N, T))(<$N x $(d[T])>, <$N x $(d[T])>*, i32, <$N x i1>)
define void @entry(<$N x $(d[T])>, $(d[Int]), <$(N) x i8>) #0 {
define void @entry(<$N x $(d[T])>, $(d[Ptr{T}]), <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %2 to <$(N) x i1>
%ptr = inttoptr $(d[Int]) %1 to <$N x $(d[T])>*
%ptr = $argtoptr $(d[Ptr{T}]) %1 to <$N x $(d[T])>*
call void @llvm.masked.store.$(suffix(N, T))(<$N x $(d[T])> %0, <$N x $(d[T])>* %ptr, i32 $(n_align(Al, N, T)), <$N x i1> %mask)
ret void
}
Expand All @@ -535,10 +543,10 @@ end
mod = """
declare <$N x $(d[T])> @llvm.masked.expandload.$(suffix(N, T))($(d[T])*, <$N x i1>, <$N x $(d[T])>)
define <$N x $(d[T])> @entry($(d[Int]), <$(N) x i8>) #0 {
define <$N x $(d[T])> @entry($(d[Ptr{T}]), <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %1 to <$(N) x i1>
%ptr = inttoptr $(d[Int]) %0 to $(d[T])*
%ptr = $argtoptr $(d[Ptr{T}]) %0 to $(d[T])*
%res = call <$N x $(d[T])> @llvm.masked.expandload.$(suffix(N, T))($(d[T])* %ptr, <$N x i1> %mask, <$N x $(d[T])> zeroinitializer)
ret <$N x $(d[T])> %res
}
Expand All @@ -556,10 +564,10 @@ end
mod = """
declare void @llvm.masked.compressstore.$(suffix(N, T))(<$N x $(d[T])>, $(d[T])*, <$N x i1>)
define void @entry(<$N x $(d[T])>, $(d[Int]), <$(N) x i8>) #0 {
define void @entry(<$N x $(d[T])>, $(d[Ptr{T}]), <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %2 to <$(N) x i1>
%ptr = inttoptr $(d[Int]) %1 to $(d[T])*
%ptr = $argtoptr $(d[Ptr{T}]) %1 to $(d[T])*
call void @llvm.masked.compressstore.$(suffix(N, T))(<$N x $(d[T])> %0, $(d[T])* %ptr, <$N x i1> %mask)
ret void
}
Expand All @@ -583,10 +591,10 @@ end
mod = """
declare <$N x $(d[T])> @llvm.masked.gather.$(suffix(N, T))(<$N x $(d[T])*>, i32, <$N x i1>, <$N x $(d[T])>)
define <$N x $(d[T])> @entry(<$N x $(d[Int])>, <$(N) x i8>) #0 {
define <$N x $(d[T])> @entry(<$N x $(d[Ptr{T}])>, <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %1 to <$(N) x i1>
%ptrs = inttoptr <$N x $(d[Int])> %0 to <$N x $(d[T])*>
%ptrs = $argtoptr <$N x $(d[Ptr{T}])> %0 to <$N x $(d[T])*>
%res = call <$N x $(d[T])> @llvm.masked.gather.$(suffix(N, T))(<$N x $(d[T])*> %ptrs, i32 $(n_align(Al, N, T)), <$N x i1> %mask, <$N x $(d[T])> zeroinitializer)
ret <$N x $(d[T])> %res
}
Expand All @@ -604,10 +612,10 @@ end
mod = """
declare void @llvm.masked.scatter.$(suffix(N, T))(<$N x $(d[T])>, <$N x $(d[T])*>, i32, <$N x i1>)
define void @entry(<$N x $(d[T])>, <$N x $(d[Int])>, <$(N) x i8>) #0 {
define void @entry(<$N x $(d[T])>, <$N x $(d[Ptr{T}])>, <$(N) x i8>) #0 {
top:
%mask = trunc <$(N) x i8> %2 to <$(N) x i1>
%ptrs = inttoptr <$N x $(d[Int])> %1 to <$N x $(d[T])*>
%ptrs = $argtoptr <$N x $(d[Ptr{T}])> %1 to <$N x $(d[T])*>
call void @llvm.masked.scatter.$(suffix(N, T))(<$N x $(d[T])> %0, <$N x $(d[T])*> %ptrs, i32 $(n_align(Al, N, T)), <$N x i1> %mask)
ret void
}
Expand Down Expand Up @@ -751,6 +759,29 @@ for (fs, (from, to)) in zip([CONVERSION_FLOAT_TO_INT, CONVERSION_INT_TO_FL
end
end

@generated function inttoptr(::Type{LVec{N, Ptr{T2}}}, x::LVec{N, T1}) where {N, T1 <: IntegerTypes, T2 <: Union{IntegerTypes, FloatingTypes}}
convert = VERSION >= v"1.12-DEV" ? "inttoptr" : "bitcast"
s = """
%2 = $convert <$(N) x $(d[T1])> %0 to <$(N) x $(d[Ptr{T2}])>
ret <$(N) x $(d[Ptr{T2}])> %2
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, LVec{N, Ptr{T2}}, Tuple{LVec{N, T1}}, x)
)
end

@generated function ptrtoint(::Type{LVec{N, T2}}, x::LVec{N, Ptr{T1}}) where {N, T1 <: Union{IntegerTypes, FloatingTypes}, T2 <: IntegerTypes}
convert = VERSION >= v"1.12-DEV" ? "ptrtoint" : "bitcast"
s = """
%2 = $convert <$(N) x $(d[Ptr{T1}])> %0 to <$(N) x $(d[T2])>
ret <$(N) x $(d[T2])> %2
"""
return :(
$(Expr(:meta, :inline));
Base.llvmcall($s, LVec{N, T2}, Tuple{LVec{N, Ptr{T1}}}, x)
)
end

###########
# Bitcast #
Expand Down
38 changes: 22 additions & 16 deletions src/simdvec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ Base.copy(v::Vec) = v

@inline Base.convert(::Type{Vec{N,T}}, v::Vec{N,T}) where {N,T} = v
@inline function Base.convert(::Type{Vec{N, T1}}, v::Vec{N, T2}) where {T1, T2, N}
if T1 <: Union{IntegerTypes, Ptr}
if T2 <: Union{IntegerTypes, Ptr, Bool}
if T1 <: Ptr
return Vec(Intrinsics.inttoptr(Intrinsics.LVec{N, T1}, v.data))
elseif T1 <: IntegerTypes
if T2 <: Ptr
return Vec(Intrinsics.ptrtoint(Intrinsics.LVec{N, T1}, v.data))
elseif T2 <: Union{IntegerTypes, Bool}
if sizeof(T1) < sizeof(T2)
return Vec(Intrinsics.trunc(Intrinsics.LVec{N, T1}, v.data))
elseif sizeof(T1) == sizeof(T2)
Expand All @@ -43,8 +47,7 @@ Base.copy(v::Vec) = v
return Vec(Intrinsics.fptosi(Intrinsics.LVec{N, T1}, v.data))
end
end
end
if T1 <: FloatingTypes
elseif T1 <: FloatingTypes
if T2 <: UIntTypes
return Vec(Intrinsics.uitofp(Intrinsics.LVec{N, T1}, v.data))
elseif T2 <: IntTypes
Expand Down Expand Up @@ -303,18 +306,21 @@ _signed(::Type{Float64}) = Int64
signbit(reinterpret(Vec{N, _signed(T)}, x))

# Pointer arithmetic
for op in (:+, :-)
@eval begin
# Cast pointer to Int and back
@inline Base.$op(x::Vec{N,Ptr{T}}, y::Vec{N,Ptr{T}}) where {N,T} =
convert(Vec{N, Ptr{T}}, ($(op)(convert(Vec{N, Int}, x), convert(Vec{N, Int}, y))))
@inline Base.$op(x::Vec{N,Ptr{T}}, y::Union{IntegerTypes}) where {N,T} = $(op)(x, Vec{N,Ptr{T}}(y))
@inline Base.$op(x::IntegerTypes, y::Union{Vec{N,Ptr{T}}}) where {N,T} = $(op)(y, x)

@inline Base.$op(x::Vec{N,<:IntegerTypes}, y::Ptr{T}) where {N,T} = $(op)(Vec{N,Ptr{T}}(x), Vec{N,Ptr{T}}(y))
@inline Base.$op(x::Ptr{T}, y::Vec{N,<:IntegerTypes}) where {N,T} = $(op)(y, x)
end
end
# Cast pointer to Int and back
@inline Base.:+(x::Vec{N,Ptr{T}}, y::Vec{N,<:IntegerTypes}) where {N,T} = convert(Vec{N,Ptr{T}}, convert(Vec{N,UInt}, x) + y)
@inline Base.:-(x::Vec{N,Ptr{T}}, y::Vec{N,<:IntegerTypes}) where {N,T} = convert(Vec{N,Ptr{T}}, convert(Vec{N,UInt}, x) - y)
@inline Base.:+(x::Ptr{T}, y::Vec{N,<:IntegerTypes}) where {N,T} = convert(Vec{N,Ptr{T}}, convert(UInt, x) + y)
@inline Base.:-(x::Ptr{T}, y::Vec{N,<:IntegerTypes}) where {N,T} = convert(Vec{N,Ptr{T}}, convert(UInt, x) - y)
@inline Base.:+(x::Vec{N,Ptr{T}}, y::IntegerTypes) where {N,T} = convert(Vec{N,Ptr{T}}, convert(Vec{N,UInt}, x) + y)
@inline Base.:-(x::Vec{N,Ptr{T}}, y::IntegerTypes) where {N,T} = convert(Vec{N,Ptr{T}}, convert(Vec{N,UInt}, x) - y)

@inline Base.:+(y::Vec{N,<:IntegerTypes}, x::Vec{N,Ptr{T}}, ) where {N,T} = x + y
@inline Base.:+(y::Vec{N,<:IntegerTypes}, x::Ptr{T}) where {N,T} = x + y
@inline Base.:+(y::IntegerTypes, x::Vec{N,Ptr{T}}) where {N,T} = x + y

@inline Base.:-(x::Vec{N,Ptr{T}}, y::Vec{N,Ptr{T}}) where {N,T} = convert(Vec{N,Int}, x) - convert(Vec{N,Int}, y)
@inline Base.:-(x::Ptr{T}, y::Vec{N,Ptr{T}}) where {N,T} = convert(UInt, x) % Int - convert(Vec{N,Int}, y)
@inline Base.:-(x::Vec{N,Ptr{T}}, y::Ptr{T}) where {N,T} = convert(Vec{N,Int}, x) - convert(UInt, y) % Int

# Bitshifts
# See https://github.com/JuliaLang/julia/blob/7426625b5c07b0d93110293246089a259a0a677d/src/intrinsics.cpp#L1179-L1196
Expand Down

0 comments on commit 898e85e

Please sign in to comment.