Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,22 @@ std::shared_ptr<node_impl> graph_impl::addNodesToExits(
std::shared_ptr<node_impl> graph_impl::addSubgraphNodes(
const std::shared_ptr<exec_graph_impl> &SubGraphExec) {
std::map<std::shared_ptr<node_impl>, std::shared_ptr<node_impl>> NodesMap;
std::list<std::shared_ptr<node_impl>> NewNodesList;

std::list<std::shared_ptr<node_impl>> NodesList = SubGraphExec->getSchedule();
std::list<std::shared_ptr<node_impl>> NewNodesList{NodesList.size()};

// Duplication of nodes
std::list<std::shared_ptr<node_impl>>::iterator NewNodesIt =
NewNodesList.end();
for (std::list<std::shared_ptr<node_impl>>::const_iterator NodeIt =
NodesList.end();
NodeIt != NodesList.begin();) {
--NodeIt;
--NewNodesIt;
auto Node = *NodeIt;
std::shared_ptr<node_impl> NodeCopy;
duplicateNode(Node, NodeCopy);
NewNodesList.push_back(NodeCopy);
*NewNodesIt = NodeCopy;
NodesMap.insert({Node, NodeCopy});
for (auto &NextNode : Node->MSuccessors) {
if (NodesMap.find(NextNode) != NodesMap.end()) {
Expand Down
37 changes: 21 additions & 16 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,42 +74,47 @@ class node_impl {

/// Tests if two nodes have the same content,
/// i.e. same command group
/// This function should only be used for internal purposes.
/// A true return from this operator is not a guarantee that the nodes are
/// equals according to the Common reference semantics. But this function is
/// an helper to verify that two nodes contain equivalent Command Groups.
/// @param Node node to compare with
/// @return true if two nodes have equivament command groups. false otherwise.
bool operator==(const node_impl &Node) {
if (MCGType != Node.MCGType)
return false;

if (MCGType == sycl::detail::CG::CGTYPE::Kernel) {
switch (MCGType) {
case sycl::detail::CG::CGTYPE::Kernel: {
sycl::detail::CGExecKernel *ExecKernelA =
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
sycl::detail::CGExecKernel *ExecKernelB =
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());

if (ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) != 0)
return false;
return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
}
if (MCGType == sycl::detail::CG::CGTYPE::CopyUSM) {
case sycl::detail::CG::CGTYPE::CopyUSM: {
sycl::detail::CGCopyUSM *CopyA =
static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
sycl::detail::CGCopyUSM *CopyB =
static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
if ((CopyA->getSrc() != CopyB->getSrc()) ||
(CopyA->getDst() != CopyB->getDst()) ||
(CopyA->getLength() == CopyB->getLength()))
return false;
return (CopyA->getSrc() == CopyB->getSrc()) &&
(CopyA->getDst() == CopyB->getDst()) &&
(CopyA->getLength() == CopyB->getLength());
}
if ((MCGType == sycl::detail::CG::CGTYPE::CopyAccToAcc) ||
(MCGType == sycl::detail::CG::CGTYPE::CopyAccToPtr) ||
(MCGType == sycl::detail::CG::CGTYPE::CopyPtrToAcc)) {
case sycl::detail::CG::CGTYPE::CopyAccToAcc:
case sycl::detail::CG::CGTYPE::CopyAccToPtr:
case sycl::detail::CG::CGTYPE::CopyPtrToAcc: {
sycl::detail::CGCopy *CopyA =
static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
sycl::detail::CGCopy *CopyB =
static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
if ((CopyA->getSrc() != CopyB->getSrc()) ||
(CopyA->getDst() != CopyB->getDst()))
return false;
return (CopyA->getSrc() == CopyB->getSrc()) &&
(CopyA->getDst() == CopyB->getDst());
}
default:
assert(false && "Unexpected command group type!");
return false;
}
return true;
}

/// Recursively add nodes to execution stack.
Expand Down