Skip to content

Commit

Permalink
feat: Switch to scalar code for unimplemented interactions in las sim…
Browse files Browse the repository at this point in the history
…d code path. (#4487)

* Add warnings.

* Update vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc

Co-authored-by: olgavrou <[email protected]>

* Update vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc

Co-authored-by: olgavrou <[email protected]>

* Update vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc

Co-authored-by: olgavrou <[email protected]>

* Update vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx2.cc

Co-authored-by: olgavrou <[email protected]>

* Update vowpalwabbit/core/src/reductions/cb/details/large_action/compute_dot_prod_avx512.cc

Co-authored-by: olgavrou <[email protected]>

* Fix format.

* Fix format.

---------

Co-authored-by: olgavrou <[email protected]>
  • Loading branch information
zwd-ms and olgavrou authored Feb 3, 2023
1 parent 0fcb078 commit 2fe6f2f
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,15 @@ VW::LEARNER::base_learner* VW::reductions::cb_explore_adf_large_action_space_set
}
}

if (use_simd_in_one_pass_svd_impl &&
(options.was_supplied("cubic") || options.was_supplied("interactions") ||
options.was_supplied("experimental_full_name_interactions")))
{
all.logger.err_warn(
"Large action space with SIMD only supports quadratic interactions for now. Using scalar code path.");
use_simd_in_one_pass_svd_impl = false;
}

VW::LEARNER::multi_learner* base = as_multiline(stack_builder.setup_base_learner());
all.example_parser->lbl_parser = VW::cb_label_parser_global;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,18 @@ float compute_dot_prod_avx2(uint64_t column_index, VW::workspace* _all, uint64_t
if (!extent_interactions.empty())
{
// TODO: Add support for extent_interactions.
THROW("Extent_interactions are not supported yet in LAS SIMD implementations");
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error("Extent_interactions are not supported yet in large action space with SIMD implementations");
}

for (const auto& ns : interactions)
{
if (ns.size() != 2)
{
// TODO: Add support for interactions other than quadratics.
THROW("Generic interactions are not supported yet in LAS SIMD implementations")
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error(
"Generic interactions are not supported yet in large action space with SIMD implementations");
}

const bool same_namespace = (!_all->permutations && (ns[0] == ns[1]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,18 @@ float compute_dot_prod_avx512(uint64_t column_index, VW::workspace* _all, uint64
if (!extent_interactions.empty())
{
// TODO: Add support for extent_interactions.
THROW("Extent_interactions are not supported yet in LAS SIMD implementations");
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error("Extent_interactions are not supported yet in large action space with SIMD implementations");
}

for (const auto& ns : interactions)
{
if (ns.size() != 2)
{
// TODO: Add support for interactions other than quadratics.
THROW("Generic interactions are not supported yet in LAS SIMD implementations")
// This code should not be reachable, since we checked conflicting command line options.
_all->logger.err_error(
"Generic interactions are not supported yet in large action space with SIMD implementations");
}

const bool same_namespace = (!_all->permutations && (ns[0] == ns[1]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ one_pass_svd_impl::one_pass_svd_impl(VW::workspace* all, uint64_t d, uint64_t se
{
if (cpu_supports_avx512()) { _use_simd = simd_type::AVX512; }
else if (cpu_supports_avx2()) { _use_simd = simd_type::AVX2; }
else { all->logger.err_warn("System does not support AVX512 or AVX2. Using scalar code path."); }
}
#else
_UNUSED(use_explicit_simd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class one_pass_svd_impl
// for testing purposes only
void _test_only_set_rank(uint64_t rank);
bool _set_testing_components = false;
#ifdef BUILD_LAS_WITH_SIMD
bool _test_only_use_simd() { return _use_simd != simd_type::NO_SIMD; }
#endif

private:
VW::workspace* _all;
Expand Down
105 changes: 69 additions & 36 deletions vowpalwabbit/core/tests/cb_las_one_pass_svd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ TEST(Las, ComputeDotProdScalarAndSimdHaveSameResults)

TEST(Las, ScalarAndSimdGenerateSamePredictions)
{
const bool cpu_supports_simd = (VW::cb_explore_adf::cpu_supports_avx512() || VW::cb_explore_adf::cpu_supports_avx2());

auto generate_example = [](int num_namespaces, int num_features)
{
std::string s;
Expand Down Expand Up @@ -338,13 +340,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
std::vector<std::string> vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -365,13 +386,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
std::vector<std::string> vw_cmd{"--cb_explore_adf", "--large_action_space", "--quiet", "-q::"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -393,13 +433,32 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
"--cb_explore_adf", "--large_action_space", "--quiet", "-q::", "--ignore=A", "--ignore_linear=B"};

auto vw_scalar = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_scalar =
as_multiline(vw_scalar->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_scalar =
(internal_action_space_op*)learner_scalar->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_scalar, nullptr);

EXPECT_FALSE(action_space_scalar->explore.impl._test_only_use_simd());

VW::multi_ex ex_scalar;
for (const auto& example : examples) { ex_scalar.push_back(VW::read_example(*vw_scalar, example)); }
vw_scalar->predict(ex_scalar);
auto& scores_scalar = ex_scalar[0]->pred.a_s;

vw_cmd.push_back("--las_hint_explicit_simd");
auto vw_simd = VW::initialize(VW::make_unique<VW::config::options_cli>(vw_cmd));

VW::LEARNER::multi_learner* learner_simd =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space_simd =
(internal_action_space_op*)learner_simd->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space_simd, nullptr);

if (cpu_supports_simd) { EXPECT_TRUE(action_space_simd->explore.impl._test_only_use_simd()); }
else { EXPECT_FALSE(action_space_simd->explore.impl._test_only_use_simd()); }

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }
vw_simd->predict(ex_simd);
Expand All @@ -419,51 +478,25 @@ TEST(Las, ScalarAndSimdGenerateSamePredictions)
// Cubics & generic interactions are not supported yet
auto vw_simd = VW::initialize(vwtest::make_args(
"--cb_explore_adf", "--large_action_space", "--quiet", "--cubic", ":::", "--las_hint_explicit_simd"));
VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }

EXPECT_THROW(
{
try
{
vw_simd->predict(ex_simd);
}
catch (const VW::vw_exception& e)
{
EXPECT_STREQ("Generic interactions are not supported yet in LAS SIMD implementations", e.what());
throw;
}
},
VW::vw_exception);
VW::LEARNER::multi_learner* learner =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space, nullptr);

vw_simd->finish_example(ex_simd);
EXPECT_FALSE(action_space->explore.impl._test_only_use_simd());
}
{
// Extent interactions are not supported yet
const std::string vw_cmd =
"--cb_explore_adf --large_action_space --quiet --experimental_full_name_interactions A|B";

auto vw_simd = VW::initialize(vwtest::make_args("--cb_explore_adf", "--large_action_space", "--quiet",
"--experimental_full_name_interactions", "A|B", "--las_hint_explicit_simd"));

VW::multi_ex ex_simd;
for (const auto& example : examples) { ex_simd.push_back(VW::read_example(*vw_simd, example)); }

EXPECT_THROW(
{
try
{
vw_simd->predict(ex_simd);
}
catch (const VW::vw_exception& e)
{
EXPECT_STREQ("Extent_interactions are not supported yet in LAS SIMD implementations", e.what());
throw;
}
},
VW::vw_exception);
VW::LEARNER::multi_learner* learner =
as_multiline(vw_simd->l->get_learner_by_name_prefix("cb_explore_adf_large_action_space"));
auto* action_space = (internal_action_space_op*)learner->get_internal_type_erased_data_pointer_test_use_only();
EXPECT_NE(action_space, nullptr);

vw_simd->finish_example(ex_simd);
EXPECT_FALSE(action_space->explore.impl._test_only_use_simd());
}
}
#endif

0 comments on commit 2fe6f2f

Please sign in to comment.