You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a text sentiment polarity prediction model, roughly structured as RoBERTa + CNN. Now, I want to use InterpretML to explain its prediction results. My code is as follows:
from interpret.glassbox import ExplainableBoostingClassifier
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
def analysis_interpret(target_name: str, text_list: list, sentiment_list: list):
ebm = ExplainableBoostingClassifier()
data_generator = DataGenerator(text_list, sentiment_list)
X_train = [np.ravel(arr) for arr in data_generator.input_ids]
X_train = pad_sequences(X_train)
X_train = np.array(X_train)
y_train = sentiment_list
ebm.fit(X_train, y_train)
ebm_local = ebm.explain_local(X_train, y_train)
Where DataGenerator is the text processing class for my model. Here, I'm temporarily using RoBERTa's tokenizer to map the text to the required token IDs for modeling. y_train represents the labels predicted by my model. After the statement ebm_local = ebm.explain_local(X_train, y_train), how can I obtain the importance of each word? I have seen people using the ebm_local.get_local_importance_dict() method, but I can't find this method in version 0.5.1.
The text was updated successfully, but these errors were encountered:
I have a text sentiment polarity prediction model, roughly structured as RoBERTa + CNN. Now, I want to use InterpretML to explain its prediction results. My code is as follows:
Where
DataGenerator
is the text processing class for my model. Here, I'm temporarily using RoBERTa's tokenizer to map the text to the required token IDs for modeling. y_train represents the labels predicted by my model. After the statementebm_local = ebm.explain_local(X_train, y_train)
, how can I obtain the importance of each word? I have seen people using theebm_local.get_local_importance_dict()
method, but I can't find this method in version 0.5.1.The text was updated successfully, but these errors were encountered: