Skip to content

Commit a5ed21d

Browse files
authored
[CODEGEN][METAL] Fix ramp codegen (#14330)
Fix ramp node codegen for the metal backend. The default C codegen can cause problem in vector indices assignment. Confirmed on apple M2.
1 parent 2ff41c6 commit a5ed21d

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/target/source/codegen_metal.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,17 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
299299
os << ')';
300300
}
301301

302+
void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
303+
PrintType(op->dtype, os);
304+
os << "(";
305+
for (int i = 0; i < op->lanes; ++i) {
306+
if (i != 0) os << ", ";
307+
os << "(" << PrintExpr(op->base) << ")"
308+
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
309+
}
310+
os << ')';
311+
}
312+
302313
void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
303314
if (op->op.same_as(builtin::reinterpret())) {
304315
// generate as_type<TYPE>(ARG)

src/target/source/codegen_metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class CodeGenMetal final : public CodeGenC {
5151
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
5252
// overload visitor
5353
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
54+
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
5455
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
5556
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
5657
// reuse parent's function.

0 commit comments

Comments
 (0)