Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DMLIV/DRIV #460

Merged
merged 22 commits into from
Aug 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ from econml.iv.dr import LinearIntentToTreatDRIV
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.linear_model import LinearRegression

est = LinearIntentToTreatDRIV(model_Y_X=GradientBoostingRegressor(),
model_T_XZ=GradientBoostingClassifier(),
est = LinearIntentToTreatDRIV(model_y_xw=GradientBoostingRegressor(),
model_t_xwz=GradientBoostingClassifier(),
flexible_model_effect=GradientBoostingRegressor())
est.fit(Y, T, Z=Z, X=X) # OLS inference by default
treatment_effects = est.effect(X_test)
Expand Down
7 changes: 5 additions & 2 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ Double Machine Learning (DML) IV
.. autosummary::
:toctree: _autosummary

econml.iv.dml.DMLATEIV
econml.iv.dml.ProjectedDMLATEIV
econml.iv.dml.OrthoIV
econml.iv.dml.DMLIV
econml.iv.dml.NonParamDMLIV

Expand All @@ -79,6 +78,10 @@ Doubly Robust (DR) IV
.. autosummary::
:toctree: _autosummary

econml.iv.dr.DRIV
econml.iv.dr.LinearDRIV
econml.iv.dr.SparseLinearDRIV
econml.iv.dr.ForestDRIV
econml.iv.dr.IntentToTreatDRIV
econml.iv.dr.LinearIntentToTreatDRIV

Expand Down
7 changes: 5 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sam
cache_values=cache_values,
inference=inference)

def score(self, Y, T, X=None, W=None):
def score(self, Y, T, X=None, W=None, sample_weight=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted residual nuisance models created at fit time.
Expand All @@ -391,14 +391,17 @@ def score(self, Y, T, X=None, W=None):
Features for each sample
W: optional(n, d_w) matrix or None (Default=None)
Controls for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples


Returns
-------
score: float
The MSE of the final CATE model on the new data.
"""
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
return super().score(Y, T, X=X, W=W)
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight)

@property
def rlearner_model_final_(self):
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _gen_model_final(self):
def _gen_rlearner_model_final(self):
return _FinalWrapper(self._gen_model_final(), self.fit_cate_intercept, self._gen_featurizer(), False)

# override only so that we can update the docstring to indicate support for `StatsModelsInference`
# override only so that we can update the docstring to indicate support for `LinearModelFinalInference`
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, freq_weight=None, sample_var=None, groups=None,
Expand Down
11 changes: 8 additions & 3 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def refit_final(self, *, inference='auto'):
return super().refit_final(inference=inference)
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__

def score(self, Y, T, X=None, W=None):
def score(self, Y, T, X=None, W=None, sample_weight=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted residual nuisance models created at fit time.
Expand All @@ -523,14 +523,16 @@ def score(self, Y, T, X=None, W=None):
Features for each sample
W: optional(n, d_w) matrix or None (Default=None)
Controls for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples

Returns
-------
score: float
The MSE of the final CATE model on the new data.
"""
# Replacing score from _OrthoLearner, to enforce Z=None and improve the docstring
return super().score(Y, T, X=X, W=W)
return super().score(Y, T, X=X, W=W, sample_weight=sample_weight)

@property
def multitask_model_cate(self):
Expand Down Expand Up @@ -693,7 +695,7 @@ class LinearDRLearner(StatsModelsCateEstimatorDiscreteMixin, DRLearner):
\\theta_t(X) = \\left\\langle \\theta_t, \\phi(X) \\right\\rangle + \\beta_t

where :math:`\\phi(X)` is the outcome features of the featurizers, or `X` if featurizer is None. :math:`\\beta_t`
is a an intercept of the CATE, which is included if ``fit_cate_intercept=True`` (Default). It fits this by
is an intercept of the CATE, which is included if ``fit_cate_intercept=True`` (Default). It fits this by
running a standard ordinary linear regression (OLS), regressing the doubly robust outcome differences on X:

.. math ::
Expand Down Expand Up @@ -1472,6 +1474,9 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, groups=None,
-------
self
"""
if X is None:
raise ValueError("This estimator does not support X=None!")

return super().fit(Y, T, X=X, W=W,
sample_weight=sample_weight, groups=groups,
cache_values=cache_values, inference=inference)
Expand Down
83 changes: 0 additions & 83 deletions econml/iv/_nuisance_wrappers.py

This file was deleted.

9 changes: 5 additions & 4 deletions econml/iv/dml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

"""

from ._dml import DMLATEIV, ProjectedDMLATEIV, DMLIV, NonParamDMLIV
from ._dml import OrthoIV, DMLIV, NonParamDMLIV, DMLATEIV, ProjectedDMLATEIV

__all__ = ["DMLATEIV",
"ProjectedDMLATEIV",
__all__ = ["OrthoIV",
"DMLIV",
"NonParamDMLIV"]
"NonParamDMLIV",
"DMLATEIV",
"ProjectedDMLATEIV"]
Loading