diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index 5f27ca5c02559..1132d32b2f605 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -985,22 +986,6 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr CommandGroup, createGraphForCommand(NewCmd.get(), NewCmd->getCG(), isInteropHostTask(NewCmd.get()), Reqs, Events, Queue, FusionCmd->auxiliaryCommands()); - // We need to check the commands that this kernel depends on for any other - // commands that have been submitted to another queue which is also in - // fusion mode. If we detect such another command, we cancel fusion for that - // other queue to avoid circular dependencies. - // Handle requirements on any commands part of another active fusion. - for (auto &Dep : NewCmd->MDeps) { - auto *DepCmd = Dep.MDepCommand; - if (!DepCmd) { - continue; - } - if (DepCmd->getQueue() != Queue && isPartOfActiveFusion(DepCmd)) { - printFusionWarning("Aborting fusion because of requirement from a " - "different fusion process"); - cancelFusion(DepCmd->getQueue(), ToEnqueue); - } - } // Set the fusion command, so we recognize when another command depends on a // kernel in the fusion list. @@ -1431,6 +1416,63 @@ void Scheduler::GraphBuilder::cancelFusion(QueueImplPtr Queue, PlaceholderCmd->setFusionStatus(KernelFusionCommand::FusionStatus::CANCELLED); } +static bool isPartOfFusion(Command *Cmd, KernelFusionCommand *Fusion) { + if (Cmd->getType() == Command::RUN_CG) { + return static_cast(Cmd)->MFusionCmd == Fusion; + } + return false; +} + +static bool checkForCircularDependency(Command *, bool, KernelFusionCommand *); + +static bool createsCircularDependency(Command *Cmd, bool PredPartOfFusion, + KernelFusionCommand *Fusion) { + if (isPartOfFusion(Cmd, Fusion)) { + // If this is part of the fusion and the predecessor also was, we can stop + // the traversal here. A direct dependency between two kernels in the same + // fusion will never form a cyclic dependency and by iterating over all + // commands in a fusion, we will detect any cycles originating from the + // current command. + // If the predecessor was not part of the fusion, but the current command + // is, we have found a potential cycle in the dependency graph. + return !PredPartOfFusion; + } + return checkForCircularDependency(Cmd, false, Fusion); +} + +static bool checkForCircularDependency(Command *Cmd, bool IsPartOfFusion, + KernelFusionCommand *Fusion) { + // Check the requirement dependencies. + for (auto &Dep : Cmd->MDeps) { + auto *DepCmd = Dep.MDepCommand; + if (!DepCmd) { + continue; + } + if (createsCircularDependency(DepCmd, IsPartOfFusion, Fusion)) { + return true; + } + } + for (auto &Ev : Cmd->getPreparedDepsEvents()) { + auto *EvDepCmd = static_cast(Ev->getCommand()); + if (!EvDepCmd) { + continue; + } + if (createsCircularDependency(EvDepCmd, IsPartOfFusion, Fusion)) { + return true; + } + } + for (auto &Ev : Cmd->getPreparedHostDepsEvents()) { + auto *EvDepCmd = static_cast(Ev->getCommand()); + if (!EvDepCmd) { + continue; + } + if (createsCircularDependency(EvDepCmd, IsPartOfFusion, Fusion)) { + return true; + } + } + return false; +} + EventImplPtr Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue, std::vector &ToEnqueue, @@ -1451,8 +1493,26 @@ Scheduler::GraphBuilder::completeFusion(QueueImplPtr Queue, auto *PlaceholderCmd = FusionList->second.get(); auto &CmdList = PlaceholderCmd->getFusionList(); - // TODO: The logic to invoke the JIT compiler to create a fused kernel from - // the list will be added in a later PR. + // We need to check if fusing the kernel would create a circular dependency. A + // circular dependency would arise, if a kernel in the fusion list + // *indirectly* depends on another kernel in the fusion list. Here, indirectly + // means, that the dependency is created through a third command not part of + // the fusion, on which this kernel depends and which in turn depends on + // another kernel in fusion list. + bool CreatesCircularDep = + std::any_of(CmdList.begin(), CmdList.end(), [&](ExecCGCommand *Cmd) { + return checkForCircularDependency(Cmd, true, PlaceholderCmd); + }); + if (CreatesCircularDep) { + // If fusing would create a fused kernel, cancel the fusion. + printFusionWarning( + "Aborting fusion because it would create a circular dependency"); + auto LastEvent = PlaceholderCmd->getEvent(); + this->cancelFusion(Queue, ToEnqueue); + return LastEvent; + } + + // Call the JIT compiler to generate a new fused kernel. auto FusedCG = detail::jit_compiler::get_instance().fuseKernels( Queue, CmdList, PropList);