diff --git a/src/cmd/vocab.rs b/src/cmd/vocab.rs index e2551dc7..01a353c1 100644 --- a/src/cmd/vocab.rs +++ b/src/cmd/vocab.rs @@ -110,9 +110,6 @@ vocab cooc options: --distrib Compute directed distributional similarity metrics instead. --min-count Minimum number of co-occurrence count to be included in the result. [default: 1] - --complete Compute the complete chi2 & G2 metrics, instead of their approximation - based on the first cell of the contingency matrix. This - is of course more costly to compute. Common options: -h, --help Display this message @@ -143,7 +140,6 @@ struct Args { flag_forward: bool, flag_distrib: bool, flag_min_count: usize, - flag_complete: bool, flag_output: Option, flag_no_headers: bool, flag_delimiter: Option, @@ -342,10 +338,6 @@ pub fn run(argv: &[&str]) -> CliResult<()> { }; if args.flag_distrib { - if args.flag_complete { - unimplemented!(); - } - let output_headers: [&[u8]; 5] = [b"token1", b"token2", b"count", b"sdI", b"sdG2"]; wtr.write_record(output_headers)?; @@ -357,9 +349,8 @@ pub fn run(argv: &[&str]) -> CliResult<()> { ]; wtr.write_record(output_headers)?; - cooccurrences.for_each_cooc_record(args.flag_min_count, args.flag_complete, |r| { - wtr.write_byte_record(r) - })?; + cooccurrences + .for_each_cooc_record(args.flag_min_count, |r| wtr.write_byte_record(r))?; } return Ok(wtr.flush()?); @@ -765,17 +756,17 @@ fn compute_npmi(xy: usize, n: usize, pmi: f64) -> f64 { } } -#[inline] -fn compute_simplified_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) { - // This version does not take into account the full contingency matrix. - let observed = xy as f64; - let expected = x as f64 * y as f64 / n as f64; +// #[inline] +// fn compute_simplified_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) { +// // This version does not take into account the full contingency matrix. +// let observed = xy as f64; +// let expected = x as f64 * y as f64 / n as f64; - ( - (observed - expected).powi(2) / expected, - 2.0 * observed * (observed / expected).ln(), - ) -} +// ( +// (observed - expected).powi(2) / expected, +// 2.0 * observed * (observed / expected).ln(), +// ) +// } #[inline] fn compute_simplified_g2(x: usize, y: usize, xy: usize, n: usize) -> f64 { @@ -787,35 +778,82 @@ fn compute_simplified_g2(x: usize, y: usize, xy: usize, n: usize) -> f64 { } // NOTE: see code in issue https://github.com/medialab/xan/issues/295 +// NOTE: it is possible to approximate chi2 and G2 for co-occurrences by +// only computing the (observed_11, expected_11) part related to the first +// cell of the contingency matrix. This works very well for chi2, but +// is a little bit more fuzzy for G2. fn compute_chi2_and_g2(x: usize, y: usize, xy: usize, n: usize) -> (f64, f64) { + // This can be 0 if some item is present in all co-occurrences! let not_x = (n - x) as f64; let not_y = (n - y) as f64; - let nf = n as f64; let observed_11 = xy as f64; - let observed_12 = (x - xy) as f64; - let observed_21 = (y - xy) as f64; - let observed_22 = (n - (x + y) + xy) as f64; + let observed_12 = (x - xy) as f64; // Is 0 if x only co-occurs with y + let observed_21 = (y - xy) as f64; // Is 0 if y only co-occurs with x + + // NOTE: with few co-occurrences, self loops can produce a negative + // outcome... + let observed_22 = ((n + xy) as i64 - (x + y) as i64) as f64; + + let nf = n as f64; - let expected_11 = x as f64 * y as f64 / nf; + let expected_11 = x as f64 * y as f64 / nf; // Cannot be 0 let expected_12 = x as f64 * not_y / nf; let expected_21 = y as f64 * not_x / nf; let expected_22 = not_x * not_y / nf; + debug_assert!( + observed_11 >= 0.0 + && observed_12 >= 0.0 + && observed_21 >= 0.0 + // && observed_22 >= 0.0 + && expected_11 >= 0.0 + && expected_12 >= 0.0 + && expected_21 >= 0.0 + && expected_22 >= 0.0 + ); + let chi2_11 = (observed_11 - expected_11).powi(2) / expected_11; let chi2_12 = (observed_12 - expected_12).powi(2) / expected_12; let chi2_21 = (observed_21 - expected_21).powi(2) / expected_21; let chi2_22 = (observed_22 - expected_22).powi(2) / expected_22; let g2_11 = observed_11 * (observed_11 / expected_11).ln(); - let g2_12 = observed_12 * (observed_12 / expected_12).ln(); - let g2_21 = observed_21 * (observed_21 / expected_21).ln(); - let g2_22 = observed_22 * (observed_22 / expected_22).ln(); - - ( - chi2_11 + chi2_12 + chi2_21 + chi2_22, - 2.0 * (g2_11 + g2_12 + g2_21 + g2_22), - ) + let g2_12 = if observed_12 == 0.0 { + 0.0 + } else { + observed_12 * (observed_12 / expected_12).ln() + }; + let g2_21 = if observed_21 == 0.0 { + 0.0 + } else { + observed_21 * (observed_21 / expected_21).ln() + }; + let g2_22 = if observed_22 <= 0.0 { + 0.0 + } else { + observed_22 * (observed_22 / expected_22).ln() + }; + + let mut chi2 = chi2_11 + chi2_12 + chi2_21 + chi2_22; + let mut g2 = 2.0 * (g2_11 + g2_12 + g2_21 + g2_22); + + // Dealing with degenerate cases that happen when the number + // of co-occurrences is very low, or when some item dominates + // the distribution. + if chi2.is_nan() { + chi2 = 0.0; + } + + if chi2.is_infinite() { + chi2 = chi2_11; + } + + if g2.is_infinite() { + g2 = g2_11; + } + + (chi2, g2) } #[derive(Debug)] @@ -912,12 +950,7 @@ impl Cooccurrences { target_entry.gcf += 1; } - fn for_each_cooc_record( - self, - min_count: usize, - complete: bool, - mut callback: F, - ) -> Result<(), E> + fn for_each_cooc_record(self, min_count: usize, mut callback: F) -> Result<(), E> where F: FnMut(&csv::ByteRecord) -> Result<(), E>, { @@ -938,11 +971,7 @@ impl Cooccurrences { let xy = *count; // chi2/G2 computations - let (chi2, g2) = if complete { - compute_chi2_and_g2(x, y, xy, n) - } else { - compute_simplified_chi2_and_g2(x, y, xy, n) - }; + let (chi2, g2) = compute_chi2_and_g2(x, y, xy, n); // PMI-related computations let pmi = compute_pmi(x, y, xy, n); @@ -953,13 +982,7 @@ impl Cooccurrences { csv_record.push_field(&target_entry.token); csv_record.push_field(count.to_string().as_bytes()); csv_record.push_field(chi2.to_string().as_bytes()); - - if g2.is_nan() { - csv_record.push_field(b""); - } else { - csv_record.push_field(g2.to_string().as_bytes()); - } - + csv_record.push_field(g2.to_string().as_bytes()); csv_record.push_field(pmi.to_string().as_bytes()); csv_record.push_field(npmi.to_string().as_bytes()); diff --git a/tests/test_vocab.rs b/tests/test_vocab.rs index e6e3eb1a..cd4eff81 100644 --- a/tests/test_vocab.rs +++ b/tests/test_vocab.rs @@ -231,7 +231,7 @@ fn vocab_cooc_sep_no_doc() { let expected = vec![ svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"], - svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"], + svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"], svec!["cat", "dog", "2", "0", "0", "0", "0"], svec!["cat", "rabbit", "1", "0", "0", "0", "0"], ]; @@ -262,7 +262,7 @@ fn vocab_cooc_no_sep() { let expected = vec![ svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"], - svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"], + svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"], svec!["cat", "dog", "2", "0", "0", "0", "0"], svec!["cat", "rabbit", "1", "0", "0", "0", "0"], ]; @@ -294,7 +294,7 @@ fn vocab_cooc_no_sep_window() { let expected = vec![ svec!["token1", "token2", "count", "chi2", "G2", "pmi", "npmi"], - svec!["cat", "cat", "1", "2.25", "-2.772588722239781", "-2", "-1"], + svec!["cat", "cat", "1", "2.25", "-1.3862943611198906", "-2", "-1"], svec!["cat", "dog", "2", "0", "0", "0", "0"], svec!["cat", "rabbit", "1", "0", "0", "0", "0"], ];