diff --git a/cinn/backends/codegen_c_test.cc b/cinn/backends/codegen_c_test.cc index 00e4dbcfe8a5a..55c4d67fd1fd8 100644 --- a/cinn/backends/codegen_c_test.cc +++ b/cinn/backends/codegen_c_test.cc @@ -105,5 +105,82 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c EXPECT_EQ(utils::Trim(out), utils::Trim(target_str)); } +TEST(CodeGenC, module_with_transform) { + lang::Placeholder A("A", {100, 20}); + lang::Placeholder B("B", {100, 20}); + + lang::Buffer C_buf, D_buf; + + // An inlined tensor, should not appear in final C code! It can be used by any times and expand its expression there. + auto inlined0 = lang::Compute({100, 20}, [&](Var i, Var j) { return A(i, j) * 2.f + 1.f; }); + + auto C = lang::Compute( + {100, 20}, [&](Var i, Var j) { return A(i, j) + B(i, j) + inlined0(i, j); }, "C"); + C->Bind(C_buf); + + auto D = lang::Compute( + {100, 20}, [&](Var i, Var j) { return C(i, j) * 2.f * inlined0(i, j); }, "D"); + D->Bind(D_buf); + + poly::Iterator i_outer, i_inner; + std::tie(i_outer, i_inner) = C->stage()->Split(poly::DefaultIterator(0), 4); + + D->stage()->Tile(poly::DefaultIterator(0), poly::DefaultIterator(1), 4, 16); + + Target target; + target.arch = Target::Arch ::X86; + target.bits = Target::Bit ::k32; + target.os = Target::OS ::Linux; + lang::Module module("module1", target); + + auto funcs = lang::Lower("add1", {A, B, C, D}); + + ASSERT_EQ(funcs.size(), 1UL); + + module.Append(funcs.front()); + module.Append(C_buf); + + std::stringstream ss; + CodeGenC codegen(ss, target); + codegen.Compile(module); + + auto out = ss.str(); + std::cout << "codegen C:" << std::endl << out << std::endl; + + auto tgt = R"ROC( +#ifndef _MODULE1_CINN_H_ +#define _MODULE1_CINN_H_ + +#include +#include + +cinn_buffer_t* C = cinn_buffer_t::new_(0/*target*/); +void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, const struct cinn_buffer_t *C, struct cinn_buffer_t *D) +{ + cinn_buffer_malloc(D); + for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1){ + for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1){ + for (int32_t j = 0; (j <= 19); j += 1){ + C[((((4 * i_outer) + i_inner) * 20) + j)] = ((A[((((4 * i_outer) + i_inner) * 20) + j)] + B[((((4 * i_outer) + i_inner) * 20) + j)]) + ((A[((((4 * i_outer) + i_inner) * 20) + j)] * 2) + 1)); + }; + }; + }; + for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1){ + for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1){ + for (int32_t j_outer = 0; (j_outer <= 1); j_outer += 1){ + for (int32_t j_inner = 0; (j_inner <= min(15, ((-16 * j_outer) + 19))); j_inner += 1){ + D[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] = ((C[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] * 2) * ((A[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] * 2) + 1)); + }; + }; + }; + }; +} + +#endif // _MODULE1_CINN_H_ +)ROC"; + + ASSERT_EQ(utils::Trim(out), utils::Trim(tgt)); +} + } // namespace backends } // namespace cinn diff --git a/cinn/poly/stage.h b/cinn/poly/stage.h index 2d8fc86cdc099..bd516c0478ca3 100644 --- a/cinn/poly/stage.h +++ b/cinn/poly/stage.h @@ -104,5 +104,7 @@ inline std::string InnerName(const Iterator& iterator); inline std::string OuterName(const std::string& name); inline std::string OuterName(const Iterator& iterator); +inline Iterator DefaultIterator(int i) { return Iterator(common::axis_name(i)); } + } // namespace poly } // namespace cinn