Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vector type in AArch64 C abi #16645

Merged
merged 2 commits into from
May 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/manual/calling-c-and-fortran-code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ In the future, some of these restrictions may be reduced or eliminated.
SIMD Values
~~~~~~~~~~~

Note: This feature is currently implemented on 64-bit x86 platforms only.
Note: This feature is currently implemented on 64-bit x86
and AArch64 platforms only.

If a C/C++ routine has an argument or return value that is a native
SIMD type, the corresponding Julia type is a homogeneous tuple
Expand Down
75 changes: 65 additions & 10 deletions src/abi_aarch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,53 @@ namespace {
typedef bool AbiState;
static const AbiState default_abi_state = 0;

static Type *get_llvm_vectype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
// `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields > 0`
size_t nfields = dt->nfields;
assert(nfields > 0);
if (nfields < 2)
return nullptr;
static Type *T_vec64 = VectorType::get(T_int32, 2);
static Type *T_vec128 = VectorType::get(T_int32, 4);
Type *lltype;
// Short vector should be either 8 bytes or 16 bytes.
// Note that there are only two distinct fundamental types for
// short vectors so we normalize them to <2 x i32> and <4 x i32>
switch (dt->size) {
case 8:
lltype = T_vec64;
break;
case 16:
lltype = T_vec128;
break;
default:
return nullptr;
}
// Since `dt` is pointer free and has no padding and is 8 or 16 in size,
// `ft0` must be concrete, immutable with no padding and we don't need
// to check if its size is legal since it is included in
// the homogeneity check.
jl_datatype_t *ft0 = (jl_datatype_t*)jl_field_type(dt, 0);
// `ft0` should be a `VecElement` type and the true element type
// should be a `bitstype`
if (ft0->name != jl_vecelement_typename ||
((jl_datatype_t*)jl_field_type(ft0, 0))->nfields)
return nullptr;
for (int i = 1; i < nfields; i++) {
if (jl_field_type(dt, i) != (jl_value_t*)ft0) {
// Not homogeneous
return nullptr;
}
}
return lltype;
}

static Type *get_llvm_fptype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
if (dt->mutabl || jl_datatype_nfields(dt) != 0)
return NULL;
// `!dt->mutabl && dt->pointerfree && !dt->haspadding && dt->nfields == 0`
Type *lltype;
// Check size first since it's cheaper.
switch (dt->size) {
Expand All @@ -37,9 +79,17 @@ static Type *get_llvm_fptype(jl_datatype_t *dt)
lltype = T_float128;
break;
default:
return NULL;
return nullptr;
}
return jl_is_floattype((jl_value_t*)dt) ? lltype : NULL;
return jl_is_floattype((jl_value_t*)dt) ? lltype : nullptr;
}

static Type *get_llvm_fp_or_vectype(jl_datatype_t *dt)
{
// Assume jl_is_datatype(dt) && !jl_is_abstracttype(dt)
if (dt->mutabl || !dt->pointerfree || dt->haspadding)
return nullptr;
return dt->nfields ? get_llvm_vectype(dt) : get_llvm_fptype(dt);
}

struct ElementType {
Expand All @@ -50,8 +100,6 @@ struct ElementType {

// Whether a type is a homogeneous floating-point aggregates (HFA) or a
// homogeneous short-vector aggregates (HVA). Returns the element type.
// We only handle HFA of HP, SP, DP and QP here since these are the only ones we
// have (no vectors).
// An Homogeneous Aggregate is a Composite Type where all of the Fundamental
// Data Types of the members that compose the type are the same.
// Note that it is the fundamental types that are important and not the member
Expand All @@ -62,6 +110,7 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType
// dt is a pointerfree type, (all members are isbits)
// dsz == dt->size > 0
// 0 <= nele <= 3
// dt has no padding

// We ignore zero sized member here. This isn't really consistent with
// GCC for zero-sized array members. GCC seems to treat structs with
Expand All @@ -83,6 +132,14 @@ static bool isHFAorHVA(jl_datatype_t *dt, size_t dsz, size_t &nele, ElementType
dt = (jl_datatype_t*)jl_field_type(dt, i);
continue;
}
if (Type *vectype = get_llvm_vectype(dt)) {
if ((ele.sz && dsz != ele.sz) || (ele.type && ele.type != vectype))
return false;
ele.type = vectype;
ele.sz = dsz;
nele++;
return true;
}
// Otherwise, process each members
for (;i < nfields;i++) {
size_t fieldsz = jl_field_size(dt, i);
Expand Down Expand Up @@ -183,9 +240,7 @@ static Type *classify_arg(jl_value_t *ty, bool *fpreg, bool *onstack,
// the argument is allocated to the least significant bits of register
// v[NSRN]. The NSRN is incremented by one. The argument has now been
// allocated.
// Note that this is missing QP float as well as short vector types since we
// don't really have those types.
if (get_llvm_fptype(dt)) {
if (get_llvm_fp_or_vectype(dt)) {
*fpreg = true;
return NULL;
}
Expand Down Expand Up @@ -323,7 +378,7 @@ Type *preferred_llvm_type(jl_value_t *ty, bool)
if (!jl_is_datatype(ty) || jl_is_abstracttype(ty))
return NULL;
jl_datatype_t *dt = (jl_datatype_t*)ty;
if (Type *fptype = get_llvm_fptype(dt))
if (Type *fptype = get_llvm_fp_or_vectype(dt))
return fptype;
bool fpreg = false;
bool onstack = false;
Expand Down
4 changes: 2 additions & 2 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_uninitialized_datatype(size_t nfields, int8_t
// For sake of Ahead-Of-Time (AOT) compilation, this routine has to work
// without LLVM being available.
unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) {
if (!is_vecelement_type(t))
if (!jl_is_vecelement_type(t))
return 0;
// LLVM 3.7 and 3.8 either crash or generate wrong code for many
// SIMD vector sizes N. It seems the rule is that N can have at
Expand All @@ -859,7 +859,7 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t) {
return 0; // nfields has more than two 1s
assert(jl_datatype_nfields(t)==1);
jl_value_t *ty = jl_field_type(t, 0);
if( !jl_is_bitstype(ty) )
if (!jl_is_bitstype(ty))
// LLVM requires that a vector element be a primitive type.
// LLVM allows pointer types as vector elements, but until a
// motivating use case comes up for Julia, we reject pointers.
Expand Down
74 changes: 72 additions & 2 deletions src/ccalltest.c
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,10 @@ JL_DLLEXPORT struct16 test_16(struct16 a, float b) {
return a;
}

// Note for AArch64:
// `i128` is a native type on aarch64 so the type here is wrong.
// However, it happens to have the same calling convention with `[2 x i64]`
// when used as first argument or return value.
#define int128_t struct3b
JL_DLLEXPORT int128_t test_128(int128_t a, int64_t b) {
//Unpack a Int128
Expand Down Expand Up @@ -393,16 +397,82 @@ JL_DLLEXPORT void *test_echo_p(void *p) {

#include <xmmintrin.h>

JL_DLLEXPORT __m128i test_m128i(__m128i a, __m128i b, __m128i c, __m128i d ) {
JL_DLLEXPORT __m128i test_m128i(__m128i a, __m128i b, __m128i c, __m128i d )
{
// 64-bit x86 has only level 2 SSE, which does not have a <4 x int32> multiplication,
// so we use floating-point instead, and assume caller knows about the hack.
return _mm_add_epi32(a,
_mm_cvtps_epi32(_mm_mul_ps(_mm_cvtepi32_ps(b),
_mm_cvtepi32_ps(_mm_sub_epi32(c,d)))));
}

JL_DLLEXPORT __m128 test_m128(__m128 a, __m128 b, __m128 c, __m128 d ) {
JL_DLLEXPORT __m128 test_m128(__m128 a, __m128 b, __m128 c, __m128 d )
{
return _mm_add_ps(a, _mm_mul_ps(b, _mm_sub_ps(c, d)));
}

#endif

#ifdef _CPU_AARCH64_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these tests aarch64 specific?

Copy link
Contributor Author

@yuyichao yuyichao May 30, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__int128, __fp16 and the vector types used below are AArch64 extensions (especially when passed as arguments or return values). I'm pretty sure at least __fp16 is not allowed as return value elsewhere (even when similar type is defined in C).


JL_DLLEXPORT __int128 test_aa64_i128_1(int64_t v1, __int128 v2)
{
return v1 * 2 - v2;
}

typedef struct {
int32_t v1;
__int128 v2;
} struct_aa64_1;

JL_DLLEXPORT struct_aa64_1 test_aa64_i128_2(int64_t v1, __int128 v2,
struct_aa64_1 v3)
{
struct_aa64_1 x = {(int32_t)v1 / 2 + 1 - v3.v1, v2 * 2 - 1 - v3.v2};
return x;
}

typedef struct {
__fp16 v1;
double v2;
} struct_aa64_2;

JL_DLLEXPORT __fp16 test_aa64_fp16_1(int v1, float v2, double v3, __fp16 v4)
{
return (__fp16)(v1 + v2 * 2 + v3 * 3 + v4 * 4);
}

JL_DLLEXPORT struct_aa64_2 test_aa64_fp16_2(int v1, float v2,
double v3, __fp16 v4)
{
struct_aa64_2 x = {v4 / 2 + 1, v1 * 2 + v2 * 4 - v3};
return x;
}

#include <arm_neon.h>

JL_DLLEXPORT int64x2_t test_aa64_vec_1(int32x2_t v1, float _v2, int32x2_t v3)
{
int v2 = (int)_v2;
return vmovl_s32(v1 * v2 + v3);
}

// This is a homogenious short vector aggregate
typedef struct {
int8x8_t v1;
float32x2_t v2;
} struct_aa64_3;

// This is NOT a homogenious short vector aggregate
typedef struct {
float32x2_t v2;
int16x8_t v1;
} struct_aa64_4;

JL_DLLEXPORT struct_aa64_3 test_aa64_vec_2(struct_aa64_3 v1, struct_aa64_4 v2)
{
struct_aa64_3 x = {v1.v1 + vmovn_s16(v2.v1), v1.v2 - v2.v2};
return x;
}

#endif
8 changes: 4 additions & 4 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ static Type *julia_struct_to_llvm(jl_value_t *jt, bool *isboxed)
latypes.push_back(lty);
}
if (!isTuple) {
if (is_vecelement_type(jt))
if (jl_is_vecelement_type(jt))
// VecElement type is unwrapped in LLVM
jst->struct_decl = latypes[0];
else
Expand Down Expand Up @@ -1101,7 +1101,7 @@ static jl_cgval_t emit_getfield_knownidx(const jl_cgval_t &strct, unsigned idx,
}
else if (strct.ispointer()) { // something stack allocated
Value *addr;
if (is_vecelement_type((jl_value_t*)jt))
if (jl_is_vecelement_type((jl_value_t*)jt))
// VecElement types are unwrapped in LLVM.
addr = strct.V;
else
Expand Down Expand Up @@ -1678,7 +1678,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg
// or instead initialize the stack buffer with stores
bool init_as_value = false;
if (lt->isVectorTy() ||
is_vecelement_type(ty) ||
jl_is_vecelement_type(ty) ||
type_is_ghost(lt)) // maybe also check the size ?
init_as_value = true;

Expand Down Expand Up @@ -1714,7 +1714,7 @@ static jl_cgval_t emit_new_struct(jl_value_t *ty, size_t nargs, jl_value_t **arg
strct = builder.CreateInsertValue(strct, fval, ArrayRef<unsigned>(&idx,1));
else {
// Must be a VecElement type, which comes unwrapped in LLVM.
assert(is_vecelement_type(ty));
assert(jl_is_vecelement_type(ty));
strct = fval;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ STATIC_INLINE int jl_is_tuple_type(void *t)
((jl_datatype_t*)(t))->name == jl_tuple_typename);
}

STATIC_INLINE int is_vecelement_type(jl_value_t* t)
STATIC_INLINE int jl_is_vecelement_type(jl_value_t* t)
{
return (jl_is_datatype(t) &&
((jl_datatype_t*)(t))->name == jl_vecelement_typename);
Expand Down
81 changes: 80 additions & 1 deletion test/ccall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,27 @@ typealias VecReg{N,T} NTuple{N,VecElement{T}}
typealias V4xF32 VecReg{4,Float32}
typealias V4xI32 VecReg{4,Int32}

if Sys.ARCH==:x86_64
immutable Struct_AA64_1
v1::Int32
v2::Int128
end
immutable Struct_AA64_2
v1::Float16
v2::Float64
end

# This is a homogenious short vector aggregate
immutable Struct_AA64_3
v1::VecReg{8,Int8}
v2::VecReg{2,Float32}
end
# This is NOT a homogenious short vector aggregate
immutable Struct_AA64_4
v2::VecReg{2,Float32}
v1::VecReg{8,Int16}
end

if Sys.ARCH === :x86_64

function test_sse(a1::V4xF32,a2::V4xF32,a3::V4xF32,a4::V4xF32)
ccall((:test_m128, libccalltest), V4xF32, (V4xF32,V4xF32,V4xF32,V4xF32), a1, a2, a3, a4)
Expand All @@ -556,4 +576,63 @@ if Sys.ARCH==:x86_64
# cfunction round-trip
@test rt_sse(a1,a2,a3,a4) == r
end
elseif Sys.ARCH === :aarch64
for v1 in 1:99:1000, v2 in -100:-1999:-20000
@test ccall((:test_aa64_i128_1, libccalltest), Int128,
(Int64, Int128), v1, v2) == v1 * 2 - v2
end
for v1 in 1:4, v2 in -4:-1, v3_1 in 3:5, v3_2 in 7:9
res = ccall((:test_aa64_i128_2, libccalltest), Struct_AA64_1,
(Int64, Int128, Struct_AA64_1),
v1, v2, Struct_AA64_1(v3_1, v3_2))
expected = Struct_AA64_1(v1 ÷ 2 + 1 - v3_1, v2 * 2 - 1 - v3_2)
@test res === expected
end
for v1 in 1:4, v2 in -4:-1, v3 in 3:5, v4 in -(1:3)
res = ccall((:test_aa64_fp16_1, libccalltest), Float16,
(Cint, Float32, Float64, Float16),
v1, v2, v3, v4)
expected = Float16(v1 + v2 * 2 + v3 * 3 + v4 * 4)
@test res === expected

res = ccall((:test_aa64_fp16_2, libccalltest), Struct_AA64_2,
(Cint, Float32, Float64, Float16),
v1, v2, v3, v4)
expected = Struct_AA64_2(v4 / 2 + 1, v1 * 2 + v2 * 4 - v3)
@test res === expected
end
for v1_1 in 1:4, v1_2 in -2:2, v2 in -4:-1, v3_1 in 3:5, v3_2 in 6:8
res = ccall((:test_aa64_vec_1, libccalltest),
VecReg{2,Int64},
(VecReg{2,Int32}, Float32, VecReg{2,Int32}),
(VecElement(Int32(v1_1)), VecElement(Int32(v1_2))),
v2, (VecElement(Int32(v3_1)), VecElement(Int32(v3_2))))
expected = (VecElement(v1_1 * v2 + v3_1), VecElement(v1_2 * v2 + v3_2))
@test res === expected
end
for v1_11 in 1:4, v1_12 in -2:2, v1_21 in 1:4, v1_22 in -2:2,
v2_11 in 1:4, v2_12 in -2:2, v2_21 in 1:4, v2_22 in -2:2
v1 = Struct_AA64_3((VecElement(Int8(v1_11)), VecElement(Int8(v1_12)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0))),
(VecElement(Float32(v1_21)),
VecElement(Float32(v1_22))))
v2 = Struct_AA64_4((VecElement(Float32(v2_21)),
VecElement(Float32(v2_22))),
(VecElement(Int16(v2_11)), VecElement(Int16(v2_12)),
VecElement(Int16(0)), VecElement(Int16(0)),
VecElement(Int16(0)), VecElement(Int16(0)),
VecElement(Int16(0)), VecElement(Int16(0))))
res = ccall((:test_aa64_vec_2, libccalltest),
Struct_AA64_3, (Struct_AA64_3, Struct_AA64_4), v1, v2)
expected = Struct_AA64_3((VecElement(Int8(v1_11 + v2_11)),
VecElement(Int8(v1_12 + v2_12)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0)),
VecElement(Int8(0)), VecElement(Int8(0))),
(VecElement(Float32(v1_21 - v2_21)),
VecElement(Float32(v1_22 - v2_22))))
@test res === expected
end
end