diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..ed8ebf583 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/easy_transformer/EasyTransformer.py b/easy_transformer/EasyTransformer.py index ecd43b371..3c7273970 100644 --- a/easy_transformer/EasyTransformer.py +++ b/easy_transformer/EasyTransformer.py @@ -244,6 +244,7 @@ def forward(self, x): resid_pre = self.hook_resid_pre(x) # [batch, pos, d_model] attn_out = self.hook_attn_out(self.attn(self.ln1(resid_pre))) # [batch, pos, d_model] resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] + mlp_out = self.hook_mlp_out(self.mlp(self.ln2(resid_mid))) # [batch, pos, d_model] resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] return resid_post diff --git a/easy_transformer/experiments.py b/easy_transformer/experiments.py index 1056a5ab4..3898e54cf 100644 --- a/easy_transformer/experiments.py +++ b/easy_transformer/experiments.py @@ -261,15 +261,32 @@ def get_target(self, layer, head): class EasyAblation(EasyExperiment): - def __init__(self, model: EasyTransformer, config: AblationConfig, metric: ExperimentMetric): + """ + Run an ablation experiment according to the config object + Pass semantic_indices not None to average across different index positions + (probably limited used currently, see test_experiments for one usage) + """ + + def __init__(self, model: EasyTransformer, config: AblationConfig, metric: ExperimentMetric, semantic_indices = None): super().__init__(model, config, metric) - assert type(config) == AblationConfig + assert "AblationConfig" in str(type(config)) if self.cfg.mean_dataset is None and config.compute_means: self.cfg.mean_dataset = self.metric.dataset if self.cfg.cache_means and self.cfg.compute_means: self.get_all_mean() + self.semantic_indices = semantic_indices def run_ablation(self): + if self.semantic_indices is not None: + cache = {} + self.model.reset_hooks() + self.model.cache_all(cache) + self.model(self.cfg.mean_dataset) + dataset_length = len(self.cfg.mean_dataset) + for semantic_symbol, semantic_indices in self.semantic_indices.items(): + for hk in self.model.hook_dict.keys(): + if not ("attn.hook_result" in hk): continue + self.mean_cache[hk][list(range(dataset_length)), semantic_indices] = einops.repeat(torch.mean(cache[hk][list(range(dataset_length)), semantic_indices], dim=0, keepdim=False).clone(), "... -> s ...", s=dataset_length) return self.run_experiment() def get_hook(self, layer, head=None): diff --git a/easy_transformer/hook_points.py b/easy_transformer/hook_points.py index 650394d1c..fd91c2cf9 100644 --- a/easy_transformer/hook_points.py +++ b/easy_transformer/hook_points.py @@ -164,7 +164,6 @@ def run_with_hooks( reset_hooks_end (bool): If True, all hooks are removed at the end (ie, including those added in this run) clear_contexts (bool): If True, clears hook contexts whenever hooks are reset - Note that if we want to use backward hooks, we need to set reset_hooks_end to be False, so the backward hooks are still there - this function only runs a forward pass. """ diff --git a/easy_transformer/tests/test_experiments.py b/easy_transformer/tests/test_experiments.py new file mode 100644 index 000000000..fea74fdaa --- /dev/null +++ b/easy_transformer/tests/test_experiments.py @@ -0,0 +1,139 @@ +# Import stuff +from typing import Callable, Union, List, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import numpy as np +import einops +from tqdm import tqdm + +from easy_transformer.hook_points import HookedRootModule, HookPoint +from easy_transformer.utils import ( + gelu_new, + to_numpy, + get_corner, + print_gpu_mem, + get_sample_from_dataset, +) +from easy_transformer.EasyTransformer import EasyTransformer +from easy_transformer.experiments import ( + ExperimentMetric, + AblationConfig, + EasyAblation, +) + + +def test_semantic_ablation(): + """ + Compute semantic ablation + in a manual way, and then + in the experiments.py way and check that they agree + """ + + # so we don't have to add the IOI dataset object to this library... + ioi_text_prompts = [ + "Then, Christina and Samantha were working at the grocery store. Samantha decided to give a kiss to Christina", + "Then, Samantha and Christina were working at the grocery store. Christina decided to give a kiss to Samantha", + "When Timothy and Dustin got a kiss at the grocery store, Dustin decided to give it to Timothy", + "When Dustin and Timothy got a kiss at the grocery store, Timothy decided to give it to Dustin", + ] + ioi_io_ids = [33673, 34778, 22283, 37616] + ioi_s_ids = [34778, 33673, 37616, 22283] + ioi_end_idx = [18, 18, 17, 17] + semantic_indices = {"IO": [2, 2, 1, 1], "S": [4, 4, 3, 3], "S2": [12, 12, 12, 12]} + L = len(ioi_text_prompts) + + def logit_diff(model, text_prompts): + """Difference between the IO and the S logits (at the "to" token)""" + logits = model(text_prompts).detach() + IO_logits = logits[torch.arange(len(text_prompts)), ioi_end_idx, ioi_io_ids] + S_logits = logits[torch.arange(len(text_prompts)), ioi_end_idx, ioi_s_ids] + return (IO_logits - S_logits).mean().detach().cpu() + + model = EasyTransformer("gpt2", use_attn_result=True) + if torch.cuda.is_available(): + model.to("cuda") + + # compute in the proper way + metric = ExperimentMetric( + metric=logit_diff, dataset=ioi_text_prompts, relative_metric=True + ) + config = AblationConfig( + abl_type="mean", + mean_dataset=ioi_text_prompts, + target_module="attn_head", + head_circuit="result", + cache_means=True, + verbose=True, + ) + abl = EasyAblation(model, config, metric, semantic_indices=semantic_indices) + result = abl.run_ablation() + + # compute in a manual way + model.reset_hooks() + cache = {} + model.cache_all(cache) + logits = model(ioi_text_prompts) + io_logits = logits[list(range(L)), ioi_end_idx, ioi_io_ids] + s_logits = logits[list(range(L)), ioi_end_idx, ioi_s_ids] + diff_logits = io_logits - s_logits + avg_logits = diff_logits.mean() + max_seq_length = cache["hook_embed"].shape[1] + assert list(cache["hook_embed"].shape) == [ + L, + max_seq_length, + model.cfg["d_model"], + ], cache["hook_embed"].shape + average_activations = {} + for key in cache.keys(): + if "attn.hook_result" not in key: + continue + tens = cache[key].detach().cpu() + avg_tens = torch.mean(tens, dim=0, keepdim=False) + cache[key] = einops.repeat(avg_tens, "... -> s ...", s=L) + + for thing in ["IO", "S", "S2"]: + thing_average = ( + tens[list(range(L)), semantic_indices[thing], :, :] + .detach() + .cpu() + .mean(dim=0) + ) + cache[key][ + list(range(L)), semantic_indices[thing], :, : + ] = thing_average.clone() + diffs = torch.zeros((model.cfg["n_layers"], model.cfg["n_heads"])) + diffs += avg_logits.item() + for layer in tqdm(range(model.cfg["n_layers"])): + for head in range(model.cfg["n_heads"]): + new_val = ( + cache[f"blocks.{layer}.attn.hook_result"][:, :, head, :] + .detach() + .clone() + ) + + def ablate_my_head(x, hook): + x[:, :, head, :] = new_val + return x + + model.reset_hooks() + new_logits = model.run_with_hooks( + ioi_text_prompts, + fwd_hooks=[(f"blocks.{layer}.attn.hook_result", ablate_my_head)], + ) + + new_io_logits = new_logits[list(range(L)), ioi_end_idx, ioi_io_ids] + new_s_logits = new_logits[list(range(L)), ioi_end_idx, ioi_s_ids] + new_diff_logits = new_io_logits - new_s_logits + new_avg_logits = new_diff_logits.mean() + diffs[layer][head] /= new_avg_logits.item() + diffs -= 1.0 + + assert torch.allclose( + diffs, result, rtol=1e-4, atol=1e-4 + ), f"{get_corner(diffs)}, {get_corner(result)}" + + +if __name__ == "__main__": + test_semantic_ablation()