Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
85d549a
softcapping
ArthurZucker Jun 28, 2024
eba5191
soft cap before the mask
ArthurZucker Jun 28, 2024
b9e4a54
style
ArthurZucker Jun 28, 2024
514a839
...
ArthurZucker Jun 28, 2024
7544feb
super nit
ArthurZucker Jun 28, 2024
be1b8c3
update
ArthurZucker Oct 21, 2024
0e0511f
fixes
ArthurZucker Oct 21, 2024
03ccc22
update
ArthurZucker Oct 21, 2024
bdda724
small issue with modular
ArthurZucker Oct 21, 2024
a2b6b12
fix modular imports
ArthurZucker Oct 21, 2024
9365c1b
update
ArthurZucker Oct 21, 2024
2108ee3
fixup
ArthurZucker Oct 21, 2024
520120a
simplify a hell lot
ArthurZucker Oct 21, 2024
314ed1f
simplify cleaning imports
ArthurZucker Oct 22, 2024
8830473
finish fixing
ArthurZucker Oct 22, 2024
e4c19d7
update our design
ArthurZucker Oct 22, 2024
7922210
nits
ArthurZucker Oct 22, 2024
fa1319d
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 1, 2024
43c68f6
use a deprecation cycle
ArthurZucker Nov 1, 2024
1aec944
updates
ArthurZucker Nov 1, 2024
93b53ef
Fix modular (recursive deps need to always be computed after merges!)
Cyrilvallez Nov 1, 2024
6f3cabb
Merge branch 'gemma-capping' of github.com:huggingface/transformers i…
ArthurZucker Nov 1, 2024
a79c4a9
push
ArthurZucker Nov 1, 2024
4c6d299
fix
ArthurZucker Nov 1, 2024
607c45d
update
ArthurZucker Nov 1, 2024
4598bba
fix modular order
Cyrilvallez Nov 1, 2024
5727270
make fix-copies
ArthurZucker Nov 1, 2024
198b4c4
updates
ArthurZucker Nov 1, 2024
3d35151
update
ArthurZucker Nov 1, 2024
da050cd
?
ArthurZucker Nov 1, 2024
e02078c
don't compile for now
ArthurZucker Nov 1, 2024
5861bbf
?
ArthurZucker Nov 4, 2024
8c47da2
fix some stuff
ArthurZucker Nov 4, 2024
09a88d9
donc!
ArthurZucker Nov 4, 2024
c06b530
fix copies
ArthurZucker Nov 4, 2024
89e6f85
update
ArthurZucker Nov 4, 2024
152e0b7
fixup
ArthurZucker Nov 4, 2024
46d8fa7
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 4, 2024
006e869
?
ArthurZucker Nov 4, 2024
159c65a
fix two tests
ArthurZucker Nov 4, 2024
56ea5b9
fix?
ArthurZucker Nov 4, 2024
4c3deb9
for now, don't use head info
ArthurZucker Nov 4, 2024
9e3609d
eager when output attentoin and sdpa or flash as it's the simplest be…
ArthurZucker Nov 4, 2024
21edaed
fix-copies
ArthurZucker Nov 4, 2024
b5d9819
revert sdpa check
ArthurZucker Nov 4, 2024
5a3dade
Apply suggestions from code review
ArthurZucker Nov 6, 2024
faf433b
Merge branch 'main' of github.com:huggingface/transformers into gemma…
ArthurZucker Nov 6, 2024
1da75e1
rebase, fix-copies and push
ArthurZucker Nov 6, 2024
aca9120
add a slow integration test
ArthurZucker Nov 6, 2024
8f1fc5e
update the test
ArthurZucker Nov 19, 2024
5be3bab
fix left padding issue
ArthurZucker Nov 19, 2024
3e5b87a
fix test
ArthurZucker Nov 19, 2024
0513aff
remove duplicate scaling
ArthurZucker Nov 19, 2024
480aff8
quality
ArthurZucker Nov 19, 2024
603fce8
Merge branch 'main' into gemma-capping
ArthurZucker Nov 19, 2024
2a765d6
add a small test and make sure it works
ArthurZucker Nov 19, 2024
fb184be
Merge branch 'gemma-capping' of github.com:huggingface/transformers i…
ArthurZucker Nov 19, 2024
6aba68c
2b
ArthurZucker Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,7 @@ def _autoset_attn_implementation(
"eager",
"sdpa",
"flash_attention_2",
"flex_attention",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
Expand Down
414 changes: 151 additions & 263 deletions src/transformers/models/gemma2/modeling_gemma2.py

Large diffs are not rendered by default.

418 changes: 187 additions & 231 deletions src/transformers/models/gemma2/modular_gemma2.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
is_torch_fp16_available_on_device,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_greater_or_equal,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,14 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)


@lru_cache()
def is_torch_greater_or_equal(library_version: str):
if not _is_package_available("torch"):
return False

return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version)


def is_torchdistx_available():
return _torchdistx_available

Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-5))
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, atol=1e-5, rtol=1e-5)

@pytest.mark.generate
def test_past_key_values_format(self):
Expand Down
60 changes: 44 additions & 16 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,6 @@ def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_l
def test_sdpa_equivalence(self):
pass

def test_eager_attention_loaded_by_default(self):
"""Gemma 2 + SDPA = inferior results, because of the logit softcapping. Eager is the default."""
config, _ = self.model_tester.prepare_config_and_inputs_for_common()

# Usually we enable SDPA by default, but not for Gemma2
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "eager")

# We can still force SDPA
config._attn_implementation = "sdpa"
model = Gemma2Model(config)
self.assertTrue(model.config._attn_implementation == "sdpa")


@slow
@require_torch_gpu
Expand Down Expand Up @@ -277,9 +264,30 @@ def test_model_9b_pipeline_bf16(self):
"Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to(
torch_device
)
model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

output = pipe(self.input_text, max_new_tokens=20, do_sample=False, padding=True)

self.assertEqual(output[0][0]["generated_text"], EXPECTED_TEXTS[0])
self.assertEqual(output[1][0]["generated_text"], EXPECTED_TEXTS[1])

@require_read_token
def test_model_2b_pipeline_bf16_flex_attention(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-2b"
# EXPECTED_TEXTS should match the same non-pipeline test, minus the special tokens
EXPECTED_TEXTS = [
"Hello I am doing a project on the 1960s and I am trying to find out what the average",
"Hi today I'm going to be talking about the 10 best anime of all time.\n\n1",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

Expand Down Expand Up @@ -365,3 +373,23 @@ def test_export_static_cache(self):
)
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)

@require_read_token
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
"<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many",
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America",
]

model = AutoModelForCausalLM.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flex_attention"
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)
58 changes: 33 additions & 25 deletions utils/modular_model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,37 +153,37 @@ def __init__(self, all_bases: Set[str]):
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode:
# Handle ClassB.call_to_method
if (
isinstance(original_node.value, cst.Name)
m.matches(original_node.value, m.Name())
and original_node.value.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(
value=cst.Call(cst.Name("super")),
)
# Handle ClassB().call_to_method
elif (
isinstance(original_node.value, cst.Call)
and isinstance(original_node.value.func, cst.Name)
m.matches(original_node.value, m.Call())
and m.matches(original_node.value.func, m.Name())
and original_node.value.func.value in self.all_bases
and isinstance(original_node.attr, cst.Name)
and m.matches(original_node.attr, m.Name())
):
# Replace with super().call_to_method
return updated_node.with_changes(func=cst.Attribute(value=cst.Call(func=cst.Name("super"))))
return updated_node

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode:
# Check if the function being called is of the form ClassB().func_a or ClassB.func_a
if isinstance(original_node.func, cst.Attribute) and (
if m.matches(original_node.func, m.Attribute()) and (
# Match ClassB().func_a(...)
(
isinstance(original_node.func.value, cst.Call)
and isinstance(original_node.func.value.func, cst.Name)
m.matches(original_node.func.value, m.Call())
and m.matches(original_node.func.value.func, m.Name())
and original_node.func.value.func.value in self.all_bases
)
or
# Match ClassB.func_a(...)
(isinstance(original_node.func.value, cst.Name) and original_node.func.value.value in self.all_bases)
(m.matches(original_node.func.value, m.Name()) and original_node.func.value.value in self.all_bases)
):
# Check if the first argument is 'self', and remove it
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")):
Expand Down Expand Up @@ -632,8 +632,10 @@ def leave_Module(self, node):
for id, node in self.global_nodes.items():
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line

# Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that
# are not part of the recorded objects (i.e. built-in variables, imports, etc)
def _restrict_dependencies_to_known_entities(self):
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc).
This should be called only after all merging operations have been finalized!!"""
global_objects = set(self.global_nodes.keys())
for object_name, dependencies in self.object_dependency_mapping.items():
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects}
Expand Down Expand Up @@ -814,6 +816,8 @@ def merge_modular_dependencies(self, classes, functions, assignments, object_map
# Correctly re-set the global nodes at this point
self.global_nodes.update(self.functions)
self.global_nodes.update(self.assignments)
# Restrict the dependency mappings to the know entities to avoid Python's built-ins
self._restrict_dependencies_to_known_entities()
# Create the global mapping of recursive dependencies for functions and assignments
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

Expand Down Expand Up @@ -1142,22 +1146,20 @@ def visit_SimpleStatementLine(self, node):
if assigned_variable == "__all__":
self.all_all_to_add = split_all_assignment(node)
else:
self.current_assignment = assigned_variable
self.assignments[assigned_variable] = node

def leave_Module(self, node):
"""When we leave the modular file, we do the following in order:
1. compute the nested (recursive) function and assignment dependencies
2. for each modeling file found in the imports, rename it with the new model name, visit it, and update
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update
its dependency graph with the new function and assignment definitions found in the modular
3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files)
3. compute the nested (recursive) function and assignment dependencies
"""
# Takes care of finalizing our visit
super().leave_Module(node)

# 1. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

# 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
# 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies
self.visited_modules = {}
self.renamers = {}
for file, module in self.model_specific_modules.items():
Expand All @@ -1177,10 +1179,13 @@ def leave_Module(self, node):
# We record it so that we can rename classes later the exact same way
self.renamers[file] = renamer

# 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# 2. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the
# definitions found in the visited files
self.merge_model_specific_imports(self.visited_modules)

# 3. compute the nested (recursive) function and assignment dependencies
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies()

# We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later
# Note that we may visit several of the same file types, thus we save them per file type, not file
self.imported_objects_per_file = defaultdict(set)
Expand All @@ -1200,9 +1205,9 @@ def merge_model_specific_imports(self, visited_modules):
if object_name in visited_module.functions and object_name not in self.functions:
self.functions[object_name] = visited_module.functions[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
Expand All @@ -1212,16 +1217,18 @@ def merge_model_specific_imports(self, visited_modules):
elif object_name in visited_module.assignments and object_name not in self.assignments:
self.assignments[object_name] = visited_module.assignments[object_name]
self.added_objects_file_mapping[object_name] = file
dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None)
dependencies = visited_module.object_dependency_mapping.get(object_name, None)
if dependencies is not None:
self.object_recursive_dependency_mapping[object_name] = dependencies
self.object_dependency_mapping[object_name] = dependencies
for dep in dependencies:
if dep not in self.global_nodes:
self.added_objects_file_mapping[dep] = file
self.assignments[dep] = visited_module.global_nodes[dep]

# Do not forget to re-assign all nodes after the merge
self.global_nodes = {**self.assignments, **self.classes, **self.functions}
# And restric dependencies to those nodes only
self._restrict_dependencies_to_known_entities()

def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that
Expand All @@ -1239,10 +1246,11 @@ def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]:
else:
original_dependencies.append(dep)
# Sort all lists according to the order in their respective file
all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x])
all_dependencies = []
for file, dependencies in other_files_dependencies.items():
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x])
all_dependencies += sorted_dependencies
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x])

# Add all original node first, then merged ones (one file at a time)
for dep in all_dependencies:
Expand Down Expand Up @@ -1485,7 +1493,7 @@ def save_modeling_file(modular_file, converted_file):
parser = argparse.ArgumentParser()
parser.add_argument(
"--files_to_parse",
default=["src/transformers/models/gemma/modular_gemma.py"],
default=["src/transformers/models/gemma2/modular_gemma2.py"],
nargs="+",
help="A list of `modular_xxxx` files that should be converted to single model file",
)
Expand Down