Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ntuple] Cache and expose field type name in REntry #17053

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions tree/ntuple/v7/inc/ROOT/REntry.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ private:
std::vector<RFieldBase::RValue> fValues;
/// For fast lookup of token IDs given a (sub)field name present in the entry
std::unordered_map<std::string, std::size_t> fFieldName2Token;
/// To ensure that the entry is standalone, a copy of all field types
std::vector<std::string> fFieldTypes;
hahnjo marked this conversation as resolved.
Show resolved Hide resolved

// Creation of entries is done by the RNTupleModel class

Expand All @@ -88,6 +90,7 @@ private:
void AddValue(RFieldBase::RValue &&value)
{
fFieldName2Token[value.GetField().GetQualifiedFieldName()] = fValues.size();
fFieldTypes.push_back(value.GetField().GetTypeName());
fValues.emplace_back(std::move(value));
}

Expand All @@ -96,6 +99,7 @@ private:
std::shared_ptr<T> AddValue(RField<T> &field)
{
fFieldName2Token[field.GetQualifiedFieldName()] = fValues.size();
fFieldTypes.push_back(field.GetTypeName());
auto value = field.CreateValue();
fValues.emplace_back(value);
return value.template GetPtr<T>();
Expand Down Expand Up @@ -134,14 +138,26 @@ private:
}
}

/// This function has linear complexity, only use for more helpful error messages!
const std::string &FindFieldName(RFieldToken token) const
{
for (const auto &[fieldName, index] : fFieldName2Token) {
if (index == token.fIndex) {
return fieldName;
}
}
// Should never happen, but avoid compiler warning about "returning reference to local temporary object".
static const std::string empty = "";
return empty;
}

template <typename T>
void EnsureMatchingType(RFieldToken token [[maybe_unused]]) const
{
if constexpr (!std::is_void_v<T>) {
const auto &v = fValues[token.fIndex];
if (v.GetField().GetTypeName() != RField<T>::TypeName()) {
throw RException(R__FAIL("type mismatch for field " + v.GetField().GetQualifiedFieldName() + ": " +
v.GetField().GetTypeName() + " vs. " + RField<T>::TypeName()));
if (fFieldTypes[token.fIndex] != RField<T>::TypeName()) {
throw RException(R__FAIL("type mismatch for field " + FindFieldName(token) + ": " +
fFieldTypes[token.fIndex] + " vs. " + RField<T>::TypeName()));
}
}
}
Expand Down Expand Up @@ -215,6 +231,14 @@ public:
return GetPtr<T>(GetToken(fieldName));
}

const std::string &GetTypeName(RFieldToken token) const
{
EnsureMatchingModel(token);
return fFieldTypes[token.fIndex];
}

const std::string &GetTypeName(std::string_view fieldName) const { return GetTypeName(GetToken(fieldName)); }

std::uint64_t GetModelId() const { return fModelId; }
std::uint64_t GetSchemaId() const { return fSchemaId; }

Expand Down
11 changes: 10 additions & 1 deletion tree/ntuple/v7/test/ntuple_basics.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ TEST(REntry, Basics)

auto e = model->CreateEntry();
EXPECT_EQ(e->GetModelId(), model->GetModelId());
EXPECT_EQ(e->GetSchemaId(), model->GetSchemaId());
for (const auto &v : *e) {
EXPECT_STREQ("pt", v.GetField().GetFieldName().c_str());
}
Expand All @@ -695,6 +696,9 @@ TEST(REntry, Basics)
EXPECT_THROW(e->GetToken("eta"), ROOT::Experimental::RException);
EXPECT_THROW(model->GetToken("eta"), ROOT::Experimental::RException);

EXPECT_EQ("float", e->GetTypeName("pt"));
EXPECT_EQ("float", e->GetTypeName(model->GetToken("pt")));

auto ptrPt = std::make_shared<float>();
e->BindValue("pt", ptrPt);
EXPECT_EQ(ptrPt.get(), e->GetPtr<float>("pt").get());
Expand Down Expand Up @@ -726,7 +730,12 @@ TEST(REntry, Basics)
EXPECT_EQ(&pt, e->GetPtr<void>("pt").get());

e->EmplaceNewValue(model->GetToken("pt"));
EXPECT_NE(&pt, e->GetPtr<void>("pt").get());
ptrPt = e->GetPtr<float>("pt");
EXPECT_NE(&pt, ptrPt.get());

// Tokens are standalone and can be used after model destruction
model.reset();
EXPECT_EQ(ptrPt, e->GetPtr<float>("pt"));
}

TEST(RFieldBase, CreateObject)
Expand Down
Loading