Skip to content

Commit

Permalink
Revert "Add steer feature explorer"
Browse files Browse the repository at this point in the history
This reverts commit f76cae5.
  • Loading branch information
seplee committed Jul 23, 2024
1 parent 9f6ce36 commit 7a91220
Showing 1 changed file with 16 additions and 33 deletions.
49 changes: 16 additions & 33 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,12 @@ def set_SAE(self, sae_name):
self.sae = sae
self.cfg_dict = cfg_dict

def get_feature_info(self):
projection_onto_unembed = self.sae.W_dec @ self.model.W_U
# get the top ten words associated with the given feature
WORD_COUNT = 10
_, inds = torch.topk(projection_onto_unembed, WORD_COUNT, dim=1)

_, sv_feature_acts = self._get_sae_out_and_feature_activations()
features = self._get_features(sv_feature_acts)
breakpoint();
associated_words = [self.model.to_str_tokens(inds[f]) for f in features]
return associated_words

def _get_sae_out_and_feature_activations(self):
# given the words in steering_vector_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated
# given the words in steering_vectore_prompt, the SAE predicts that the neurons(aka features) in activateCache will be activated
sv_logits, activationCache = self.model.run_with_cache(self.steering_vector_prompt, prepend_bos=True)
sv_feature_acts = self.sae.encode(activationCache[self.sae.cfg.hook_name])
# get top_k of 1
# self.sae_out = sae.decode(sv_feature_acts)
return self.sae.decode(sv_feature_acts), sv_feature_acts

def _hooked_generate(self, prompt_batch, fwd_hooks, seed=None, **kwargs):
Expand All @@ -79,7 +69,9 @@ def _get_features(self, sv_feature_activations):
# return torch.topk(sv_feature_acts, 1).indices.tolist()
features = torch.topk(sv_feature_activations, 1).indices
print(f'features that align with the text prompt: {features}')
return features[0]
print("pump the features into the tool that gives you the words associated with each feature")
return features


def _get_steering_hook(self, feature, sae_out):
coeff = self.coeff
Expand All @@ -101,18 +93,20 @@ def _get_steering_hooks(self):
# and not use the seperate function _get_steering_hook()
sae_out, sv_feature_acts = self._get_sae_out_and_feature_activations()
features = self._get_features(sv_feature_acts)
steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features]
steering_hooks = [self._get_steering_hook(feature, sae_out) for feature in features[0]]

return steering_hooks


def _run_generate(self, example_prompt, steering_on: bool):

self.model.reset_hooks()
steer_hooks = self._get_steering_hooks()
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks]
# editing_hooks = [(self.sae_id, steer_hook)]
# ^^change this to support steer_hooks being a list of steer_hooks
print(f"steering by {len(editing_hooks)} hooks")
if steering_on:
steer_hooks = self._get_steering_hooks()
editing_hooks = [ (self.sae_id, steer_hook) for steer_hook in steer_hooks]
print(f"steering by {len(editing_hooks)} hooks")
res = self._hooked_generate([example_prompt] * 3, editing_hooks, seed=None, **self.sampling_kwargs)
else:
tokenized = self.model.to_tokens([example_prompt])
Expand All @@ -135,12 +129,12 @@ def generate(self, message: str, steering_on: bool):



# MODEL = "gemma-2b"
# PRETRAINED_SAE = "gemma-2b-res-jb"
MODEL = "gemma-2b"
PRETRAINED_SAE = "gemma-2b-res-jb"
MODEL = "gpt2-small"
PRETRAINED_SAE = "gpt2-small-res-jb"
LAYER = 10
chatbot_model = Inference(MODEL, PRETRAINED_SAE, LAYER)
chatbot_model = Inference(MODEL,PRETRAINED_SAE, LAYER)


import time
Expand All @@ -159,15 +153,6 @@ def slow_echo_steering(message, history):
time.sleep(0.01)
yield result[: i + 1]

def populate_related_features():
features = chatbot_model.get_feature_info()
print(features)
return features[0]
# for feature in features:
# for i in range(len(feature)):
# time.sleep(0.01)
# yield feature[: i + 1]

with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("*STANDARD HEXTER BOT*")
Expand Down Expand Up @@ -197,14 +182,12 @@ def populate_related_features():
)
with gr.Row():
steering_prompt = gr.Textbox(label="Steering prompt", value="Golden Gate Bridge")
found_features = gr.Textbox(label="Found Features")
find_features = gr.Button("Find Related Features")
find_features.click(fn=populate_related_features,inputs=None, outputs=found_features)
with gr.Row():
coeff = gr.Slider(1, 1000, 300, label="Coefficient", info="Coefficient is..", interactive=True)
with gr.Row():
temp = gr.Slider(0, 5, 1, label="Temperature", info="Temperature is..", interactive=True)

# Set up an action when the sliders change
temp.change(chatbot_model.set_temperature, inputs=[temp], outputs=[])
coeff.change(chatbot_model.set_coeff, inputs=[coeff], outputs=[])
chatbot_model.set_steering_vector_prompt(steering_prompt.value)
Expand Down

0 comments on commit 7a91220

Please sign in to comment.