diff --git a/app.py b/app.py index 7b1194c..8ab1f09 100644 --- a/app.py +++ b/app.py @@ -44,11 +44,9 @@ def set_SAE(self, sae_name): self.cfg_dict = cfg_dict def _get_sae_out_and_feature_activations(self): - # given the words in steering_vectore_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated + # given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True) sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name]) - # get top_k of 1 - # self.sae_out = sae.decode(sv_feature_acts) return self.sae.decode(sv_feature_acts), sv_feature_acts def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs): @@ -101,12 +99,10 @@ def _get_steering_hooks(self): def _run_generate(self, example_prompt, steering_on: bool): self.model.reset_hooks() - steer_hooks = self._get_steering_hooks() - editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] - # editing_hooks = [(self.sae_id, steer_hook)] - # ^^change this to support steer_hooks being a list of steer_hooks - print(f"steering by {len(editing_hooks)} hooks") if steering_on: + steer_hooks = self._get_steering_hooks() + editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks] + print(f"steering by {len(editing_hooks)} hooks") res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs) else: tokenized = self.model.to_tokens([example_prompt]) @@ -129,12 +125,12 @@ def generate(self, message: str, steering_on: bool): -MODEL = "gemma-2b" -PRETRAINED_SAE = "gemma-2b-res-jb" +# MODEL = "gemma-2b" +# PRETRAINED_SAE = "gemma-2b-res-jb" MODEL = "gpt2-small" PRETRAINED_SAE = "gpt2-small-res-jb" LAYER = 10 -chatbot_model = Inference(MODEL,PRETRAINED_SAE, LAYER) +chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER) import time @@ -187,7 +183,6 @@ def slow_echo_steering(message, history): with gr.Row(): temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True) - # Set up an action when the sliders change temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[]) coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[]) chatbot_model.set_steering_vector_prompt(steering_prompt.value)