Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: constexpr uniform_hash and type fixes #4415

Merged
merged 10 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vowpalwabbit/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ vw_add_test_executable(
FOR_LIB "common"
SOURCES
tests/basic_tokenize_test.cc
tests/text_utils_test.cc
tests/hash_test.cc
tests/text_utils_test.cc
)
44 changes: 25 additions & 19 deletions vowpalwabbit/common/include/vw/common/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ license as described in the file LICENSE.
//
// Adopted for VW and contributed by Ariel Faigon.
//
// Constexpr changes from:
// https://github.com/AntonJohansson/StaticMurmur/blob/master/StaticMurmur.hpp

//-----------------------------------------------------------------------------
// MurmurHash3 was written by Austin Appleby, and is placed in the public
Expand Down Expand Up @@ -42,7 +44,7 @@ constexpr inline uint32_t rotl32(uint32_t x, int8_t r) noexcept { return (x << r

//-----------------------------------------------------------------------------
// Finalization mix - force all bits of a hash block to avalanche
VW_STD14_CONSTEXPR static inline uint32_t fmix(uint32_t h) noexcept
VW_STD14_CONSTEXPR inline uint32_t fmix(uint32_t h) noexcept
{
h ^= h >> 16;
h *= 0x85ebca6b;
Expand All @@ -56,31 +58,26 @@ VW_STD14_CONSTEXPR static inline uint32_t fmix(uint32_t h) noexcept
//-----------------------------------------------------------------------------
// Block read - if your platform needs to do endian-swapping or can only
// handle aligned reads, do the conversion here
static inline uint32_t getblock(const uint32_t* p, int i) noexcept
VW_STD14_CONSTEXPR inline uint32_t get_block(const char* p, size_t i)
{
uint32_t block = 0;
memcpy(&block, &p[i], sizeof(uint32_t));
uint32_t block = static_cast<uint8_t>(p[0 + i * 4]) << 0 | static_cast<uint8_t>(p[1 + i * 4]) << 8 |
static_cast<uint8_t>(p[2 + i * 4]) << 16 | static_cast<uint8_t>(p[3 + i * 4]) << 24;
return block;
}

} // namespace details

inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
VW_STD14_CONSTEXPR inline uint32_t murmurhash_x86_32(const char* data, size_t len, uint32_t seed)
{
const uint8_t* data = static_cast<const uint8_t*>(key);
const int nblocks = static_cast<int>(len) / 4;
const auto num_blocks = len / 4;

uint32_t h1 = static_cast<uint32_t>(seed);
auto h1 = seed;

const uint32_t c1 = 0xcc9e2d51;
const uint32_t c2 = 0x1b873593;

// --- body
const uint32_t* blocks = (const uint32_t*)(data + nblocks * 4);

for (int i = -nblocks; i; i++)
for (size_t i = 0; i < num_blocks; i++)
{
uint32_t k1 = details::getblock(blocks, i);
uint32_t k1 = details::get_block(data, i);

k1 *= c1;
k1 = details::rotl32(k1, 15);
Expand All @@ -92,7 +89,7 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
}

// --- tail
const uint8_t* tail = data + nblocks * 4;
const char* tail = data + num_blocks * 4;

uint32_t k1 = 0;

Expand All @@ -101,13 +98,13 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
switch (len & 3u)
{
case 3:
k1 ^= tail[2] << 16;
k1 ^= static_cast<unsigned char>(tail[2]) << 16;
VW_FALLTHROUGH
case 2:
k1 ^= tail[1] << 8;
k1 ^= static_cast<unsigned char>(tail[1]) << 8;
VW_FALLTHROUGH
case 1:
k1 ^= tail[0];
k1 ^= static_cast<unsigned char>(tail[0]);
k1 *= c1;
k1 = details::rotl32(k1, 15);
k1 *= c2;
Expand All @@ -122,7 +119,16 @@ inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)

return details::fmix(h1);
}
} // namespace details

VW_STD14_CONSTEXPR inline uint32_t uniform_hash(const char* data, size_t len, uint32_t seed)
{
return details::murmurhash_x86_32(data, len, seed);
}
} // namespace VW

VW_DEPRECATED("uniform_hash has been moved into VW namespace")
inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed) { return VW::uniform_hash(key, len, seed); }
inline uint64_t uniform_hash(const void* key, size_t len, uint64_t seed)
{
return VW::uniform_hash(reinterpret_cast<const char*>(key), len, static_cast<uint32_t>(seed));
}
20 changes: 17 additions & 3 deletions vowpalwabbit/core/include/vw/core/hashstring.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
#include <cstdint>
#include <string>

inline uint64_t hashall(const char* s, size_t len, uint64_t h) { return VW::uniform_hash(s, len, h); }
namespace VW
{
namespace details
{
VW_STD14_CONSTEXPR inline uint32_t hashall(const char* s, size_t len, uint32_t h)
{
return VW::uniform_hash(s, len, h);
}

inline uint64_t hashstring(const char* s, size_t len, uint64_t h)
VW_STD14_CONSTEXPR inline uint32_t hashstring(const char* s, size_t len, uint32_t h)
{
const char* front = s;
while (len > 0 && front[0] <= 0x20 && static_cast<int>(front[0]) >= 0)
Expand All @@ -32,7 +39,14 @@ inline uint64_t hashstring(const char* s, size_t len, uint64_t h)

return ret + h;
}
} // namespace details

using hash_func_t = uint64_t (*)(const char*, size_t, uint64_t);
using hash_func_t = uint32_t (*)(const char*, size_t, uint32_t);

hash_func_t get_hasher(const std::string& s);

} // namespace VW
using hash_func_t VW_DEPRECATED("Moved into VW namespace") = uint32_t (*)(const char*, size_t, uint32_t);

VW_DEPRECATED("Moved into VW namespace")
VW::hash_func_t get_hasher(const std::string& s);
6 changes: 3 additions & 3 deletions vowpalwabbit/core/include/vw/core/parse_slates_example_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

template <bool audit>
VW_DEPRECATED("parse_slates_example_json moved to VW::parsers::json::details::parse_slates_example_json")
void parse_slates_example_json(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed,
void parse_slates_example_json(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed,
uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples = nullptr)
Expand Down Expand Up @@ -36,11 +36,11 @@ void parse_slates_example_dsjson(VW::workspace& all, VW::multi_ex& examples, cha
}

// Define extern template specializations so they don't get initialized when this file is included
extern template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
extern template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
extern template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
extern template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
Expand Down
5 changes: 3 additions & 2 deletions vowpalwabbit/core/src/cb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ static void parse_label_cb(VW::cb_label& ld, VW::label_parser_reuse_mem& reuse_m
if (reuse_mem.tokens.empty() || reuse_mem.tokens.size() > 3) { THROW("malformed cost specification: " << word); }

f.partial_prediction = 0.;
f.action = static_cast<uint32_t>(hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
f.action =
static_cast<uint32_t>(VW::details::hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
f.cost = FLT_MAX;

if (reuse_mem.tokens.size() > 1) { f.cost = VW::details::float_of_string(reuse_mem.tokens[1], logger); }
Expand Down Expand Up @@ -199,7 +200,7 @@ void parse_label_cb_eval(VW::cb_eval_label& ld, VW::label_parser_reuse_mem& reus
{
if (words.size() < 2) THROW("Evaluation can not happen without an action and an exploration");

ld.action = static_cast<uint32_t>(hashstring(words[0].data(), words[0].length(), 0));
ld.action = static_cast<uint32_t>(VW::details::hashstring(words[0].data(), words[0].length(), 0));

// TODO - make this a span and there is no allocation
const auto rest_of_tokens = std::vector<VW::string_view>(words.begin() + 1, words.end());
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/cost_sensitive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void parse_label(VW::cs_label& ld, VW::label_parser_reuse_mem& reuse_mem, const
{
f.class_index = ldict
? ldict->get(reuse_mem.tokens[0], logger)
: static_cast<uint32_t>(hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
: static_cast<uint32_t>(VW::details::hashstring(reuse_mem.tokens[0].data(), reuse_mem.tokens[0].length(), 0));
if (reuse_mem.tokens.size() == 1 && f.x >= 0)
{ // test examples are specified just by un-valued class #s
f.x = FLT_MAX;
Expand Down
8 changes: 5 additions & 3 deletions vowpalwabbit/core/src/hashstring.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

#include <string>

hash_func_t get_hasher(const std::string& s)
VW::hash_func_t get_hasher(const std::string& s) { return VW::get_hasher(s); }

VW::hash_func_t VW::get_hasher(const std::string& s)
{
if (s == "strings") { return hashstring; }
else if (s == "all") { return hashall; }
if (s == "strings") { return VW::details::hashstring; }
else if (s == "all") { return VW::details::hashall; }
else
THROW("Unknown hash function: " << s);
}
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/parse_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ void parse_feature_tweaks(options_i& options, VW::workspace& all, bool interacti
options.add_and_parse(feature_options);

// feature manipulation
all.example_parser->hasher = get_hasher(hash_function);
all.example_parser->hasher = VW::get_hasher(hash_function);

if (options.was_supplied("spelling"))
{
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/core/src/parse_slates_example_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
#include "vw/common/future_compat.h"

// Explicitly instantiate templates only in this source file
template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
template void parse_slates_example_json<true>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, hash_func_t hash_func,
template void parse_slates_example_json<false>(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples, char* line, size_t length,
VW::example_factory_t example_factory, void* ex_factory_context,
std::unordered_map<uint64_t, VW::example*>* dedup_examples);
Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/core/src/reductions/cats_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ void cats_tree::learn(LEARNER::single_learner& base, example& ec)
if (ec.weight < weight_th)
{
// generate a new seed
uint64_t new_random_seed = VW::uniform_hash(&app_seed, sizeof(app_seed), app_seed);
uint64_t new_random_seed = VW::uniform_hash(
reinterpret_cast<const char*>(&app_seed), sizeof(app_seed), static_cast<uint32_t>(app_seed));
// pick a uniform random number between 0.0 - .001f
float random_draw = merand48(new_random_seed) * weight_th;
if (random_draw < ec.weight) { ec.weight = weight_th; }
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/core/src/reductions/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class search_private
size_t operator()(const byte_array& key) const
{
size_t sz = *key.get();
return VW::uniform_hash(key.get(), sz, SEARCH_HASH_SEED);
return VW::uniform_hash(reinterpret_cast<const char*>(key.get()), sz, static_cast<uint32_t>(SEARCH_HASH_SEED));
}
};

Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/core/tests/random_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
TEST(Rand, ReproduceMaxBoundaryIssue)
{
uint64_t seed = 58587211;
const uint64_t new_random_seed = VW::uniform_hash(&seed, sizeof(seed), seed);
const uint64_t new_random_seed =
VW::uniform_hash(reinterpret_cast<const char*>(&seed), sizeof(seed), static_cast<uint32_t>(seed));
EXPECT_EQ(new_random_seed, 2244123448);

float random_draw = merand48_noadvance(new_random_seed);
Expand Down
6 changes: 3 additions & 3 deletions vowpalwabbit/json_parser/src/parse_example_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ class Context
{
public:
VW::label_parser _label_parser;
hash_func_t _hash_func;
VW::hash_func_t _hash_func;
uint64_t _hash_seed;
uint64_t _parse_mask;
bool _chain_hash;
Expand Down Expand Up @@ -1565,7 +1565,7 @@ class Context
root_state = &default_state;
}

void init(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
void init(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
bool chain_hash, VW::label_parser_reuse_mem* reuse_mem, const VW::named_labels* ldict, VW::io::logger* logger)
{
assert(reuse_mem != nullptr);
Expand Down Expand Up @@ -1633,7 +1633,7 @@ class VWReaderHandler : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, V
public:
Context<audit> ctx;

void init(const VW::label_parser& lbl_parser, hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
void init(const VW::label_parser& lbl_parser, VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask,
bool chain_hash, VW::label_parser_reuse_mem* reuse_mem, const VW::named_labels* ldict, VW::io::logger* logger,
VW::multi_ex* examples, rapidjson::InsituStringStream* stream, const char* stream_end,
VW::example_factory_t example_factory, void* example_factory_context,
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/json_parser/src/parse_example_slates_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ inline float get_number(const rapidjson::Value& value)

template <bool audit>
void handle_features_value(const char* key_namespace, const Value& value, VW::example* current_example,
std::vector<VW::parsers::json::details::namespace_builder<audit>>& namespaces, hash_func_t hash_func,
std::vector<VW::parsers::json::details::namespace_builder<audit>>& namespaces, VW::hash_func_t hash_func,
uint64_t hash_seed, uint64_t parse_mask, bool chain_hash)
{
assert(key_namespace != nullptr);
Expand Down Expand Up @@ -152,7 +152,7 @@ void handle_features_value(const char* key_namespace, const Value& value, VW::ex
// NO_SANITIZE_UNDEFINED needed because example_factory function pointer may be typecasted
template <bool audit>
void NO_SANITIZE_UNDEFINED parse_context(const Value& context, const VW::label_parser& lbl_parser,
hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples,
VW::hash_func_t hash_func, uint64_t hash_seed, uint64_t parse_mask, bool chain_hash, VW::multi_ex& examples,
VW::example_factory_t example_factory, void* ex_factory_context, VW::multi_ex& slot_examples,
std::unordered_map<uint64_t, VW::example*>* dedup_examples = nullptr)
{
Expand Down
2 changes: 1 addition & 1 deletion vowpalwabbit/slim/include/vw/slim/model_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class model_parser
// avoid alignment issues for 32/64bit types on e.g. Android/ARM
memcpy(&val, data, sizeof(T));

if (compute_checksum) { _checksum = (uint32_t)VW::uniform_hash(&val, sizeof(T), _checksum); }
if (compute_checksum) { _checksum = VW::uniform_hash(reinterpret_cast<const char*>(&val), sizeof(T), _checksum); }

#ifdef MODEL_PARSER_DEBUG
log << " '" << val << '\'' << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions vowpalwabbit/slim/src/example_predict_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ example_predict_builder::example_predict_builder(
{
_feature_index_bit_mask = ((uint64_t)1 << feature_index_num_bits) - 1;
add_namespace(namespace_name[0]);
_namespace_hash = hashstring(namespace_name, strlen(namespace_name), 0);
_namespace_hash = VW::details::hashstring(namespace_name, strlen(namespace_name), 0);
}

example_predict_builder::example_predict_builder(
Expand All @@ -32,7 +32,7 @@ void example_predict_builder::add_namespace(VW::namespace_index feature_group)
void example_predict_builder::push_feature_string(const char* feature_name, VW::feature_value value)
{
VW::feature_index feature_hash =
_feature_index_bit_mask & hashstring(feature_name, strlen(feature_name), _namespace_hash);
_feature_index_bit_mask & VW::details::hashstring(feature_name, strlen(feature_name), _namespace_hash);
_ex->feature_space[_namespace_idx].push_back(value, feature_hash);
}

Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/text_parser/src/parse_example_text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ class tc_parser
}

VW::string_view spelling_strview(_spelling.data(), _spelling.size());
word_hash = hashstring(spelling_strview.data(), spelling_strview.length(), (uint64_t)_channel_hash);
word_hash =
VW::details::hashstring(spelling_strview.data(), spelling_strview.length(), (uint64_t)_channel_hash);
spell_fs.push_back(_v, word_hash, VW::details::SPELLING_NAMESPACE);
if (audit)
{
Expand Down