Skip to content

Commit

Permalink
Added partial function
Browse files Browse the repository at this point in the history
  • Loading branch information
thequackdaddy committed Mar 4, 2017
1 parent 72e9725 commit 460a6f9
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions patsy/design_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from patsy.constraint import linear_constraint
from patsy.contrasts import ContrastMatrix
from patsy.desc import ModelDesc, Term
from collections import OrderedDict

class FactorInfo(object):
"""A FactorInfo object is a simple class that provides some metadata about
Expand Down Expand Up @@ -684,6 +685,49 @@ def var_names(self, eval_env=0):
else:
return {}

def partial(self, columns, product=False):
"""Returns a partial prediction array where only the variables in the
dict ``columns`` are tranformed per the :class:`DesignInfo`
transformations. The terms that are not influenced by ``columns``
return as zero.
This is useful to perform a partial prediction on unseen data and to
view marginal differences in factors.
:arg columns: A dict with the keys as the column names for the marginal
predictions desired and values as the marginal values to be predicted.
:arg product: When `True`, the resturned numpy array represents the
Cartesian product of the values ``columns``.
:returns: A numpy array of the partial design matrix.
"""
from .highlevel import dmatrix
if product:
columns = _column_product(columns)
rows = None
for col in columns:
if rows and rows != len(columns[col]):
raise ValueError('all columns must be of same length')
rows = len(columns[col])
parts = []
for term, subterm in six.iteritems(self.term_codings):
term_vars = term.var_names()
present = True
for term_var in term_vars:
if term_var not in columns:
present = False
if present and (term.name() != 'Intercept'):
# This seems like an inelegent way to not having the Intercept
# in the output
di = self.subset('0 + {}'.format(term.name()))
parts.append(dmatrix(di, columns))
else:
num_columns = np.sum(s.num_columns for s in subterm)
dm = np.zeros((rows, num_columns))
parts.append(dm)
return np.hstack(parts)

@classmethod
def from_array(cls, array_like, default_column_prefix="column"):
"""Find or construct a DesignInfo appropriate for a given array_like.
Expand Down Expand Up @@ -1230,3 +1274,62 @@ def test_design_matrix():
repr(DesignMatrix(np.zeros((1, 0))))
repr(DesignMatrix(np.zeros((0, 1))))
repr(DesignMatrix(np.zeros((0, 0))))


def test_DesignInfo_partial():
from .highlevel import dmatrix
from numpy.testing import assert_allclose
a = np.array(['a', 'b', 'a', 'b', 'a', 'a', 'b', 'a'])
b = np.array([1, 3, 2, 4, 1, 3, 1, 1])
c = np.array([4, 3, 2, 1, 6, 4, 2, 1])
dm = dmatrix('a + bs(b, df=3, degree=3) + np.log(c)')
x = np.zeros((3, 6))
x[1, 1] = 1
y = dm.design_info.partial({'a': ['a', 'b', 'a']})
assert_allclose(x, y)

x = np.zeros((2, 6))
x[1, 1] = 1
x[1, 5] = np.log(3)
p = OrderedDict([('a', ['a', 'b']), ('c', [1, 3])])
y = dm.design_info.partial(p)
assert_allclose(x, y)

x = np.zeros((4, 6))
x[2, 1] = 1
x[3, 1] = 1
x[1, 5] = np.log(3)
x[3, 5] = np.log(3)
y = dm.design_info.partial(p, product=True)
assert_allclose(x, y)

dm = dmatrix('a * c')
y = dm.design_info.partial(p)
x = np.array([[0, 0, 1, 0], [0, 1, 3, 3]])
assert_allclose(x, y)

from nose.tools import assert_raises
assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'],
'b': [1, 2, 3]})


def _column_product(columns):
from itertools import product
cols = []
values = []
for col, value in six.iteritems(columns):
cols.append(col)
values.append(value)
values = [value for value in product(*values)]
values = [value for value in zip(*values)]
return OrderedDict([(col, list(value))
for col, value in zip(cols, values)])


def test_column_product():
x = OrderedDict([('a', [1, 2, 3]), ('b', ['a', 'b'])])
y = OrderedDict([('a', [1, 1, 2, 2, 3, 3]),
('b', ['a', 'b', 'a', 'b', 'a', 'b'])])
x = _column_product(x)
assert x['a'] == y['a']
assert x['b'] == y['b']

0 comments on commit 460a6f9

Please sign in to comment.