Skip to content
This repository has been archived by the owner on Dec 22, 2023. It is now read-only.

Commit

Permalink
🔥 Always enable damping
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenein committed Jun 23, 2023
1 parent f3ad11a commit 7976973
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 32 deletions.
9 changes: 2 additions & 7 deletions src/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,8 @@ impl GridSearch {
info!(n_votes = votes.len(), "✅ Gotcha!");
report_memory_usage();

let params = iproduct!(1..=self.high_neighbors, [false, true], [false, true]).map(
|(n_neighbors, enable_damping, include_negative)| Params {
enable_damping,
n_neighbors,
include_negative,
},
);
let params = iproduct!(1..=self.high_neighbors, [false, true])
.map(|(n_neighbors, include_negative)| Params { n_neighbors, include_negative });
search(&mut votes, self.n_partitions, self.test_proportion, params);
info!("🏁 Finished search");

Expand Down
30 changes: 6 additions & 24 deletions src/trainer/item_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ use crate::{
/// Model parameters.
#[derive(Debug, Args, Serialize, Deserialize, Copy, Clone)]
pub struct Params {
#[clap(long, env = "BLITZ_TANKS_MODEL_ENABLE_DAMPING")]
pub enable_damping: bool,

#[clap(long, env = "BLITZ_TANKS_MODEL_NEIGHBORS")]
/// Number of top similar vehicles to include in a prediction.
pub n_neighbors: usize,
Expand All @@ -30,7 +27,7 @@ impl Params {
pub fn fit(self, votes: &[Vote]) -> Model {
let mut votes = votes.iter().into_group_map_by(|vote| vote.tank_id);
let biases = Self::calculate_biases(&votes);
let mut similarities = self.calculate_similarities(&mut votes, &biases);
let mut similarities = Self::calculate_similarities(&mut votes, &biases);
Self::sort_similarities(&mut similarities);
Model {
created_at: Utc::now(),
Expand Down Expand Up @@ -72,7 +69,6 @@ impl Params {
/// by decreasing similarity in respect to the former.
#[must_use]
fn calculate_similarities(
&self,
votes: &mut HashMap<u16, Vec<&Vote>>,
biases: &HashMap<u16, f64>,
) -> HashMap<u16, Box<[(u16, f64)]>> {
Expand All @@ -85,7 +81,7 @@ impl Params {
.filter(|(j, _)| *i != **j)
.map(|(j, bias_j)| {
// FIXME: I do the same calculation twice: for `(i, j)` and `(j, i)`.
(*j, self.calculate_similarity(*bias_i, &votes[i], *bias_j, &votes[j]))
(*j, Self::calculate_similarity(*bias_i, &votes[i], *bias_j, &votes[j]))
})
.collect();
(*i, similarities)
Expand All @@ -95,28 +91,18 @@ impl Params {

/// Calculate similarity between two vehicles, specified by their respective biases
/// and votes sorted by account ID.
fn calculate_similarity(
&self,
bias_i: f64,
votes_i: &[&Vote],
bias_j: f64,
votes_j: &[&Vote],
) -> f64 {
fn calculate_similarity(bias_i: f64, votes_i: &[&Vote], bias_j: f64, votes_j: &[&Vote]) -> f64 {
let mut dot_product = 0.0;
let mut norm2_i = 0.0;
let mut norm2_j = 0.0;

for either in merge_join_by(votes_i, votes_j, |i, j| i.account_id.cmp(&j.account_id)) {
match either {
EitherOrBoth::Left(vote_i) => {
if self.enable_damping {
norm2_i += (vote_i.rating - bias_i).powi(2);
}
norm2_i += (vote_i.rating - bias_i).powi(2);
}
EitherOrBoth::Right(vote_j) => {
if self.enable_damping {
norm2_j += (vote_j.rating - bias_j).powi(2);
}
norm2_j += (vote_j.rating - bias_j).powi(2);
}
EitherOrBoth::Both(vote_i, vote_j) => {
let diff_i = vote_i.rating - bias_i;
Expand All @@ -128,11 +114,7 @@ impl Params {
}
}

if norm2_i >= f64::EPSILON && norm2_j >= f64::EPSILON {
dot_product / norm2_i.sqrt() / norm2_j.sqrt()
} else {
0.0
}
dot_product / norm2_i.sqrt() / norm2_j.sqrt()
}

/// Sort each vehicle's similar vehicles by decreasing similarity.
Expand Down
1 change: 0 additions & 1 deletion src/trainer/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ pub fn search(
if new_mrr > current_mrr {
info!(
%new_mrr,
enable_damping = new_params.enable_damping,
n_neighbors = new_params.n_neighbors,
include_negative = new_params.include_negative,
"🎉 Improved",
Expand Down

0 comments on commit 7976973

Please sign in to comment.