Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Switch to scalar code for unimplemented interactions in las simd code path. #4487

Merged
merged 9 commits into from
Feb 3, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
}
}

if (use_simd_in_one_pass_svd_impl &&
(options.was_supplied("cubic") || options.was_supplied("interactions") ||
options.was_supplied("experimental_full_name_interactions")))
{
all.logger.err_warn(
"Large action space with SIMD only supports quadratic interactions for now. Using scalar code path.");
use_simd_in_one_pass_svd_impl = false;
}

VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = VW::cb_label_parser_global;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,18 @@ float compute_dot_prod_avx2(uint64_t column_index, VW::workspace* _all, uint64_t
if (!extent_interactions.empty())
{
// TODO: Add support for extent_interactions.
THROW("Extent_interactions are not supported yet in LAS SIMD implementations");
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error("Extent_interactions are not supported yet in large action space with SIMD implementations");
}

for (const auto& ns : interactions)
{
if (ns.size() != 2)
{
// TODO: Add support for interactions other than quadratics.
THROW("Generic interactions are not supported yet in LAS SIMD implementations")
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error(
"Generic interactions are not supported yet in large action space with SIMD implementations");
}

const bool same_namespace = (!_all->permutations && (ns[0] == ns[1]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,18 @@ float compute_dot_prod_avx512(uint64_t column_index, VW::workspace* _all, uint64
if (!extent_interactions.empty())
{
// TODO: Add support for extent_interactions.
THROW("Extent_interactions are not supported yet in LAS SIMD implementations");
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error("Extent_interactions are not supported yet in large action space with SIMD implementations");
}

for (const auto& ns : interactions)
{
if (ns.size() != 2)
{
// TODO: Add support for interactions other than quadratics.
THROW("Generic interactions are not supported yet in LAS SIMD implementations")
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error(
"Generic interactions are not supported yet in large action space with SIMD implementations");
}

const bool same_namespace = (!_all->permutations && (ns[0] == ns[1]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ one_pass_svd_impl::one_pass_svd_impl(VW::workspace* all, uint64_t d, uint64_t se
{
if (cpu_supports_avx512()) { _use_simd = simd_type::AVX512; }
else if (cpu_supports_avx2()) { _use_simd = simd_type::AVX2; }
else { all->logger.err_warn("System does not support AVX512 or AVX2. Using scalar code path."); }
}
#else
_UNUSED(use_explicit_simd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class one_pass_svd_impl
// for testing purposes only
void _test_only_set_rank(uint64_t rank);
bool _set_testing_components = false;
#ifdef BUILD_LAS_WITH_SIMD
bool _test_only_use_simd() { return _use_simd != simd_type::NO_SIMD; }
#endif

private:
VW::workspace* _all;
Expand Down
105 changes: 69 additions & 36 deletions vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults)

TEST(Las, ScalarAndSimdGenerateSamePredictions)
{
const bool cpu_supports_simd = (VW::cb_explore_adf::cpu_supports_avx512() || VW::cb_explore_adf::cpu_supports_avx2());

auto generate_example = [](int num_namespaces, int num_features)
{
std::string s;
Expand Down Expand Up @@ -338,13 +340,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
std::vector<std::string> vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -365,13 +386,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
std::vector<std::string> vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet", "-q::"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -393,13 +433,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
"--cb_explore_adf", "--large_action_space", "--quiet", "-q::", "--ignore=A", "--ignore_linear=B"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -419,51 +478,25 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
// Cubics & generic interactions are not supported yet
auto vw_simd = VW::initialize(vwtest::make_args(
"--cb_explore_adf", "--large_action_space", "--quiet", "--cubic", ":::", "--las_hint_explicit_simd"));
VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }

EXPECT_THROW(
{
try
{
vw_simd->predict(ex_simd);
}
catch (const VW::vw_exception& e)
{
EXPECT_STREQ("Generic interactions are not supported yet in LAS SIMD implementations", e.what());
throw;
}
},
VW::vw_exception);
VW::LEARNER::multi_learner* learner =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space, nullptr);

vw_simd->finish_example(ex_simd);
EXPECT_FALSE(action_space->explore.impl._test_only_use_simd());
}
{
// Extent interactions are not supported yet
const std::string vw_cmd =
"--cb_explore_adf --large_action_space --quiet --experimental_full_name_interactions A|B";

auto vw_simd = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet",
"--experimental_full_name_interactions", "A|B", "--las_hint_explicit_simd"));

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }

EXPECT_THROW(
{
try
{
vw_simd->predict(ex_simd);
}
catch (const VW::vw_exception& e)
{
EXPECT_STREQ("Extent_interactions are not supported yet in LAS SIMD implementations", e.what());
throw;
}
},
VW::vw_exception);
VW::LEARNER::multi_learner* learner =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space, nullptr);

vw_simd->finish_example(ex_simd);
EXPECT_FALSE(action_space->explore.impl._test_only_use_simd());
}
}
#endif