Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ntuple] Implement minimal Python API #17104

Merged
merged 3 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bindings/pyroot/pythonizations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ if(tmva)
endif()
endif()

if(root7)
list(APPEND PYROOT_EXTRA_PYTHON_SOURCES
ROOT/_pythonization/_rntuple.py)
endif()

list(APPEND PYROOT_EXTRA_HEADERS
inc/TPyDispatcher.h)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Author: Jonas Hahnfeld CERN 11/2024

################################################################################
# Copyright (C) 1995-2024, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
################################################################################

from . import pythonization
from ._pyz_utils import MethodTemplateGetter, MethodTemplateWrapper


def _REntry_GetPtr(self, key):
# key can be either a RFieldToken already or a string. In the latter case, get a token to use it twice.
if (
not hasattr(type(key), "__cpp_name__")
or type(key).__cpp_name__ != "ROOT::Experimental::REntry::RFieldToken"
):
key = self.GetToken(key)
fieldType = self.GetTypeName(key)
return self._GetPtr[fieldType](key)


def _REntry_getitem(self, key):
ptr = self.GetPtr(key)
return ptr.get()[0]


def _REntry_setitem(self, key, value):
ptr = self.GetPtr(key)
ptr.get()[0] = value


@pythonization("REntry", ns="ROOT::Experimental")
def pythonize_REntry(klass):
klass._GetPtr = klass.GetPtr
klass.GetPtr = _REntry_GetPtr

klass.__getitem__ = _REntry_getitem
klass.__setitem__ = _REntry_setitem


def _RNTupleModel_CreateBare(*args):
if len(args) >= 1:
raise ValueError("no support for passing explicit RFieldZero")
import ROOT

return ROOT.Experimental.RNTupleModel._CreateBare()


def _RNTupleModel_GetDefaultEntry(self):
raise RuntimeError("default entries are not supported in Python, call CreateEntry")


class _RNTupleModel_MakeField(MethodTemplateWrapper):
def __call__(self, *args):
self._original_method(*args)
# We do not support default entries in Python, so do not even return the nullptr.
return


@pythonization("RNTupleModel", ns="ROOT::Experimental")
def pythonize_RNTupleModel(klass):
# We do not support default entries in Python, so always create a bare model.
klass.Create = _RNTupleModel_CreateBare
klass._CreateBare = klass.CreateBare
klass.CreateBare = _RNTupleModel_CreateBare

klass.GetDefaultEntry = _RNTupleModel_GetDefaultEntry

klass.MakeField = MethodTemplateGetter(klass.MakeField, _RNTupleModel_MakeField)


def _RNTupleReader_LoadEntry(self, *args):
if len(args) < 2:
raise ValueError(
"default entries are not supported in Python, pass explicit entry"
)
return self._LoadEntry(*args)


@pythonization("RNTupleReader", ns="ROOT::Experimental")
def pythonize_RNTupleReader(klass):
klass._LoadEntry = klass.LoadEntry
klass.LoadEntry = _RNTupleReader_LoadEntry


def _RNTupleWriter_Fill(self, *args):
if len(args) < 1:
raise ValueError(
"default entries are not supported in Python, pass explicit entry"
)
return self._Fill(*args)


def _RNTupleWriter_exit(self, *args):
self.CommitDataset()
return False


@pythonization("RNTupleWriter", ns="ROOT::Experimental")
def pythonize_RNTupleWriter(klass):
klass._Fill = klass.Fill
klass.Fill = _RNTupleWriter_Fill

klass.__enter__ = lambda writer: writer
klass.__exit__ = _RNTupleWriter_exit
10 changes: 7 additions & 3 deletions tree/ntuple/v7/test/ntuple_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ def test_write_read(self):

model = RNTupleModel.Create()
model.MakeField["int"]("f")
writer = RNTupleWriter.Recreate(model, "ntpl", "test_ntuple_py_write_read.root")
writer.Fill()
del writer
with RNTupleWriter.Recreate(model, "ntpl", "test_ntuple_py_write_read.root") as writer:
entry = writer.CreateEntry()
entry["f"] = 42
writer.Fill(entry)

reader = RNTupleReader.Open("ntpl", "test_ntuple_py_write_read.root")
self.assertEqual(reader.GetNEntries(), 1)
entry = reader.GetModel().CreateEntry()
reader.LoadEntry(0, entry)
self.assertEqual(entry["f"], 42)
Loading