Skip to content

Commit cf6312e

Browse files
committed
Fix Vulkan interleave SPIRV codegen. Fix a bug in Simplify_Shuffle. Fix a bug in Deinterleave.
1 parent 60621b8 commit cf6312e

File tree

8 files changed

+229
-103
lines changed

8 files changed

+229
-103
lines changed

src/CodeGen_Vulkan_Dev.cpp

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,31 +2041,21 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
20412041
debug(3) << "\n";
20422042

20432043
if (arg_ids.size() == 1) {
2044-
20452044
// 1 argument, just do a simple assignment via a cast
20462045
SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]);
20472046
builder.update_id(result_id);
20482047

20492048
} else if (arg_ids.size() == 2) {
2050-
2051-
// 2 arguments, use a composite insert to update even and odd indices
2052-
uint32_t even_idx = 0;
2053-
uint32_t odd_idx = 1;
2054-
SpvFactory::Indices even_indices;
2055-
SpvFactory::Indices odd_indices;
2056-
for (int i = 0; i < op_lanes; ++i) {
2057-
even_indices.push_back(even_idx);
2058-
odd_indices.push_back(odd_idx);
2059-
even_idx += 2;
2060-
odd_idx += 2;
2049+
// 2 arguments, use vector-shuffle with logical indices indexing into (vec1[0], vec1[1], ..., vec2[0], vec2[1], ...)
2050+
SpvFactory::Indices logical_indices;
2051+
for (int i = 0; i < arg_lanes; ++i) {
2052+
logical_indices.push_back(uint32_t(i));
2053+
logical_indices.push_back(uint32_t(i + arg_lanes));
20612054
}
20622055

20632056
SpvId type_id = builder.declare_type(op->type);
2064-
SpvId value_id = builder.declare_null_constant(op->type);
2065-
SpvId partial_id = builder.reserve_id(SpvResultId);
20662057
SpvId result_id = builder.reserve_id(SpvResultId);
2067-
builder.append(SpvFactory::composite_insert(type_id, partial_id, arg_ids[0], value_id, even_indices));
2068-
builder.append(SpvFactory::composite_insert(type_id, result_id, arg_ids[1], partial_id, odd_indices));
2058+
builder.append(SpvFactory::vector_shuffle(type_id, result_id, arg_ids[0], arg_ids[1], logical_indices));
20692059
builder.update_id(result_id);
20702060

20712061
} else {
@@ -2095,7 +2085,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
20952085
} else if (op->is_extract_element()) {
20962086
int idx = op->indices[0];
20972087
internal_assert(idx >= 0);
2098-
internal_assert(idx <= op->vectors[0].type().lanes());
2088+
internal_assert(idx < op->vectors[0].type().lanes());
20992089
if (op->vectors[0].type().is_vector()) {
21002090
SpvFactory::Indices indices = {(uint32_t)idx};
21012091
SpvId type_id = builder.declare_type(op->type);

src/Deinterleave.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ class Deinterleaver : public IRGraphMutator {
298298
} else {
299299

300300
Type t = op->type.with_lanes(new_lanes);
301+
internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes)
302+
<< "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane
303+
<< " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly."
304+
<< " Deinterleaver probably recursed too deep into types of different lane count.";
301305
if (external_lets.contains(op->name) &&
302306
starting_lane == 0 &&
303307
lane_stride == 2) {
@@ -392,8 +396,12 @@ class Deinterleaver : public IRGraphMutator {
392396
int index = indices.front();
393397
for (const auto &i : op->vectors) {
394398
if (index < i.type().lanes()) {
395-
ScopedValue<int> lane(starting_lane, index);
396-
return mutate(i);
399+
if (i.type().lanes() == op->type.lanes()) {
400+
ScopedValue<int> scoped_starting_lane(starting_lane, index);
401+
return mutate(i);
402+
} else {
403+
return Shuffle::make(op->vectors, indices);
404+
}
397405
}
398406
index -= i.type().lanes();
399407
}
@@ -405,10 +413,18 @@ class Deinterleaver : public IRGraphMutator {
405413
};
406414

407415
Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
416+
debug(3) << "Deinterleave "
417+
<< "(start:" << starting_lane << ", stide:" << lane_stride << ", new_lanes:" << new_lanes << "): "
418+
<< e << " of Type: " << e.type() << "\n";
419+
Type original_type = e.type();
408420
e = substitute_in_all_lets(e);
409421
Deinterleaver d(starting_lane, lane_stride, new_lanes, lets);
410422
e = d.mutate(e);
411423
e = common_subexpression_elimination(e);
424+
Type final_type = e.type();
425+
int expected_lanes = (original_type.lanes() + lane_stride - starting_lane - 1) / lane_stride;
426+
internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after interleaving.";
427+
internal_assert(expected_lanes == final_type.lanes()) << "Number of lanes incorrect after interleaving: " << final_type.lanes() << "while expected was " << expected_lanes << ".";
412428
return simplify(e);
413429
}
414430

@@ -419,12 +435,12 @@ Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {
419435

420436
Expr extract_even_lanes(const Expr &e, const Scope<> &lets) {
421437
internal_assert(e.type().lanes() % 2 == 0);
422-
return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets);
438+
return deinterleave(e, 0, 2, e.type().lanes() / 2, lets);
423439
}
424440

425441
Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) {
426442
internal_assert(e.type().lanes() % 3 == 0);
427-
return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets);
443+
return deinterleave(e, lane, 3, e.type().lanes() / 3, lets);
428444
}
429445

430446
} // namespace

src/Simplify_Let.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) {
9898
Expr new_var = Variable::make(f.new_value.type(), f.new_name);
9999
Expr replacement = new_var;
100100

101-
debug(4) << "simplify let " << op->name << " = " << f.value << " in...\n";
101+
debug(4) << "simplify let " << op->name << " = (" << f.value.type() << ") " << f.value << " in...\n";
102102

103103
while (true) {
104104
const Variable *var = f.new_value.template as<Variable>();
@@ -180,6 +180,16 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) {
180180
f.new_value = cast->value;
181181
new_var = Variable::make(f.new_value.type(), f.new_name);
182182
replacement = substitute(f.new_name, Cast::make(cast->type, new_var), replacement);
183+
} else if (shuffle && shuffle->is_concat() && is_pure(shuffle)) {
184+
// Substitute in all concatenates as they will likely simplify
185+
// with other shuffles.
186+
// As the structure of this while loop makes it hard to peel off
187+
// pure operations from _all_ arguments to the Shuffle, we will
188+
// instead subsitute all of the vars that go in the shuffle, and
189+
// instead guard against side effects by checking with `is_pure()`.
190+
replacement = substitute(f.new_name, shuffle, replacement);
191+
f.new_value = Expr();
192+
break;
183193
} else if (shuffle && shuffle->is_slice()) {
184194
// Replacing new_value below might free the shuffle
185195
// indices vector, so save them now.

src/Simplify_Shuffle.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,13 +289,18 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
289289
if (inner_shuffle->is_concat()) {
290290
int slice_min = op->indices.front();
291291
int slice_max = op->indices.back();
292+
if (slice_min > slice_max) {
293+
// Slices can go backward.
294+
std::swap(slice_min, slice_max);
295+
}
292296
int concat_index = 0;
293297
int new_slice_start = -1;
294298
vector<Expr> new_concat_vectors;
295299
for (const auto &v : inner_shuffle->vectors) {
296300
// Check if current concat vector overlaps with slice.
297-
if ((concat_index >= slice_min && concat_index <= slice_max) ||
298-
((concat_index + v.type().lanes() - 1) >= slice_min && (concat_index + v.type().lanes() - 1) <= slice_max)) {
301+
int overlap_max = std::min(slice_max, concat_index + v.type().lanes() - 1);
302+
int overlap_min = std::max(slice_min, concat_index);
303+
if (overlap_min <= overlap_max) {
299304
if (new_slice_start < 0) {
300305
new_slice_start = concat_index;
301306
}
@@ -305,7 +310,10 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) {
305310
concat_index += v.type().lanes();
306311
}
307312
if (new_concat_vectors.size() < inner_shuffle->vectors.size()) {
308-
return Shuffle::make_slice(Shuffle::make_concat(new_concat_vectors), op->slice_begin() - new_slice_start, op->slice_stride(), op->indices.size());
313+
return Shuffle::make_slice(Shuffle::make_concat(new_concat_vectors),
314+
op->slice_begin() - new_slice_start,
315+
op->slice_stride(),
316+
op->indices.size());
309317
}
310318
}
311319
}

src/runtime/opencl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ extern "C" WEAK void *halide_opencl_get_symbol(void *user_context, const char *n
3737
#ifdef WINDOWS
3838
"opencl.dll",
3939
#else
40-
"libOpenCL.so",
40+
"libOpenCL.so.1",
4141
"/System/Library/Frameworks/OpenCL.framework/OpenCL",
4242
#endif
4343
};

src/runtime/vulkan_internal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ const char *vk_get_error_name(VkResult error) {
279279
return "VK_ERROR_FORMAT_NOT_SUPPORTED";
280280
case VK_ERROR_FRAGMENTED_POOL:
281281
return "VK_ERROR_FRAGMENTED_POOL";
282+
case VK_ERROR_UNKNOWN:
283+
return "VK_ERROR_UNKNOWN";
282284
case VK_ERROR_SURFACE_LOST_KHR:
283285
return "VK_ERROR_SURFACE_LOST_KHR";
284286
case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR:
@@ -302,6 +304,8 @@ const char *vk_get_error_name(VkResult error) {
302304
}
303305
}
304306

307+
#define vk_report_error(user_context, code, func) (error((user_context)) << "Vulkan: " << (func) << " returned " << vk_get_error_name((code)) << " (code: " << (code) << ") ")
308+
305309
// --------------------------------------------------------------------------
306310

307311
} // namespace

0 commit comments

Comments
 (0)