-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model comparison v2 #1672
Model comparison v2 #1672
Conversation
@@ -509,6 +514,8 @@ def make_json_formatted_for_single_chart(mutant_features, | |||
was sent for inference. The length of that field should be the same length | |||
of mutant_features. | |||
index_to_mutate: The index of the feature being mutated for this chart. | |||
model_id: Index of the model corresponding to inference_results_proto to | |||
distinguish between multiple models for model comparison |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't used anywhere in the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, removed it.
<div class="flex"> | ||
<template is="dom-repeat" items="{{featureValueThreshold.threshold}}"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent the paper-slider
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
immediate-value="{{overallThreshold}}" value="[[overallThreshold]]"> | ||
</paper-slider> | ||
<template is="dom-repeat" items="{{overallThresholds}}" as="overallThreshold"> | ||
<paper-slider class="slider" editable=true min="0" max="1" step="0.01" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add label for sliders when more than one model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
class="conf-matrix"></tf-confusion-matrix> | ||
</tf-confusion-matrix> | ||
<div> | ||
<template is="dom-repeat" items="[[inferenceStats_]]" index-as="modelInd"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label conf matrices when more than one model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
// series | ||
self = this; | ||
const mapped = {}; | ||
_.forEach(data, function(modelValues, modelInd){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i know the code was like this before, but if you remove function() {} and replace with () => {} here and below, then you don't need the self=this, can just use this. also if you were to use "self", should be "const self"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some quick initial comments
|
||
getNumberOfModels_: function(){ | ||
return this.modelName.split(',').length; | ||
}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its strange that each time we want this we have to do this string split. i think numModels should be a data member that is set/updated at any appropriate time during inference or after model name setting, and then is used without needing recalcuating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, moved it to variable numModels which gets updated when user clicks Accept in settings pane.
plotStats.push(inferenceStats.faceted[key]); | ||
plotThresholds.push(modelThresholds[modelInd].threshold) | ||
} | ||
this.plotPr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indenting issue here maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
inferenceMap[this.inferences.indices[i]][0] = | ||
for (let modelNum = 0; modelNum < this.inferences.results.length; | ||
modelNum++){ | ||
const result = this.inferences.results[modelNum].regressionResult; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indenting issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
modelInd < self.getNumberOfModels_(); | ||
modelInd++) { | ||
item[self.strWithModelIndex_(inferenceLabelStr, modelInd)] = | ||
self.labelVocab[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whitespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
return; | ||
} | ||
|
||
// Binary classifier case. | ||
let threshold = this.overallThreshold; | ||
let threshold = this.overallThresholds; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably better if this is called "thresholds" as it is a list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more comments
Thank you James! Should be all set with the comments. |
mutant_examples, serving_bundle) | ||
return make_json_formatted_for_single_chart(mutant_features, | ||
charts = [] | ||
for model_id, serving_bundle in enumerate(serving_bundles): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove model_id and enumerate call as model_id isn't used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, removed!
@@ -222,8 +221,9 @@ | |||
} | |||
|
|||
.pr-line-chart { | |||
height: 150px; | |||
width: 250px; | |||
margin: 6px 6px 6px 6px; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want margin of 6px on all sides you can just do "margin: 6px" - same rules for padding as well. see https://www.w3schools.com/cssref/pr_margin.asp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, done.
transform: rotate(270deg); | ||
font-size: 12px; | ||
color: #3c4043; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where'd this color come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mahima added it yesterday.
this.examplesAndInferences[0].inferences == null || | ||
this.examplesAndInferences[0].inferences[0].length | ||
!= this.numModels || | ||
(this.inferenceStats_.length != this.numModels && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add comments about this check, as its gotten complex with these new cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added description.
// console.log('overallthresholds', this.overallThresholds) | ||
// console.log('inferencestats', this.inferenceStats_) | ||
// console.log('inferences', this.inferences.results) | ||
// console.log('examplesandinferences', this.examplesAndInferences) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented-out debug logs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, done!
}, | ||
]; | ||
// If there is more than 1 model, show model number tooltip | ||
if (data.length >= 2){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space before {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
}, | ||
|
||
formatChartKey: function(labelKey, modelInd, numberOfModels) { | ||
if (numberOfModels == 1){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space before {
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great. some minor comments.
also, have you tested it with multiple multiclass models (can just treat our binary models as multiclass to test that) and multiple regression models?
Did another pass. Added binary check for ROC curve plotting in updateInferenceStats. Before, it only checked for the existence of chart, but if someone switches from binary to multi-class using inference settings window without refreshing the page, the chart elements still exist in the dom which causes an error in plotPR (thresholds don't exist for non-binary inferenceStats.faceted[key] is an object not an array so map in plotPR throws an error). Added "restamp" to binary/multi-class/regression dom-ifs to delete unused chunks when model type changes. There is currently an issue with multi-class selection in inference settings panel where one may trigger a number of thresholds x number of thresholds (100 x 100) confusion matrix when switching from binary to multi-class model without refreshing the page. Without "restamp" this cripples the GUI after that event. With "restamp", the confusion matrix gets cleared after creation so the issue is temporary. Deferring fixing this to a future PR as it is independent from model comparison. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clear PR description—this is very helpful.
Second version of model comparison for what-if-tool plugin.
Motivation for features / changes
PR #1589 added the ability to load two models and compare their inference scores in Facets Dive panel for binary classification and Datapoint Editor tab for classification.
This PR extends the model comparison mode to Global/Partial PDP, Fairness and Performance panels. It also adds comparing regression and multi-class models in all supported panels.
All model prediction related features in facets dive (Inference Value, Inference Error etc.) now also have multiple model support.
The only mode that does not support multiple models is closest counterfactual button which still computes its results on model 1.
Technical description of changes
Some variables that used to handle single models became arrays indexed by model number.
Some string keys for facets features now have model number appended to them.
Add the ability to compare two models them for Global PDP, PDP plots:
Now the Global PDP and PDP plots show one line per model per label in the same chart.
(Previously they only showed one line per label.)
Performance and Fairness panel:
Show one ROC curve per model within the same chart. (Previously showed ROC curve for single model.)
Show one threshold and confusion matrix per model. (Previously showed threshold and confusion matrix for single model.)
Add the ability to load two models for regression:
Facets dive features and MSEs in Performance and Fairness panel work for two regression models.
Minor:
Screenshots of UI changes
all features now support two models
features old
new performance and fairness tab
old performance and fairness tab
pdp shows multiple models
same for multi-class case
categorical pdp shows each model as separate bar and labels original value in text
all regression features are available in facets
regression performance shows both models
Detailed steps to verify changes work correctly (as executed by you)
Tested all offline demos, multi/single model for binary and multi-class models (using the settings panel multi-class option with binary models).
Tested multiple model regression with age demo by duplicating the first model.
Testing involves visually verifying the functionality of all panels by clicking buttons and verifying the outputs and checking for exceptions in the browser console.
Alternate designs / implementations considered
N/A