forked from clbustos/statsample
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from lokeshh/cat_data
Regression with categorical data and introducing formula language
- Loading branch information
Showing
9 changed files
with
472 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
require 'statsample/formula/formula' | ||
|
||
module Statsample | ||
# Class for performing regression | ||
class FitModel | ||
def initialize(formula, df, opts = {}) | ||
@formula = FormulaWrapper.new formula, df | ||
@df = df | ||
@opts = opts | ||
end | ||
|
||
def model | ||
@model || fit_model | ||
end | ||
|
||
def predict(new_data) | ||
model.predict(df_for_prediction(new_data)) | ||
end | ||
|
||
def df_for_prediction df | ||
canonicalize_df(df) | ||
end | ||
|
||
def df_for_regression | ||
df = canonicalize_df(@df) | ||
df[@formula.y.value] = @df[@formula.y.value] | ||
df | ||
end | ||
|
||
def canonicalize_df(orig_df) | ||
tokens = @formula.canonical_tokens | ||
tokens.shift if tokens.first.value == '1' | ||
df = tokens.map { |t| t.to_df orig_df }.reduce(&:merge) | ||
df | ||
end | ||
|
||
def fit_model | ||
# TODO: Add support for inclusion/exclusion of intercept | ||
@model = Statsample::Regression.multiple( | ||
df_for_regression, | ||
@formula.y.value, | ||
@opts | ||
) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,306 @@ | ||
module Statsample | ||
# This class recognizes what terms are numeric | ||
# and accordingly forms groups which are fed to Formula | ||
# Once they are parsed with Formula, they are combined back | ||
class FormulaWrapper | ||
attr_reader :tokens, :y, :canonical_tokens | ||
|
||
# Initializes formula wrapper object to parse a given formula into | ||
# some tokens which do not overlap one another. | ||
# @note Specify 0 as a term in the formula if you do not want constant | ||
# to be included in the parsed formula | ||
# @param [string] formula to parse | ||
# @param [Daru::DataFrame] df dataframe requried to know what vectors | ||
# are numerical | ||
# @example | ||
# df = Daru::DataFrame.from_csv 'spec/data/df.csv' | ||
# df.to_category 'c', 'd', 'e' | ||
# formula = Statsample::GLM::FormulaWrapper.new 'y~a+d:c', df | ||
# formula.canonical_to_s | ||
# #=> "1+c(-)+d(-):c+a" | ||
def initialize(formula, df) | ||
@df = df | ||
# @y store the LHS term that is name of vector to be predicted | ||
# @tokens store the RHS terms of the formula | ||
@y, *@tokens = split_to_tokens(formula) | ||
@tokens = @tokens.uniq.sort | ||
manage_constant_term | ||
@canonical_tokens = non_redundant_tokens | ||
end | ||
|
||
# Returns canonical tokens in a readable form. | ||
# @return [String] canonical tokens in a readable form. | ||
# @note 'y~a+b(-)' means 'a' exist in full rank expansion | ||
# and 'b(-)' exist in reduced rank expansion | ||
# @example | ||
# df = Daru::DataFrame.from_csv 'spec/data/df.csv' | ||
# df.to_category 'c', 'd', 'e' | ||
# formula = Statsample::GLM::FormulaWrapper.new 'y~a+d:c', df | ||
# formula.canonical_to_s | ||
# #=> "1+c(-)+d(-):c+a" | ||
def canonical_to_s | ||
canonical_tokens.join '+' | ||
end | ||
|
||
# Returns tokens to produce non-redundant design matrix | ||
# @return [Array] array of tokens that do not produce redundant matrix | ||
def non_redundant_tokens | ||
groups = split_to_groups | ||
# TODO: An enhancement | ||
# Right now x:c appears as c:x | ||
groups.each { |k, v| groups[k] = strip_numeric v, k } | ||
groups.each { |k, v| groups[k] = Formula.new(v).canonical_tokens } | ||
groups.flat_map { |k, v| add_numeric v, k } | ||
end | ||
|
||
private | ||
|
||
# Removes intercept token if term '0' is found in the formula. | ||
# Intercept token remains if term '1' is found. | ||
# If neither term '0' nor term '1' is found then, intercept token is added. | ||
def manage_constant_term | ||
@tokens.unshift Token.new('1') unless | ||
@tokens.include?(Token.new('1')) || | ||
@tokens.include?(Token.new('0')) | ||
@tokens.delete Token.new('0') | ||
end | ||
|
||
# Groups the tokens to gropus based on the numerical terms | ||
# they are interacting with. | ||
def split_to_groups | ||
@tokens.group_by { |t| extract_numeric t } | ||
end | ||
|
||
# Add numeric interaction term which was removed earlier | ||
# @param [Array] tokens tokens on which to add numerical terms | ||
# @param [Array] numeric array of numeric terms to add | ||
def add_numeric(tokens, numeric) | ||
tokens.map do |t| | ||
terms = t.interact_terms + numeric | ||
if terms == ['1'] | ||
Token.new('1') | ||
else | ||
terms = terms.reject { |i| i == '1' } | ||
Token.new terms.join(':'), t.full | ||
end | ||
end | ||
end | ||
|
||
# Strip numerical interacting terms | ||
# @param [Array] tokens tokens from which to strip numeric | ||
# @param [Array] numeric array of numeric terms to strip from tokens | ||
# @return [Array] array of tokens with striped numerical terms | ||
def strip_numeric(tokens, numeric) | ||
tokens.map do |t| | ||
terms = t.interact_terms - numeric | ||
terms = ['1'] if terms.empty? | ||
Token.new terms.join(':') | ||
end | ||
end | ||
|
||
# Extract numeric interacting terms | ||
# @param [Statsample::GLM::Token] token form which to extract numeric terms | ||
# @return [Array] array of numericl terms | ||
def extract_numeric(token) | ||
terms = token.interact_terms | ||
return [] if terms == ['1'] | ||
terms.reject { |t| @df[t].category? } | ||
end | ||
|
||
def split_to_tokens(formula) | ||
formula = formula.gsub(/\s+/, '') | ||
lhs_term, rhs = formula.split '~' | ||
rhs_terms = rhs.split '+' | ||
([lhs_term] + rhs_terms).map { |t| Token.new t } | ||
end | ||
end | ||
|
||
# To process formula language | ||
class Formula | ||
attr_reader :tokens, :canonical_tokens | ||
|
||
def initialize(tokens) | ||
@tokens = tokens | ||
@canonical_tokens = parse_formula | ||
end | ||
|
||
def canonical_to_s | ||
canonical_tokens.join '+' | ||
end | ||
|
||
private | ||
|
||
def parse_formula | ||
@tokens.inject([]) do |acc, token| | ||
acc + add_non_redundant_elements(token, acc) | ||
end | ||
end | ||
|
||
def add_non_redundant_elements(token, result_so_far) | ||
return [token] if token.value == '1' | ||
tokens = token.expand | ||
result_so_far = result_so_far.flat_map(&:expand) | ||
tokens -= result_so_far | ||
contract_if_possible tokens | ||
end | ||
|
||
def contract_if_possible(tokens) | ||
tokens.combination(2).each do |a, b| | ||
result = a.add b | ||
next unless result | ||
tokens.delete a | ||
tokens.delete b | ||
tokens << result | ||
return contract_if_possible tokens | ||
end | ||
tokens.sort | ||
end | ||
end | ||
|
||
# To encapsulate interaction as well as non-interaction terms | ||
class Token | ||
attr_reader :value, :full, :interact_terms | ||
|
||
def initialize(value, full = true) | ||
@interact_terms = value.include?(':') ? value.split(':') : [value] | ||
@full = coerce_full full | ||
end | ||
|
||
def value | ||
interact_terms.join(':') | ||
end | ||
|
||
def size | ||
# TODO: Return size 1 for value '1' also | ||
# CAn't do this at the moment because have to make | ||
# changes in sorting first | ||
value == '1' ? 0 : interact_terms.size | ||
end | ||
|
||
def add(other) | ||
# ANYTHING + FACTOR- : ANYTHING = FACTOR : ANYTHING | ||
# ANYTHING + ANYTHING : FACTOR- = ANYTHING : FACTOR | ||
if size > other.size | ||
other.add self | ||
|
||
elsif other.size == 2 && | ||
size == 1 && | ||
other.interact_terms.last == value && | ||
other.full.last == full.first && | ||
other.full.first == false | ||
Token.new( | ||
"#{other.interact_terms.first}:#{value}", | ||
[true, other.full.last] | ||
) | ||
|
||
elsif other.size == 2 && | ||
size == 1 && | ||
other.interact_terms.first == value && | ||
other.full.first == full.first && | ||
other.full.last == false | ||
Token.new( | ||
"#{value}:#{other.interact_terms.last}", | ||
[other.full.first, true] | ||
) | ||
|
||
elsif value == '1' && | ||
other.size == 1 | ||
Token.new(other.value, true) | ||
end | ||
end | ||
|
||
def ==(other) | ||
value == other.value && | ||
full == other.full | ||
end | ||
|
||
alias eql? == | ||
|
||
def hash | ||
value.hash ^ full.hash | ||
end | ||
|
||
def <=>(other) | ||
size <=> other.size | ||
end | ||
|
||
def to_s | ||
interact_terms | ||
.zip(full) | ||
.map { |t, f| f ? t : t + '(-)' } | ||
.join ':' | ||
end | ||
|
||
def expand | ||
case size | ||
when 0 | ||
[self] | ||
when 1 | ||
[Token.new('1'), Token.new(value, false)] | ||
when 2 | ||
a, b = interact_terms | ||
[Token.new('1'), Token.new(a, false), Token.new(b, false), | ||
Token.new(a + ':' + b, [false, false])] | ||
end | ||
end | ||
|
||
def to_df(df) | ||
case size | ||
when 1 | ||
if df[value].category? | ||
df[value].contrast_code full: full.first | ||
else | ||
Daru::DataFrame.new value => df[value].to_a | ||
end | ||
when 2 | ||
to_df_when_interaction(df) | ||
end | ||
end | ||
|
||
private | ||
|
||
def coerce_full(value) | ||
if value.is_a? Array | ||
value + Array.new((@interact_terms.size - value.size), true) | ||
else | ||
[value] * @interact_terms.size | ||
end | ||
end | ||
|
||
def to_df_when_interaction(df) | ||
case interact_terms.map { |t| df[t].category? } | ||
when [true, true] | ||
df.interact_code(interact_terms, full) | ||
when [false, false] | ||
to_df_numeric_interact_with_numeric df | ||
when [true, false] | ||
to_df_category_interact_with_numeric df | ||
when [false, true] | ||
to_df_numeric_interact_with_category df | ||
end | ||
end | ||
|
||
def to_df_numeric_interact_with_numeric(df) | ||
Daru::DataFrame.new value => (df[interact_terms.first] * | ||
df[interact_terms.last]).to_a | ||
end | ||
|
||
def to_df_category_interact_with_numeric(df) | ||
a, b = interact_terms | ||
Daru::DataFrame.new( | ||
df[a].contrast_code(full: full.first) | ||
.map { |dv| ["#{dv.name}:#{b}", (dv * df[b]).to_a] } | ||
.to_h | ||
) | ||
end | ||
|
||
def to_df_numeric_interact_with_category(df) | ||
a, b = interact_terms | ||
Daru::DataFrame.new( | ||
df[b].contrast_code(full: full.last) | ||
.map { |dv| ["#{a}:#{dv.name}", (dv * df[a]).to_a] } | ||
.to_h | ||
) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.