Skip to content

Commit

Permalink
feat: constexpr uniform_hash and type fixes (#4415)
Browse files Browse the repository at this point in the history
* feat: constexpr uniform_hash and type fixes

* mark constexpr

* formatting

* add casts

* formatting

* integrate new tests
  • Loading branch information
jackgerrits authored Jan 6, 2023
1 parent c766d57 commit 07e066f
Show file tree
Hide file tree
Showing 17 changed files with 73 additions and 47 deletions.
2 changes: 1 addition & 1 deletion vowpalwabbit/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ vw_add_test_executable(
FOR_LIB "common"
SOURCES
tests/basic_tokenize_test.cc
tests/text_utils_test.cc
tests/hash_test.cc
tests/text_utils_test.cc
)
44 changes: 25 additions & 19 deletions vowpalwabbit/common/include/vw/common/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ license as described in the file LICENSE.
//
// Adopted for VW and contributed by Ariel Faigon.
//
// Constexpr changes from:
// https://github.com/AntonJohansson/StaticMurmur/blob/master/StaticMurmur.hpp

//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
Expand Down Expand Up @@ -42,7 +44,7 @@ constexpr inline uint32_t rotl32(uint32_t x, int8_t r) noexcept { return (x << r

//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche
VW_STD14_CONSTEXPR static inline uint32_t fmix(uint32_t h) noexcept
VW_STD14_CONSTEXPR inline uint32_t fmix(uint32_t h) noexcept
{
h ^= h >> 16;
h *= 0x85ebca6b;
Expand All @@ -56,31 +58,26 @@ VW_STD14_CONSTEXPR static inline uint32_t fmix(uint32_t h) noexcept
//-----------------------------------------------------------------------------
// Block read - if your platform needs to do endian-swapping or can only
// handle aligned reads, do the conversion here
static inline uint32_t getblock(const uint32_t* p, int i) noexcept
VW_STD14_CONSTEXPR inline uint32_t get_block(const char* p, size_t i)
{
uint32_t block = 0;
memcpy(&block, &p[i], sizeof(uint32_t));
uint32_t block = static_cast<uint8_t>(p[0 + i * 4]) << 0 | static_cast<uint8_t>(p[1 + i * 4]) << 8 |
static_cast<uint8_t>(p[2 + i * 4]) << 16 | static_cast<uint8_t>(p[3 + i * 4]) << 24;
return block;
}

} // namespace details

inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
VW_STD14_CONSTEXPR inline uint32_t murmurhash_x86_32(const char* data, size_t len, uint32_t seed)
{
const uint8_t* data = static_cast<const uint8_t*>(key);
const int nblocks = static_cast<int>(len) / 4;
const auto num_blocks = len / 4;

uint32_t h1 = static_cast<uint32_t>(seed);
auto h1 = seed;

const uint32_t c1 = 0xcc9e2d51;
const uint32_t c2 = 0x1b873593;

// --- body
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4);

for (int i = -nblocks; i; i++)
for (size_t i = 0; i < num_blocks; i++)
{
uint32_t k1 = details::getblock(blocks, i);
uint32_t k1 = details::get_block(data, i);

k1 *= c1;
k1 = details::rotl32(k1, 15);
Expand All @@ -92,7 +89,7 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
}

// --- tail
const uint8_t* tail = data + nblocks * 4;
const char* tail = data + num_blocks * 4;

uint32_t k1 = 0;

Expand All @@ -101,13 +98,13 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
switch (len & 3u)
{
case 3:
k1 ^= tail[2] << 16;
k1 ^= static_cast<unsigned char>(tail[2]) << 16;
VW_FALLTHROUGH
case 2:
k1 ^= tail[1] << 8;
k1 ^= static_cast<unsigned char>(tail[1]) << 8;
VW_FALLTHROUGH
case 1:
k1 ^= tail[0];
k1 ^= static_cast<unsigned char>(tail[0]);
k1 *= c1;
k1 = details::rotl32(k1, 15);
k1 *= c2;
Expand All @@ -122,7 +119,16 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)

return details::fmix(h1);
}
} // namespace details

VW_STD14_CONSTEXPR inline uint32_t uniform_hash(const char* data, size_t len, uint32_t seed)
{
return details::murmurhash_x86_32(data, len, seed);
}
} // namespace VW

VW_DEPRECATED("uniform_hash has been moved into VW namespace")
inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed) { return VW::uniform_hash(key, len, seed); }
inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
{
return VW::uniform_hash(reinterpret_cast<const char*>(key), len, static_cast<uint32_t>(seed));
}
20 changes: 17 additions & 3 deletions vowpalwabbit/core/include/vw/core/hashstring.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
#include <cstdint>
#include <string>

inline uint64_t hashall(const char* s, size_t len, uint64_t h) { return VW::uniform_hash(s, len, h); }
namespace VW
{
namespace details
{
VW_STD14_CONSTEXPR inline uint32_t hashall(const char* s, size_t len, uint32_t h)
{
return VW::uniform_hash(s, len, h);
}

inline uint64_t hashstring(const char* s, size_t len, uint64_t h)
VW_STD14_CONSTEXPR inline uint32_t hashstring(const char* s, size_t len, uint32_t h)
{
const char* front = s;
while (len > 0 && front[0] <= 0x20 && static_cast<int>(front[0]) >= 0)
Expand All @@ -32,7 +39,14 @@ inline uint64_t hashstring(const char* s, size_t len, uint64_t h)

return ret + h;
}
} // namespace details

using hash_func_t = uint64_t (*)(const char*, size_t, uint64_t);
using hash_func_t = uint32_t (*)(const char*, size_t, uint32_t);

hash_func_t get_hasher(const std::string& s);

} // namespace VW
using hash_func_t VW_DEPRECATED("Moved into VW namespace") = uint32_t (*)(const char*, size_t, uint32_t);

VW_DEPRECATED("Moved into VW namespace")
VW::hash_func_t get_hasher(const std::string& s);
6 changes: 3 additions & 3 deletions vowpalwabbit/core/include/vw/core/parse_slates_example_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

template <bool audit>
VW_DEPRECATED("parse_slates_example_json moved to VW::parsers::json::details::parse_slates_example_json")
void parse_slates_example_json(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed,
void parse_slates_example_json(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed,
uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples = nullptr)
Expand Down Expand Up @@ -36,11 +36,11 @@ void parse_slates_example_dsjson(VW::workspace& all, VW::multi_ex& examples, cha
}

// Define extern template specializations so they don't get initialized when this file is included
extern template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
extern template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
extern template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
extern template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
Expand Down
5 changes: 3 additions & 2 deletions vowpalwabbit/core/src/cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ static void parse_label_cb(VW::cb_label& ld, VW::label_parser_reuse_mem& reuse_m
if (reuse_mem.tokens.empty() || reuse_mem.tokens.size() > 3) { THROW("malformed cost specification: " << word); }

f.partial_prediction = 0.;
f.action = static_cast<uint32_t>(hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
f.action =
static_cast<uint32_t>(VW::details::hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
f.cost = FLT_MAX;

if (reuse_mem.tokens.size() > 1) { f.cost = VW::details::float_of_string(reuse_mem.tokens[1], logger); }
Expand Down Expand Up @@ -199,7 +200,7 @@ void parse_label_cb_eval(VW::cb_eval_label& ld, VW::label_parser_reuse_mem& reus
{
if (words.size() < 2) THROW("Evaluation can not happen without an action and an exploration");

ld.action = static_cast<uint32_t>(hashstring(words[0].data(), words[0].length(), 0));
ld.action = static_cast<uint32_t>(VW::details::hashstring(words[0].data(), words[0].length(), 0));

// TODO - make this a span and there is no allocation
const auto rest_of_tokens = std::vector<VW::string_view>(words.begin() + 1, words.end());
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/cost_sensitive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void parse_label(VW::cs_label& ld, VW::label_parser_reuse_mem& reuse_mem, const
{
f.class_index = ldict
? ldict->get(reuse_mem.tokens[0], logger)
: static_cast<uint32_t>(hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
: static_cast<uint32_t>(VW::details::hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
if (reuse_mem.tokens.size() == 1 && f.x >= 0)
{ // test examples are specified just by un-valued class #s
f.x = FLT_MAX;
Expand Down
8 changes: 5 additions & 3 deletions vowpalwabbit/core/src/hashstring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

#include <string>

hash_func_t get_hasher(const std::string& s)
VW::hash_func_t get_hasher(const std::string& s) { return VW::get_hasher(s); }

VW::hash_func_t VW::get_hasher(const std::string& s)
{
if (s == "strings") { return hashstring; }
else if (s == "all") { return hashall; }
if (s == "strings") { return VW::details::hashstring; }
else if (s == "all") { return VW::details::hashall; }
else
THROW("Unknown hash function: " << s);
}
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/parse_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti
options.add_and_parse(feature_options);

// feature manipulation
all.example_parser->hasher = get_hasher(hash_function);
all.example_parser->hasher = VW::get_hasher(hash_function);

if (options.was_supplied("spelling"))
{
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_slates_example_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
#include "vw/common/future_compat.h"

// Explicitly instantiate templates only in this source file
template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/core/src/reductions/cats_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ void cats_tree::learn(LEARNER::single_learner& base, example& ec)
if (ec.weight < weight_th)
{
// generate a new seed
uint64_t new_random_seed = VW::uniform_hash(&app_seed, sizeof(app_seed), app_seed);
uint64_t new_random_seed = VW::uniform_hash(
reinterpret_cast<const char*>(&app_seed), sizeof(app_seed), static_cast<uint32_t>(app_seed));
// pick a uniform random number between 0.0 - .001f
float random_draw = merand48(new_random_seed) * weight_th;
if (random_draw < ec.weight) { ec.weight = weight_th; }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class search_private
size_t operator()(const byte_array& key) const
{
size_t sz = *key.get();
return VW::uniform_hash(key.get(), sz, SEARCH_HASH_SEED);
return VW::uniform_hash(reinterpret_cast<const char*>(key.get()), sz, static_cast<uint32_t>(SEARCH_HASH_SEED));
}
};

Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/core/tests/random_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TEST(Rand, ReproduceMaxBoundaryIssue)
{
uint64_t seed = 58587211;
const uint64_t new_random_seed = VW::uniform_hash(&seed, sizeof(seed), seed);
const uint64_t new_random_seed =
VW::uniform_hash(reinterpret_cast<const char*>(&seed), sizeof(seed), static_cast<uint32_t>(seed));
EXPECT_EQ(new_random_seed, 2244123448);

float random_draw = merand48_noadvance(new_random_seed);
Expand Down
6 changes: 3 additions & 3 deletions vowpalwabbit/json_parser/src/parse_example_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ class Context
{
public:
VW::label_parser _label_parser;
hash_func_t _hash_func;
VW::hash_func_t _hash_func;
uint64_t _hash_seed;
uint64_t _parse_mask;
bool _chain_hash;
Expand Down Expand Up @@ -1565,7 +1565,7 @@ class Context
root_state = &default_state;
}

void init(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
void init(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
bool chain_hash, VW::label_parser_reuse_mem* reuse_mem, const VW::named_labels* ldict, VW::io::logger* logger)
{
assert(reuse_mem != nullptr);
Expand Down Expand Up @@ -1633,7 +1633,7 @@ class VWReaderHandler : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, V
public:
Context<audit> ctx;

void init(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
void init(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
bool chain_hash, VW::label_parser_reuse_mem* reuse_mem, const VW::named_labels* ldict, VW::io::logger* logger,
VW::multi_ex* examples, rapidjson::InsituStringStream* stream, const char* stream_end,
VW::example_factory_t example_factory, void* example_factory_context,
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/json_parser/src/parse_example_slates_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ inline float get_number(const rapidjson::Value& value)

template <bool audit>
void handle_features_value(const char* key_namespace, const Value& value, VW::example* current_example,
std::vector<VW::parsers::json::details::namespace_builder<audit>>& namespaces, hash_func_t hash_func,
std::vector<VW::parsers::json::details::namespace_builder<audit>>& namespaces, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash)
{
assert(key_namespace != nullptr);
Expand Down Expand Up @@ -152,7 +152,7 @@ void handle_features_value(const char* key_namespace, const Value& value, VW::ex
// NO_SANITIZE_UNDEFINED needed because example_factory function pointer may be typecasted
template <bool audit>
void NO_SANITIZE_UNDEFINED parse_context(const Value& context, const VW::label_parser& lbl_parser,
hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples,
VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples,
VW::example_factory_t example_factory, void* ex_factory_context, VW::multi_ex& slot_examples,
std::unordered_map<uint64_t, VW::example*>* dedup_examples = nullptr)
{
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/slim/include/vw/slim/model_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class model_parser
// avoid alignment issues for 32/64bit types on e.g. Android/ARM
memcpy(&val, data, sizeof(T));

if (compute_checksum) { _checksum = (uint32_t)VW::uniform_hash(&val, sizeof(T), _checksum); }
if (compute_checksum) { _checksum = VW::uniform_hash(reinterpret_cast<const char*>(&val), sizeof(T), _checksum); }

#ifdef MODEL_PARSER_DEBUG
log << " '" << val << '\'' << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/slim/src/example_predict_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ example_predict_builder::example_predict_builder(
{
_feature_index_bit_mask = ((uint64_t)1 << feature_index_num_bits) - 1;
add_namespace(namespace_name[0]);
_namespace_hash = hashstring(namespace_name, strlen(namespace_name), 0);
_namespace_hash = VW::details::hashstring(namespace_name, strlen(namespace_name), 0);
}

example_predict_builder::example_predict_builder(
Expand All @@ -32,7 +32,7 @@ void example_predict_builder::add_namespace(VW::namespace_index feature_group)
void example_predict_builder::push_feature_string(const char* feature_name, VW::feature_value value)
{
VW::feature_index feature_hash =
_feature_index_bit_mask & hashstring(feature_name, strlen(feature_name), _namespace_hash);
_feature_index_bit_mask & VW::details::hashstring(feature_name, strlen(feature_name), _namespace_hash);
_ex->feature_space[_namespace_idx].push_back(value, feature_hash);
}

Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/text_parser/src/parse_example_text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ class tc_parser
}

VW::string_view spelling_strview(_spelling.data(), _spelling.size());
word_hash = hashstring(spelling_strview.data(), spelling_strview.length(), (uint64_t)_channel_hash);
word_hash =
VW::details::hashstring(spelling_strview.data(), spelling_strview.length(), (uint64_t)_channel_hash);
spell_fs.push_back(_v, word_hash, VW::details::SPELLING_NAMESPACE);
if (audit)
{
Expand Down

0 comments on commit 07e066f

Please sign in to comment.