diff --git a/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx b/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx index b561511250c9b..19da04127dbeb 100644 --- a/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx +++ b/tree/dataframe/inc/ROOT/RDF/RActionSnapshot.hxx @@ -64,18 +64,15 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase { const auto ¤tVariations = GetVariations(); // If this node hangs from the RLoopManager itself, just use that as the upstream node for each variation - auto nominalPrevNode = fPrevNodes.begin(); - if (static_cast(nominalPrevNode->get()) == fLoopManager) { - fPrevNodes.resize(1 + currentVariations.size(), *nominalPrevNode); + auto nominalPrevNode = fPrevNodes.front(); + if (static_cast(nominalPrevNode.get()) == fLoopManager) { + fPrevNodes.resize(1 + currentVariations.size(), nominalPrevNode); return; } // Otherwise, append one varied filter per variation - const auto &prevVariations = (*nominalPrevNode)->GetVariations(); - - fPrevNodes.reserve(1 + prevVariations.size()); - // Get valid iterator after resizing - nominalPrevNode = fPrevNodes.begin(); + const auto &prevVariations = nominalPrevNode->GetVariations(); + fPrevNodes.reserve(1 + currentVariations.size()); // Need to populate parts of the computation graph for which we have empty shells, e.g. RJittedFilters if (!currentVariations.empty()) @@ -83,9 +80,9 @@ class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase { for (const auto &variation : currentVariations) { if (IsStrInVec(variation, prevVariations)) { fPrevNodes.emplace_back( - std::static_pointer_cast((*nominalPrevNode)->GetVariedFilter(variation))); + std::static_pointer_cast(nominalPrevNode->GetVariedFilter(variation))); } else { - fPrevNodes.push_back(*nominalPrevNode); + fPrevNodes.push_back(nominalPrevNode); } } } diff --git a/tree/dataframe/test/CMakeLists.txt b/tree/dataframe/test/CMakeLists.txt index 476ec58deea73..aad10aab36b54 100644 --- a/tree/dataframe/test/CMakeLists.txt +++ b/tree/dataframe/test/CMakeLists.txt @@ -161,6 +161,7 @@ if(pyroot) endif() if(NOT MSVC OR win_broken_tests) ROOT_ADD_PYUNITTEST(dataframe_merge_results dataframe_merge_results.py) + ROOT_ADD_PYUNITTEST(dataframe_snapshot_py dataframe_snapshot.py) endif() endif() diff --git a/tree/dataframe/test/dataframe_snapshot.py b/tree/dataframe/test/dataframe_snapshot.py new file mode 100644 index 0000000000000..9ed94a3c0cfc3 --- /dev/null +++ b/tree/dataframe/test/dataframe_snapshot.py @@ -0,0 +1,59 @@ +import unittest + +import ROOT + + +class SnapshotTests(unittest.TestCase): + # Regression described in https://github.com/root-project/root/issues/20320#issuecomment-3553697692 + # This was caused by an iterator invalidation when snapshots with JIT-ted filters is used + def test_snapshot(self): + df = ROOT.RDataFrame(10) + for var in ["pt", "eta", "phi", "pdgId", "mass", "tightId", "pfIsoId", "deltaEtaSC", "cutBased"]: + df = df.Define("Muon_%s" % var, "ROOT::RVec(2, 1.)") + df = df.Define("Electron_%s" % var, "ROOT::RVec(2, 1.)") + for var in ["pt", "eta", "phi", "pdgId", "mass"]: + for var2 in []: + df = df.Define("Muon_good_%s" % var2, "ROOT::RVec(2, 1.)") + df = df.Define( + "Muon_good_%s" % var, + "Muon_%s[abs(Muon_eta) < 2.4 && Muon_pt > 0 && Muon_tightId && Muon_pfIsoId >= 0]" % var, + ) + for var in ["pt", "eta", "phi", "pdgId", "mass"]: + df = df.Define( + "Electron_good_%s" % var, + "Electron_%s[!(abs(Electron_eta+Electron_deltaEtaSC)>0 && abs(Electron_eta+Electron_deltaEtaSC)< 0) && abs(Electron_eta)<2.4 && Electron_pt > 0 && Electron_cutBased > 0]" + % var, + ) + + df = df.Define("Muon_IDSF", "1+0.01*(Muon_pt-40)") + df = df.Vary( + "Muon_IDSF", + "ROOT::VecOps::RVec>({1+0.005*(Muon_pt-40), 1+0.02*(Muon_pt-40)})", + ["down", "up"], + "muon_unc", + ) + df = df.Define("Electron_IDSF", "1+0.01*(Electron_pt-40)") + df = df.Vary( + "Electron_IDSF", + "ROOT::VecOps::RVec>({1+0.005*(Electron_pt-40), 1+0.02*(Electron_pt-40)})", + ["down", "up"], + "electron_unc", + ) + + df = df.Filter("(Muon_good_pt.size() + Electron_good_pt.size()) > 0") + + comprAlgo = getattr(ROOT.RCompressionSetting.EAlgorithm, "kZLIB") + opts = ROOT.RDF.RSnapshotOptions("RECREATE", comprAlgo, 0, 0, 99, False) + opts.fIncludeVariations = True + + snapshot = df.Snapshot("Events", "output.root", ["Electron_IDSF", "Muon_IDSF"], opts) + self.assertIsNotNone(snapshot) + + with ROOT.TFile.Open("output.root") as f: + tree = f.Get("Events") + self.assertIsNotNone(tree) + self.assertEqual(tree.GetEntries(), 10) + + +if __name__ == "__main__": + unittest.main()