Skip to content

Commit

Permalink
AVRO-3984 [C++] Improve code generated for unions (#3047)
Browse files Browse the repository at this point in the history
* AVRO-3984 [C++] Getters created by avrogencpp return a reference instead of a value to avoid calling copy constructor of large classes

* AVRO-3984 [C++] Add getter for generated unions that returns a mutable reference. This allows the user to modify values in union branches after creation (#3047)

* AVRO-3984 [C++] Add move setters for generated unions to provide a more efficient way to set a value (#3047)

* AVRO-3984 [C++] Use std::move in decode implementation of codec_traits for unions to avoid a copy (#3047)

* AVRO-3984 [C++] Generate an enum for each union type that maps the branch names to the corresponding index. This allows the user to avoid checks against "magic numbers" (#3047)

* AVRO-3984 [C++] Add additional checks for the union branch in testUnionMethods test (#3047)

* AVRO-3984 [C++] Add additional branch() method that returns the Branch enum directly, this avoids a manual static_cast (#3047)

---------

Co-authored-by: hwse <[email protected]>
  • Loading branch information
hwse and hwse authored Aug 21, 2024
1 parent 789b2a0 commit f350a8f
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 9 deletions.
3 changes: 2 additions & 1 deletion lang/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ gen (crossref cr)
gen (primitivetypes pt)
gen (cpp_reserved_words cppres)
gen (cpp_reserved_words_union_typedef cppres_union)
gen (big_union big_union)

add_executable (avrogencpp impl/avrogencpp.cc)
target_link_libraries (avrogencpp avrocpp_s)
Expand Down Expand Up @@ -226,7 +227,7 @@ add_dependencies (AvrogencppTests bigrecord_hh bigrecord_r_hh bigrecord2_hh
union_array_union_hh union_map_union_hh union_conflict_hh
recursive_hh reuse_hh circulardep_hh tree1_hh tree2_hh crossref_hh
primitivetypes_hh empty_record_hh cpp_reserved_words_union_typedef_hh
union_empty_record_hh)
union_empty_record_hh big_union_hh)

include (InstallRequiredSystemLibraries)

Expand Down
60 changes: 52 additions & 8 deletions lang/c++/impl/avrogencpp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,12 +313,21 @@ static void generateGetterAndSetter(ostream &os,

os << "inline\n";

os << type << sn << "get_" << name << "() const {\n"
os << "const " << type << "&" << sn << "get_" << name << "() const {\n"
<< " if (idx_ != " << idx << ") {\n"
<< " throw avro::Exception(\"Invalid type for "
<< "union " << structName << "\");\n"
<< " }\n"
<< " return std::any_cast<" << type << " >(value_);\n"
<< " return *std::any_cast<" << type << " >(&value_);\n"
<< "}\n\n";

os << "inline\n"
<< type << "&" << sn << "get_" << name << "() {\n"
<< " if (idx_ != " << idx << ") {\n"
<< " throw avro::Exception(\"Invalid type for "
<< "union " << structName << "\");\n"
<< " }\n"
<< " return *std::any_cast<" << type << " >(&value_);\n"
<< "}\n\n";

os << "inline\n"
Expand All @@ -327,6 +336,13 @@ static void generateGetterAndSetter(ostream &os,
<< " idx_ = " << idx << ";\n"
<< " value_ = v;\n"
<< "}\n\n";

os << "inline\n"
<< "void" << sn << "set_" << name
<< "(" << type << "&& v) {\n"
<< " idx_ = " << idx << ";\n"
<< " value_ = std::move(v);\n"
<< "}\n\n";
}

static void generateConstructor(ostream &os,
Expand Down Expand Up @@ -376,8 +392,33 @@ string CodeGen::generateUnionType(const NodePtr &n) {
<< "private:\n"
<< " size_t idx_;\n"
<< " std::any value_;\n"
<< "public:\n"
<< " size_t idx() const { return idx_; }\n";
<< "public:\n";

os_ << " /** enum representing union branches as returned by the idx() function */\n"
<< " enum class Branch: size_t {\n";

// generate a enum that maps the branch name to the corresponding index (as returned by idx())
std::set<std::string> used_branch_names;
for (size_t i = 0; i < c; ++i) {
// escape reserved literals for c++
auto branch_name = decorate(names[i]);
// avoid rare collisions, e.g. somone might name their struct int_
if (used_branch_names.find(branch_name) != used_branch_names.end()) {
size_t postfix = 2;
std::string escaped_name = branch_name + "_" + std::to_string(postfix);
while (used_branch_names.find(escaped_name) != used_branch_names.end()) {
++postfix;
escaped_name = branch_name + "_" + std::to_string(postfix);
}
branch_name = escaped_name;
}
os_ << " " << branch_name << " = " << i << ",\n";
used_branch_names.insert(branch_name);
}
os_ << " };\n";

os_ << " size_t idx() const { return idx_; }\n";
os_ << " Branch branch() const { return static_cast<Branch>(idx_); }\n";

for (size_t i = 0; i < c; ++i) {
const NodePtr &nn = n->leafAt(i);
Expand All @@ -392,9 +433,11 @@ string CodeGen::generateUnionType(const NodePtr &n) {
} else {
const string &type = types[i];
const string &name = names[i];
os_ << " " << type << " get_" << name << "() const;\n"
" void set_"
<< name << "(const " << type << "& v);\n";
os_ << " "
<< "const " << type << "& get_" << name << "() const;\n"
<< " " << type << "& get_" << name << "();\n"
<< " void set_" << name << "(const " << type << "& v);\n"
<< " void set_" << name << "(" << type << "&& v);\n";
pendingGettersAndSetters.emplace_back(result, type, name, i);
}
}
Expand Down Expand Up @@ -645,7 +688,7 @@ void CodeGen::generateUnionTraits(const NodePtr &n) {
os_ << " {\n"
<< " " << cppTypeOf(nn) << " vv;\n"
<< " avro::decode(d, vv);\n"
<< " v.set_" << cppNameOf(nn) << "(vv);\n"
<< " v.set_" << cppNameOf(nn) << "(std::move(vv));\n"
<< " }\n";
}
os_ << " break;\n";
Expand Down Expand Up @@ -730,6 +773,7 @@ void CodeGen::generate(const ValidSchema &schema) {

os_ << "#include <sstream>\n"
<< "#include <any>\n"
<< "#include <utility>\n"
<< "#include \"" << includePrefix_ << "Specific.hh\"\n"
<< "#include \"" << includePrefix_ << "Encoder.hh\"\n"
<< "#include \"" << includePrefix_ << "Decoder.hh\"\n"
Expand Down
101 changes: 101 additions & 0 deletions lang/c++/jsonschemas/big_union
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"type": "record",
"doc": "Top level Doc.",
"name": "RootRecord",
"fields": [
{
"name": "big_union",
"doc": "A large union containing the primitive types, a array, a map and records.",
"type": [
"null",
"boolean",
"int",
"long",
"float",
"double",
{
"type": "fixed",
"size": 16,
"name": "MD5"
},
"string",
{
"type": "record",
"name": "Vec2",
"fields": [
{
"name": "x",
"type": "long"
},
{
"name": "y",
"type": "long"
}
]
},
{
"type": "record",
"name": "Vec3",
"fields": [
{
"name": "x",
"type": "long"
},
{
"name": "y",
"type": "long"
},
{
"name": "z",
"type": "long"
}
]
},
{
"type": "enum",
"name": "Suit",
"symbols": [
"SPADES",
"HEARTS",
"DIAMONDS",
"CLUBS"
]
},
{
"type": "array",
"items": "string",
"default": []
},
{
"type": "map",
"values": "long",
"default": {}
},
{
"type": "record",
"name": "int_",
"doc": "try to force a collision with int",
"fields": []
},
{
"type": "record",
"name": "int__",
"doc": "try to force a collision with int",
"fields": []
},
{
"type": "record",
"name": "Int",
"doc": "name similar to primitive name",
"fields": []
},
{
"type": "record",
"name": "_Int",
"doc": "name with underscore as prefix",
"fields": []
}
]
}
]
}
110 changes: 110 additions & 0 deletions lang/c++/test/AvrogencppTests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include "Compiler.hh"
#include "big_union.hh"
#include "bigrecord.hh"
#include "bigrecord_r.hh"
#include "tweet.hh"
Expand Down Expand Up @@ -132,6 +133,14 @@ void checkDefaultValues(const testgen_r::RootRecord &r) {
BOOST_CHECK_EQUAL(r.byteswithDefaultValue.get_bytes()[1], 0xaa);
}

// enable use of BOOST_CHECK_EQUAL
template<>
struct boost::test_tools::tt_detail::print_log_value<big_union::RootRecord::big_union_t::Branch> {
void operator()(std::ostream &stream, const big_union::RootRecord::big_union_t::Branch &branch) const {
stream << "big_union_t::Branch{" << static_cast<size_t>(branch) << "}";
}
};

void testEncoding() {
ValidSchema s;
ifstream ifs("jsonschemas/bigrecord");
Expand Down Expand Up @@ -300,6 +309,105 @@ void testEmptyRecord() {
BOOST_CHECK_EQUAL(calc2.stack[2].idx(), 2);
}

void testUnionMethods() {
ValidSchema schema;
ifstream ifs_w("jsonschemas/bigrecord");
compileJsonSchema(ifs_w, schema);

testgen::RootRecord record;
// initialize the map and set values with getter
record.myunion.set_map({});
record.myunion.get_map()["zero"] = 0;
record.myunion.get_map()["one"] = 1;

std::vector<uint8_t> bytes{1, 2, 3, 4};
record.anotherunion.set_bytes(std::move(bytes));
// after move assignment the local variable should be empty
BOOST_CHECK(bytes.empty());

unique_ptr<OutputStream> out_stream = memoryOutputStream();
EncoderPtr encoder = validatingEncoder(schema, binaryEncoder());
encoder->init(*out_stream);
avro::encode(*encoder, record);
encoder->flush();

DecoderPtr decoder = validatingDecoder(schema, binaryDecoder());
unique_ptr<InputStream> is = memoryInputStream(*out_stream);
decoder->init(*is);
testgen::RootRecord decoded_record;
avro::decode(*decoder, decoded_record);

// check that a reference can be obtained from a union
BOOST_CHECK(decoded_record.myunion.branch() == testgen::RootRecord::myunion_t::Branch::map);
const std::map<std::string, int32_t> &read_map = decoded_record.myunion.get_map();
BOOST_CHECK_EQUAL(read_map.size(), 2);
BOOST_CHECK_EQUAL(read_map.at("zero"), 0);
BOOST_CHECK_EQUAL(read_map.at("one"), 1);

BOOST_CHECK(decoded_record.anotherunion.branch() == testgen::RootRecord::anotherunion_t::Branch::bytes);
const std::vector<uint8_t> read_bytes = decoded_record.anotherunion.get_bytes();
const std::vector<uint8_t> expected_bytes{1, 2, 3, 4};
BOOST_CHECK_EQUAL_COLLECTIONS(read_bytes.begin(), read_bytes.end(), expected_bytes.begin(), expected_bytes.end());
}

void testUnionBranchEnum() {
big_union::RootRecord record;

using Branch = big_union::RootRecord::big_union_t::Branch;

BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::null);
record.big_union.set_null();
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::null);

record.big_union.set_bool(false);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::bool_);

record.big_union.set_int(123);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int_);

record.big_union.set_long(456);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::long_);

record.big_union.set_float(555.555f);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::float_);

record.big_union.set_double(777.777);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::double_);

record.big_union.set_MD5({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::MD5);

record.big_union.set_string("test");
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::string);

record.big_union.set_Vec2({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Vec2);

record.big_union.set_Vec3({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Vec3);

record.big_union.set_Suit(big_union::Suit::CLUBS);
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Suit);

record.big_union.set_array({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::array);

record.big_union.set_map({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::map);

record.big_union.set_int_({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int__2);

record.big_union.set_int__({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::int__);

record.big_union.set_Int({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::Int);

record.big_union.set__Int({});
BOOST_CHECK_EQUAL(record.big_union.branch(), Branch::_Int);
}

boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/[]) {
auto *ts = BOOST_TEST_SUITE("Code generator tests");
ts->add(BOOST_TEST_CASE(testEncoding));
Expand All @@ -308,5 +416,7 @@ boost::unit_test::test_suite *init_unit_test_suite(int /*argc*/, char * /*argv*/
ts->add(BOOST_TEST_CASE(testEncoding2<umu::r1>));
ts->add(BOOST_TEST_CASE(testNamespace));
ts->add(BOOST_TEST_CASE(testEmptyRecord));
ts->add(BOOST_TEST_CASE(testUnionMethods));
ts->add(BOOST_TEST_CASE(testUnionBranchEnum));
return ts;
}

0 comments on commit f350a8f

Please sign in to comment.