diff --git a/app.py b/app.py index 7b1194c..9363456 100644 --- a/app.py +++ b/app.py @@ -43,12 +43,22 @@ def set_SAE(self, sae_name): self.sae = sae self.cfg_dict = cfg_dict + def get_feature_info(self): + projection_onto_unembed = self.sae.W_dec @ self.model.W_U + # get the top ten words associated with the given feature + WORD_COUNT = 10 + _, inds = torch.topk(projection_onto_unembed, WORD_COUNT, dim=1) + + _, sv_feature_acts = self._get_sae_out_and_feature_activations() + features = self._get_features(sv_feature_acts) + breakpoint(); + associated_words = [self.model.to_str_tokens(inds[f]) for f in features] + return associated_words + def _get_sae_out_and_feature_activations(self): - # given the words in steering_vectore_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated + # given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True) sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name]) - # get top_k of 1 - # self.sae_out = sae.decode(sv_feature_acts) return self.sae.decode(sv_feature_acts), sv_feature_acts def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs): @@ -69,9 +79,7 @@ def _get_features(self, sv_feature_activations): # return torch.topk(sv_feature_acts, 1).indices.tolist() features = torch.topk(sv_feature_activations, 1).indices print(f'features that align with the text prompt: {features}') - print("pump the features into the tool that gives you the words associated with each feature") - return features - + return features[0] def _get_steering_hook(self, feature, sae_out): coeff = self.coeff @@ -93,7 +101,7 @@ def _get_steering_hooks(self): # and not use the seperate function _get_steering_hook() sae_out, sv_feature_acts = self._get_sae_out_and_feature_activations() features = self._get_features(sv_feature_acts) - steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features[0]] + steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features] return steering_hooks @@ -101,12 +109,10 @@ def _get_steering_hooks(self): def _run_generate(self, example_prompt, steering_on: bool): self.model.reset_hooks() - steer_hooks = self._get_steering_hooks() - editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] - # editing_hooks = [(self.sae_id, steer_hook)] - # ^^change this to support steer_hooks being a list of steer_hooks - print(f"steering by {len(editing_hooks)} hooks") if steering_on: + steer_hooks = self._get_steering_hooks() + editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] + print(f"steering by {len(editing_hooks)} hooks") res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs) else: tokenized = self.model.to_tokens([example_prompt]) @@ -129,12 +135,12 @@ def generate(self, message: str, steering_on: bool): -MODEL = "gemma-2b" -PRETRAINED_SAE = "gemma-2b-res-jb" +# MODEL = "gemma-2b" +# PRETRAINED_SAE = "gemma-2b-res-jb" MODEL = "gpt2-small" PRETRAINED_SAE = "gpt2-small-res-jb" LAYER = 10 -chatbot_model = Inference(MODEL,PRETRAINED_SAE, LAYER) +chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER) import time @@ -153,6 +159,15 @@ def slow_echo_steering(message, history): time.sleep(0.01) yield result[: i + 1] +def populate_related_features(): + features = chatbot_model.get_feature_info() + print(features) + return features[0] + # for feature in features: + # for i in range(len(feature)): + # time.sleep(0.01) + # yield feature[: i + 1] + with gr.Blocks() as demo: with gr.Row(): gr.Markdown("*STANDARD HEXTER BOT*") @@ -182,12 +197,14 @@ def slow_echo_steering(message, history): ) with gr.Row(): steering_prompt = gr.Textbox(label="Steering prompt", value="Golden Gate Bridge") + found_features = gr.Textbox(label="Found Features") + find_features = gr.Button("Find Related Features") + find_features.click(fn=populate_related_features,inputs=None, outputs=found_features) with gr.Row(): coeff = gr.Slider(1, 1000, 300, label="Coefficient", info="Coefficient is..", interactive=True) with gr.Row(): temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True) - # Set up an action when the sliders change temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[]) coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[]) chatbot_model.set_steering_vector_prompt(steering_prompt.value)