Skip to content

Commit

Permalink
Merge branch 'develop' into support_multi_multimachine
Browse files Browse the repository at this point in the history
  • Loading branch information
AndSonder authored Nov 21, 2023
2 parents 0bff0af + 192a5f8 commit 34feef0
Show file tree
Hide file tree
Showing 262 changed files with 7,334 additions and 3,969 deletions.
11 changes: 11 additions & 0 deletions cmake/external/gloo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ if(CMAKE_COMPILER_IS_GNUCC)
${SOURCE_DIR}/gloo/ < ${types_header})
endif()
endif()

file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/linux.cc.patch
linux_cc_ethtool)
if(GLOO_PATCH_COMMAND STREQUAL "")
set(GLOO_PATCH_COMMAND git checkout -- . && git checkout ${GLOO_TAG} && patch
-Nd ${SOURCE_DIR}/gloo/common/ < ${linux_cc_ethtool})
else()
set(GLOO_PATCH_COMMAND ${GLOO_PATCH_COMMAND} && patch -Nd
${SOURCE_DIR}/gloo/common/ < ${linux_cc_ethtool})
endif()

include_directories(${GLOO_INCLUDE_DIR})

ExternalProject_Add(
Expand Down
62 changes: 33 additions & 29 deletions paddle/cinn/adt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
add_subdirectory(print_utils)
if(NOT CINN_ONLY)
add_subdirectory(print_utils)

core_gather_headers()
core_gather_headers()

gather_srcs(
cinnapi_src
SRCS
anchor_sd_equation_context.cc
equation_function.cc
equation_solver.cc
equation_value.cc
generate_map_expr.cc
get_sub_reshape_dim_ranges.cc
igroup.cc
index_expr_infer_context.cc
kgroup.cc
m_ir.cc
naive_bidirection_equation_generator.cc
naive_op_equation_context.cc
partition_op_stmts.cc
schedule_descriptor.cc
schedule_dim.cc
schedule_mesh.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc)
gather_srcs(
cinnapi_src
SRCS
adapter_tensor.cc
anchor_sd_equation_context.cc
equation_function.cc
equation_solver.cc
equation_value.cc
generate_map_expr.cc
get_sub_reshape_dim_ranges.cc
igroup.cc
index_expr_infer_context.cc
kgroup.cc
m_ir.cc
naive_bidirection_equation_generator.cc
naive_op_equation_context.cc
partition_op_stmts.cc
schedule_descriptor.cc
schedule_dim.cc
schedule_mesh.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc)

cinn_cc_test(equation_value_match_trait_test SRCS
equation_value_match_trait_test.cc DEPS gtest glog)
cinn_cc_test(equation_value_match_trait_test SRCS
equation_value_match_trait_test.cc DEPS gtest glog)

cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog)
cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog)

cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS
cinncore)
cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS
cinncore)

message(STATUS "ADT srcs: ${cinnapi_src}")
message(STATUS "ADT srcs: ${cinnapi_src}")

endif()
44 changes: 44 additions & 0 deletions paddle/cinn/adt/adapter_tensor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/adt/adapter_tensor.h"
#include "glog/logging.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"

namespace cinn::adt::adapter {

std::size_t Tensor::GetRank() const {
return cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)
.size();
}

std::vector<int32_t> Tensor::GetShape() const {
std::vector<int32_t> ret{};
for (int dim_size :
cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) {
ret.emplace_back(dim_size);
}
return ret;
}

std::size_t Tensor::GetNumel() const {
std::size_t ret = 1;
for (int dim_size :
cinn::hlir::framework::pir::CompatibleInfo::ValueShape(node_data)) {
ret = ret * dim_size;
}
return ret;
}

} // namespace cinn::adt::adapter
45 changes: 7 additions & 38 deletions paddle/cinn/adt/adapter_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,28 @@
// limitations under the License.

#pragma once
#include "glog/logging.h"

#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/pir/core/value.h"

namespace cinn::adt::adapter {

struct Tensor final {
const hlir::framework::NodeData* node_data;
const hlir::framework::Graph* graph;
::pir::Value node_data;

bool operator==(const Tensor& other) const {
return this->node_data == other.node_data && this->graph == other.graph;
return this->node_data == other.node_data;
}

std::size_t GetRank() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
return shape_dict.at(node_data->id()).size();
}
std::size_t GetRank() const;

const std::vector<int32_t>& GetShape() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
return shape_dict.at(node_data->id());
}
std::vector<int32_t> GetShape() const;

std::size_t GetNumel() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
std::vector<int32_t> shape = shape_dict.at(node_data->id());
std::size_t ret = 1;
for (int32_t dim_size : shape) {
ret = ret * dim_size;
}
return ret;
}
std::size_t GetNumel() const;
};

inline std::size_t GetHashValueImpl(const Tensor& tensor) {
return hash_combine(
std::hash<const hlir::framework::NodeData*>()(tensor.node_data),
std::hash<const hlir::framework::Graph*>()(tensor.graph));
return std::hash<::pir::Value>()(tensor.node_data);
}

} // namespace cinn::adt::adapter
5 changes: 0 additions & 5 deletions paddle/cinn/adt/equation_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/match.h"

namespace cinn::hlir::framework {
class Node;
class NodeData;
} // namespace cinn::hlir::framework

namespace cinn::adt {

DEFINE_ADT_TAG(tPointer);
Expand Down
Loading

0 comments on commit 34feef0

Please sign in to comment.