-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH implement fair representation learner #1478
base: main
Are you sure you want to change the base?
ENH implement fair representation learner #1478
Conversation
I think that the CI is failing for the
|
I am back :) Wow that is a big PR thanks! I will look into the Windows issue, so we can separate it from the PR here and have a clean CI. |
@taharallouche it seems that the error was made less severe so the builds pass now, and the library maintainers are keeping an eye on it. CI now builds with |
I get notifications so just popping here without a proper review to hopefully make your day easier. The sklearn 1.6 compatibility issues were working for what we had, trying to add minimal added code, but if you are using some other functions that we didn't have you might want to port some of this code found here https://github.com/sklearn-compat/sklearn-compat, and if it is too much and becomes annoying we can decide to vendor the whole library. It is a bit of a tedious thing, feel free to ping me on Discord about anything, I'd love to support if I can! I wasn't clear enough the other day about the |
EDIT: False alarm Whoops it seems (It's breaking my new estimator as well so I'll be looking into it ...) |
hi @taharallouche, the main branch passes fully with |
Yay for those changes fixing most of the issues! It seems there are still some issues with the |
hi @fairlearn/fairlearn-maintainers, this is now ready for a deeper review! I will sign myself up to do one next week, I have already looked into the branch a bit. A second reviewer is necessary. @hildeweerts as you have created the issue, would you be up for a review too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After a first pass, I’m unsure what a good API for this method should look like and/or how to categorize it properly. Intuitively, I wouldn’t expect a "pre-processing" algorithm to have a “predict” method (or need of a fall-back estimator).
According to the paper, the intuition behind the proposed representation learning approach is to pre-process data in an alternative representation, which could theoretically be reused by a different party for classification. Yet, the main experiments (section 4.2) use the probabilities derived from prototypes directly for classification, which seems more like an end-to-end representation learning approach rather than a pre-processing approach...
The chosen representation, which maps instances in the dataset probabilistically to prototypes, is a bit “weird” in the sense that it's not entirely clear how it would be used directly as input for a particular classification algorithm. In Section 4.3, they seem to only use the probabilities as a representation and not the learned prototypes (please correct me if I'm wrong).
I'm not sure if we need to change the implementation so much as the documentation. E.g. we might want to reconsider calling this a pre-processing method, make it clear in the API docs/user guide how the predicted probabilities are derived, show how intermediate transformations could be used in a pipeline with a different classification algorithm, etc.
Thoughts? @fairlearn/fairlearn-maintainers
|
||
.. _fair_representation_learner: | ||
|
||
Fair Representation Learner |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how we should name this module. On the one hand, fair representation learning has become a category of fair-ml algorithms in the literature, on the other hand, naming something "FairRepresentationLearner" suggests the representations are actually "fair" in a meaningful way, a claim we generally try to avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reading the paper I thought about these intermediate representations as the minimal possible representation that balances utility and decoupling as much as possible from the sensitive groups the individuals belong to. They use the term "sanitised" at some point, and although I don't really like the term itself, maybe "Sanitised Intermediate Representation Learner" would fit here? It is a mouthful though, and a bit too on the nose :)
performance. | ||
|
||
The model minimizes a loss function that consists of three terms: the reconstruction error, | ||
the classification error, and the statistical-parity error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency across our docs we might consider calling this something like "an approximation of the demographic parity difference".
the classification error, and the statistical-parity error. | |
the classification error, and an approximation of the demographic parity difference. |
The formulation is not exactly the same as DPD, as it considers the difference in probability of mapping to a prototype rather than the target variable, but seems more consistent with the rest of our docs.
n_prototypes : int, default=2 | ||
Number of prototypes to use in the latent representation. | ||
|
||
Ax : float, default=1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we typically don't really use capitals for parameters in our docs (unless it's an array), perhaps something like alpha / beta / gamma would be more consistent with naming in other classes?
Number of prototypes to use in the latent representation. | ||
|
||
Ax : float, default=1.0 | ||
Weight for the reconstruction error term in the objective function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight for the reconstruction error term in the objective function. | |
Weight of the reconstruction error term in the objective function. |
Weight for the reconstruction error term in the objective function. | ||
|
||
Ay : float, default=1.0 | ||
Weight for the classification error term in the objective function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight for the classification error term in the objective function. | |
Weight of the classification error term in the objective function. |
random_state : int, np.random.RandomState, or None, default=None | ||
Seed or random number generator for reproducibility. | ||
|
||
optimizer : Literal["L-BFGS-B", "Nelder-Mead", "Powell", "SLSQP", "TNC", "trust-constr", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any advice on which optimizers make the most sense in what scenarios? (If not that's also perfectly fine)
tol : float, default=1e-6 | ||
Convergence tolerance for the optimization algorithm. | ||
|
||
max_iter : int, default=1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above - do we have advice on max_iter for different optimizers?
max_iter : int, default=1000 | ||
Maximum number of iterations for the optimization algorithm. | ||
|
||
Attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This made me realize we haven't added attributes to docstrings of any of the other modules, lol... I would consider dropping the user-defined attributes, but perhaps leave the others (n_iter_
, etc.).
Should we consider adding attributes to all of our docstrings? @fairlearn/fairlearn-maintainers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_iter_
, max_iter
, _classes
and a few others are part of the necessary scikit-learn
compatibility API, not sure if that speaks in favour or not. scikit-learn
adds attributes to their docs, I can see how they make the code more accessible with that extra information bit available. Tools like Copilot or Cursor can now successfully add these automatically for you, update them too (of course you'd need to proofread), so it isn't a huge maintenance burden. I personally am ok with either way.
_fall_back_classifier : LogisticRegression or None | ||
Fallback classifier used when no sensitive features are provided. | ||
|
||
Methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is not really necessary IMO as each method has its own docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! Also the rendering documentation exposes all public methods automatically.
The target values. | ||
|
||
sensitive_features : array-like or None, default=None | ||
Sensitive features to be considered whose groups will be used to enforce statistical |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sensitive features to be considered whose groups will be used to enforce statistical | |
Sensitive features to be considered whose groups will be used to enforce demographic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce sounds perhaps a bit stronger than what the method does, perhaps something like improve/promote/increase/etc.?
before I start the review because I believe this PR will be affected too - I am working on a PR that will address the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you again for all the hard work! Since we cleared out most of the sklearn issues before, I have general stylistic nitpicks and a few questions. I want to think a bit if we could test something else too, but we can also add more tests in a subsequent PR as we discuss reproducibility further for all the methods.
FairRepresentationLearner(max_iter=10, n_prototypes=4) | ||
>>> X_train_transformed = frl.transform(X_train) | ||
>>> X_test_transformed = frl.transform(X_test) | ||
>>> y_hat = frl.predict(X_test) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I know that you are really great with these things, is there any visualisation/plotting that could be interesting here to add? Something simple?
_fall_back_classifier : LogisticRegression or None | ||
Fallback classifier used when no sensitive features are provided. | ||
|
||
Methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! Also the rendering documentation exposes all public methods automatically.
expect_sensitive_features=False, | ||
enforce_binary_labels=True, | ||
) | ||
assert sensitive_features is None or isinstance(sensitive_features, pd.Series) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from a best practices perspective and the limitations that some testing libraries bring, it is best when assert
is used only when testing and debugging. gently erroring out if the code shouldn't continue instead of breaking abruptly would be better, or an if/else could maybe substitute. that could merge with line 285 maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another question here before reading the rest, why does the type need to be Series
? asking for future refactorings.
self, X, y, sensitive_features: pd.Series, random_state: np.random.RandomState | ||
): | ||
""" | ||
Minimize the loss given the sensitive features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was reading the code and slowly putting the method together in my mind as I read, and I thought that maybe two-three sentences of the summarised "how" of the optimisation as a docstring under the one existing might be nice for our future selves :)
+ self._prototype_dim # alpha: the weight of each dimension in the distance computation | ||
) | ||
|
||
def objective(x: np.ndarray, X, y) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might be a personal preference, but I find the function a tad bit too long to be a nested one. it can exist right above as a private method. it will make the optimisation a bit more readable.
) | ||
|
||
def objective(x: np.ndarray, X, y) -> float: | ||
assert x.shape == (self._optimizer_size,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for the assert here as above
Az: float = 1.0, | ||
random_state: int | np.random.RandomState | None = None, | ||
optimizer: Literal[ | ||
"L-BFGS-B", "Nelder-Mead", "Powell", "SLSQP", "TNC", "trust-constr", "COBYLA", "COBYQA" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember the mention of L-BFGS
for optimisation, but where do they other names come from? I see them in the scipy
documentation, are they just the compatible alternatives?
options={"maxiter": self.max_iter}, | ||
) | ||
except Exception as optimization_error: | ||
raise RuntimeError("The loss minimization failed.") from optimization_error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this error be improved with the addition of what folks can do to improve this not to fail, for example increase the max_iter
?
Description
Tackles #1026
Hello ! Here are the most important details imo regarding this implementation:
Comparison to original paper and AIF360 implementation
w
are boundto be in
L-BFGS
one that the paper uses, among a couple of other things to adapt.Fallback Estimator
In order to be sklearn-compatible, the
.fit
should work with the defaultsensitive_features
kwarg value, which isNone
.To achieve that, I made the choice to fallback to fitting a regular
LogisticRegression
if nosensitive_features
are passed.The transformation in such cases will be the identity function, and the predictions are those of the fallback logistic regressor.
Unit-tests
I have to say I feel the amount of unit tests is not much compared to the other modules, but I run out of ideas of code paths to cover, especially that the sklearn estimator checks already tests for many edge cases.
Tests
Documentation
Screenshots