Skip to content

Commit a722fdc

Browse files
author
Maxime France-Pillois
authored
[SYCL][Graph] Improves node_impl equal operator and new nodes list allocation policy (#303)
1 parent b3cf214 commit a722fdc

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,22 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
132132
std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(
133133
const std::shared_ptr<exec_graph_impl> &SubGraphExec) {
134134
std::map<std::shared_ptr<node_impl>, std::shared_ptr<node_impl>> NodesMap;
135-
std::list<std::shared_ptr<node_impl>> NewNodesList;
136135

137136
std::list<std::shared_ptr<node_impl>> NodesList = SubGraphExec->getSchedule();
137+
std::list<std::shared_ptr<node_impl>> NewNodesList{NodesList.size()};
138138

139139
// Duplication of nodes
140+
std::list<std::shared_ptr<node_impl>>::iterator NewNodesIt =
141+
NewNodesList.end();
140142
for (std::list<std::shared_ptr<node_impl>>::const_iterator NodeIt =
141143
NodesList.end();
142144
NodeIt != NodesList.begin();) {
143145
--NodeIt;
146+
--NewNodesIt;
144147
auto Node = *NodeIt;
145148
std::shared_ptr<node_impl> NodeCopy;
146149
duplicateNode(Node, NodeCopy);
147-
NewNodesList.push_back(NodeCopy);
150+
*NewNodesIt = NodeCopy;
148151
NodesMap.insert({Node, NodeCopy});
149152
for (auto &NextNode : Node->MSuccessors) {
150153
if (NodesMap.find(NextNode) != NodesMap.end()) {

sycl/source/detail/graph_impl.hpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,42 +74,47 @@ class node_impl {
7474

7575
/// Tests if two nodes have the same content,
7676
/// i.e. same command group
77+
/// This function should only be used for internal purposes.
78+
/// A true return from this operator is not a guarantee that the nodes are
79+
/// equals according to the Common reference semantics. But this function is
80+
/// an helper to verify that two nodes contain equivalent Command Groups.
7781
/// @param Node node to compare with
82+
/// @return true if two nodes have equivament command groups. false otherwise.
7883
bool operator==(const node_impl &Node) {
7984
if (MCGType != Node.MCGType)
8085
return false;
8186

82-
if (MCGType == sycl::detail::CG::CGTYPE::Kernel) {
87+
switch (MCGType) {
88+
case sycl::detail::CG::CGTYPE::Kernel: {
8389
sycl::detail::CGExecKernel *ExecKernelA =
8490
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
8591
sycl::detail::CGExecKernel *ExecKernelB =
8692
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
87-
88-
if (ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) != 0)
89-
return false;
93+
return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
9094
}
91-
if (MCGType == sycl::detail::CG::CGTYPE::CopyUSM) {
95+
case sycl::detail::CG::CGTYPE::CopyUSM: {
9296
sycl::detail::CGCopyUSM *CopyA =
9397
static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
9498
sycl::detail::CGCopyUSM *CopyB =
9599
static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
96-
if ((CopyA->getSrc() != CopyB->getSrc()) ||
97-
(CopyA->getDst() != CopyB->getDst()) ||
98-
(CopyA->getLength() == CopyB->getLength()))
99-
return false;
100+
return (CopyA->getSrc() == CopyB->getSrc()) &&
101+
(CopyA->getDst() == CopyB->getDst()) &&
102+
(CopyA->getLength() == CopyB->getLength());
100103
}
101-
if ((MCGType == sycl::detail::CG::CGTYPE::CopyAccToAcc) ||
102-
(MCGType == sycl::detail::CG::CGTYPE::CopyAccToPtr) ||
103-
(MCGType == sycl::detail::CG::CGTYPE::CopyPtrToAcc)) {
104+
case sycl::detail::CG::CGTYPE::CopyAccToAcc:
105+
case sycl::detail::CG::CGTYPE::CopyAccToPtr:
106+
case sycl::detail::CG::CGTYPE::CopyPtrToAcc: {
104107
sycl::detail::CGCopy *CopyA =
105108
static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
106109
sycl::detail::CGCopy *CopyB =
107110
static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
108-
if ((CopyA->getSrc() != CopyB->getSrc()) ||
109-
(CopyA->getDst() != CopyB->getDst()))
110-
return false;
111+
return (CopyA->getSrc() == CopyB->getSrc()) &&
112+
(CopyA->getDst() == CopyB->getDst());
113+
}
114+
default:
115+
assert(false && "Unexpected command group type!");
116+
return false;
111117
}
112-
return true;
113118
}
114119

115120
/// Recursively add nodes to execution stack.

0 commit comments

Comments
 (0)