11#!/usr/bin/env python3
22
3+ import copy
34from collections import namedtuple
45from typing import cast , List , Optional , Union
56
@@ -53,9 +54,11 @@ def __init__(self):
5354
5455 def forward (self , input_ids , * args , ** kwargs ):
5556 emb = self .emb (input_ids )
57+ if "past_key_values" in kwargs :
58+ emb = torch .cat ((kwargs ["past_key_values" ], emb ), dim = 1 )
5659 logits = self .linear (self .trans (emb ))
57- Result = namedtuple ("Result" , ["logits" ])
58- return Result (logits = logits )
60+ Result = namedtuple ("Result" , ["logits" , "past_key_values" ])
61+ return Result (logits = logits , past_key_values = emb )
5962
6063 def generate (self , input_ids , * args , mock_response = None , ** kwargs ):
6164 assert mock_response , "must mock response to use DummyLLM to geenrate"
@@ -64,16 +67,35 @@ def generate(self, input_ids, *args, mock_response=None, **kwargs):
6467 [input_ids , torch .tensor ([response ], device = self .device )], dim = 1
6568 )
6669
70+ def _update_model_kwargs_for_generation (self , outputs , model_kwargs ):
71+ new_kwargs = copy .deepcopy (model_kwargs )
72+ if hasattr (outputs , "past_key_values" ):
73+ new_kwargs ["past_key_values" ] = outputs .past_key_values
74+ return new_kwargs
75+
76+ def prepare_inputs_for_generation (self , model_inp , ** model_kwargs ):
77+ if "past_key_values" in model_kwargs :
78+ emb_len = model_kwargs ["past_key_values" ].shape [1 ]
79+ return {
80+ "input_ids" : model_inp [:, emb_len :],
81+ "past_key_values" : model_kwargs ["past_key_values" ],
82+ }
83+ return {"input_ids" : model_inp }
84+
6785 @property
6886 def device (self ):
6987 return next (self .parameters ()).device
7088
7189
7290@parameterized_class (
73- ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
91+ ("device" , "use_cached_outputs" ),
92+ [("cpu" , True ), ("cpu" , False ), ("cuda" , True ), ("cuda" , False )]
93+ if torch .cuda .is_available ()
94+ else [("cpu" , True ), ("cpu" , False )],
7495)
7596class TestLLMAttr (BaseTest ):
7697 device : str
98+ use_cached_outputs : bool
7799
78100 @parameterized .expand ([(FeatureAblation ,), (ShapleyValueSampling ,)])
79101 def test_llm_attr (self , AttrClass ) -> None :
@@ -83,7 +105,9 @@ def test_llm_attr(self, AttrClass) -> None:
83105 llm_attr = LLMAttribution (AttrClass (llm ), tokenizer )
84106
85107 inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
86- res = llm_attr .attribute (inp , "m n o p q" )
108+ res = llm_attr .attribute (
109+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
110+ )
87111
88112 self .assertEqual (res .seq_attr .shape , (4 ,))
89113 self .assertEqual (cast (Tensor , res .token_attr ).shape , (5 , 4 ))
@@ -100,7 +124,11 @@ def test_llm_attr_without_target(self) -> None:
100124 llm_fa = LLMAttribution (fa , tokenizer )
101125
102126 inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
103- res = llm_fa .attribute (inp , gen_args = {"mock_response" : "x y z" })
127+ res = llm_fa .attribute (
128+ inp ,
129+ gen_args = {"mock_response" : "x y z" },
130+ use_cached_outputs = self .use_cached_outputs ,
131+ )
104132
105133 self .assertEqual (res .seq_attr .shape , (4 ,))
106134 self .assertEqual (cast (Tensor , res .token_attr ).shape , (3 , 4 ))
@@ -117,7 +145,9 @@ def test_llm_attr_fa_log_prob(self) -> None:
117145 llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
118146
119147 inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
120- res = llm_fa .attribute (inp , "m n o p q" )
148+ res = llm_fa .attribute (
149+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
150+ )
121151
122152 # With FeatureAblation, the seq attr in log_prob
123153 # equals to the sum of each token attr
@@ -132,7 +162,9 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
132162 llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
133163
134164 inp = TextTemplateInput ("{} b {} {} e {}" , ["a" , "c" , "d" , "f" ])
135- res = llm_fa .attribute (inp , "m n o p q" )
165+ res = llm_fa .attribute (
166+ inp , "m n o p q" , use_cached_outputs = self .use_cached_outputs
167+ )
136168
137169 self .assertEqual (res .seq_attr .shape , (4 ,))
138170 self .assertEqual (res .seq_attr .device .type , self .device )
0 commit comments