Skip to content

Commit

Permalink
added transformers_config for passing arguments to the transformer (#268
Browse files Browse the repository at this point in the history
)

* added transformers config

* changed def config to include new transformers config

* fixed quotationmarks in config

* removed wierd symbol

* added attention in data classes

* fixed keyerror

* ibid

* added pass of config to forward

* ibid

* fix for init

* fixed tensors in forward

* removed default for attention and added to_doc fix for attn

* reformatted to black (accidentally reformatted via. autopep8)

* added def to transformerdata

* bugfixes - don't get why this does not use the default argument here though

* removed default trfconfig from trfmodel

* updated dummy transformer

* fixed tests

* added Tok2VecTransformer.v2

* changed typing

* fixed type hint

* fixed name of transformer_tok2vec_v2, added def

* fixed default config to match name change

* remove ds

* renamed transformers_config

Co-authored-by: KennethEnevoldsen <[email protected]>
  • Loading branch information
KennethEnevoldsen and KennethEnevoldsen committed Jul 8, 2021
1 parent 21e80a6 commit 01b67ef
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 24 deletions.
40 changes: 38 additions & 2 deletions spacy_transformers/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def transformer_listener_tok2vec_v1(
) -> Model[List[Doc], List[Floats2d]]:
"""Create a 'TransformerListener' layer, which will connect to a Transformer
component earlier in the pipeline.
The layer takes a list of Doc objects as input, and produces a list of
2d arrays as output, with each array having one row per token. Most spaCy
models expect a sublayer with this signature, making it easy to connect them
Expand Down Expand Up @@ -46,7 +46,7 @@ def transformer_listener_tok2vec_v1(
def transformer_tok2vec_v1(
name: str,
get_spans,
tokenizer_config,
tokenizer_config: dict,
pooling: Model[Ragged, Floats2d],
grad_factor: float = 1.0,
) -> Model[List[Doc], List[Floats2d]]:
Expand Down Expand Up @@ -74,6 +74,42 @@ def transformer_tok2vec_v1(
)


@registry.architectures.register("spacy-transformers.Tok2VecTransformer.v2")
def transformer_tok2vec_v2(
name: str,
get_spans,
tokenizer_config: dict,
transformer_config: dict,
pooling: Model[Ragged, Floats2d],
grad_factor: float = 1.0,
) -> Model[List[Doc], List[Floats2d]]:
"""Use a transformer as a "Tok2Vec" layer directly. This does not allow
multiple components to share the transformer weights, and does not allow
the transformer to set annotations into the `Doc` object, but it's a
simpler solution if you only need the transformer within one component.
get_spans (Callable[[List[Doc]], List[List[Span]]]): A function to extract
spans from the batch of Doc objects. See the "TransformerModel" layer
for details.
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
transformers_config (dict): Settings to pass to the transformers forward pass
of the transformer.
pooling (Model[Ragged, Floats2d]): A reduction layer used to calculate
the token vectors based on zero or more wordpiece vectors. If in doubt,
mean pooling (see `thinc.layers.reduce_mean`) is usually a good choice.
grad_factor (float): Reweight gradients from the component before passing
them to the transformer. You can set this to 0 to "freeze" the transformer
weights with respect to the component, or to make it learn more slowly.
Leaving it at 1.0 is usually fine.
"""
return chain(
TransformerModel(name, get_spans, tokenizer_config, transformer_config),
split_trf_batch(),
trfs2arrays(pooling, grad_factor),
)



registry.architectures.register(
"spacy-transformers.TransformerModel.v1", func=TransformerModel
)
14 changes: 12 additions & 2 deletions spacy_transformers/data_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Tuple
import torch
import numpy
from transformers.tokenization_utils import BatchEncoding
Expand Down Expand Up @@ -155,11 +155,14 @@ class TransformerData:
wordpieces: WordpieceBatch
tensors: List[FloatsXd]
align: Ragged
attention: Optional[Tuple[FloatsXd, ...]] = None

@classmethod
def empty(cls) -> "TransformerData":
align = Ragged(numpy.zeros((0,), dtype="i"), numpy.zeros((0,), dtype="i"))
return cls(wordpieces=WordpieceBatch.empty(), tensors=[], align=align)
return cls(
wordpieces=WordpieceBatch.empty(), tensors=[], align=align, attention=None
)

@classmethod
def zeros(cls, length: int, width: int, *, xp=numpy) -> "TransformerData":
Expand Down Expand Up @@ -247,6 +250,7 @@ class FullTransformerBatch:
wordpieces: WordpieceBatch
tensors: List[torch.Tensor]
align: Ragged
attention: Optional[Tuple[torch.Tensor]] = None
cached_doc_data: Optional[List[TransformerData]] = None

@classmethod
Expand All @@ -259,6 +263,7 @@ def empty(cls, nr_docs) -> "FullTransformerBatch":
wordpieces=WordpieceBatch.empty(),
tensors=[],
align=align,
attention=None,
cached_doc_data=doc_data,
)

Expand Down Expand Up @@ -312,11 +317,16 @@ def split_by_doc(self) -> List[TransformerData]:
doc_tokens = self.wordpieces[start:end]
doc_align = self.align[start_i:end_i]
doc_align.data = doc_align.data - prev_tokens
if self.attention:
attn = [torch2xp(t[start:end]) for t in self.attention]
else:
attn = None
outputs.append(
TransformerData(
wordpieces=doc_tokens,
tensors=[torch2xp(t[start:end]) for t in self.tensors],
align=doc_align,
attention=attn,
)
)
prev_tokens += doc_tokens.input_ids.size
Expand Down
32 changes: 21 additions & 11 deletions spacy_transformers/layers/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def TransformerModel(
name: str, get_spans: Callable, tokenizer_config: dict
name: str, get_spans: Callable, tokenizer_config: dict = {}, transformer_config: dict = {}
) -> Model[List[Doc], FullTransformerBatch]:
"""
get_spans (Callable[[List[Doc]], List[Span]]):
Expand All @@ -25,6 +25,7 @@ def TransformerModel(
overlap, and you can also omit sections of the Doc if they are not
relevant.
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
transformer_config (dict): Settings to pass to the transformers forward pass.
"""

return Model(
Expand All @@ -38,6 +39,7 @@ def TransformerModel(
"get_spans": get_spans,
"name": name,
"tokenizer_config": tokenizer_config,
"transformer_config": transformer_config,
"set_transformer": set_pytorch_transformer,
"has_transformer": False,
"flush_cache_chance": 0.0,
Expand Down Expand Up @@ -75,7 +77,8 @@ def init(model: Model, X=None, Y=None):
return
name = model.attrs["name"]
tok_cfg = model.attrs["tokenizer_config"]
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg)
trf_cfg = model.attrs["transformer_config"]
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg, trf_cfg)
model.attrs["tokenizer"] = tokenizer
model.attrs["set_transformer"](model, transformer)
# Call the model with a batch of inputs to infer the width
Expand All @@ -89,26 +92,23 @@ def init(model: Model, X=None, Y=None):
for doc_spans in nested_spans:
flat_spans.extend(doc_spans)
token_data = huggingface_tokenize(
model.attrs["tokenizer"],
[span.text for span in flat_spans]
model.attrs["tokenizer"], [span.text for span in flat_spans]
)
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
align = get_alignment(
flat_spans,
wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
flat_spans, wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
)
wordpieces, align = truncate_oversize_splits(
wordpieces, align, tokenizer.model_max_length
)
else:
texts = ["hello world", "foo bar"]
token_data = huggingface_tokenize(
model.attrs["tokenizer"],
texts
)
token_data = huggingface_tokenize(model.attrs["tokenizer"], texts)
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
model.layers[0].initialize(X=wordpieces)
tensors = model.layers[0].predict(wordpieces)
if trf_cfg["output_attentions"] is True:
tensors = tensors[:-1] # remove attention
t_i = find_last_hidden(tensors)
model.set_dim("nO", tensors[t_i].shape[-1])

Expand All @@ -118,6 +118,7 @@ def forward(
) -> Tuple[FullTransformerBatch, Callable]:
tokenizer = model.attrs["tokenizer"]
get_spans = model.attrs["get_spans"]
trf_config = model.attrs["transformer_config"]
transformer = model.layers[0]

nested_spans = get_spans(docs)
Expand All @@ -142,8 +143,17 @@ def forward(
tensors, bp_tensors = transformer(wordpieces, is_train)
if "logger" in model.attrs:
log_gpu_memory(model.attrs["logger"], "after forward")
if ("output_attentions" in trf_config) and (trf_config["output_attentions"] is True):
attn = tensors[-1]
tensors = tensors[:-1]
else:
attn = None
output = FullTransformerBatch(
spans=nested_spans, wordpieces=wordpieces, tensors=tensors, align=align
spans=nested_spans,
wordpieces=wordpieces,
tensors=tensors,
align=align,
attention=attn,
)
if "logger" in model.attrs:
log_gpu_memory(model.attrs["logger"], "return from forward")
Expand Down
15 changes: 11 additions & 4 deletions spacy_transformers/pipeline_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
@architectures = "spacy-transformers.TransformerModel.v1"
name = "roberta-base"
tokenizer_config = {"use_fast": true}
transformer_config = {"output_attentions": false}
[transformer.model.get_spans]
@span_getters = "spacy-transformers.strided_spans.v1"
Expand Down Expand Up @@ -143,7 +144,9 @@ def add_listener(self, listener: TransformerListener, component_name: str) -> No
if self.model.has_dim("nO") and listener.has_dim("nO") is None:
listener.set_dim("nO", self.model.get_dim("nO"))

def remove_listener(self, listener: TransformerListener, component_name: str) -> bool:
def remove_listener(
self, listener: TransformerListener, component_name: str
) -> bool:
"""Remove a listener for a downstream component. Usually internals."""
if component_name in self.listener_map:
if listener in self.listener_map[component_name]:
Expand All @@ -167,7 +170,10 @@ def find_listeners(self, component) -> None:
names = ("*", self.name)
if isinstance(getattr(component, "model", None), Model):
for node in component.model.walk():
if isinstance(node, TransformerListener) and node.upstream_name in names:
if (
isinstance(node, TransformerListener)
and node.upstream_name in names
):
self.add_listener(node, component.name)

def __call__(self, doc: Doc) -> Doc:
Expand Down Expand Up @@ -296,7 +302,8 @@ def accumulate_gradient(d_trf_datas: List[TransformerData]):
nonlocal d_tensors
for i, d_trf_data in enumerate(d_trf_datas):
for d_tensor in d_trf_data.tensors:
losses[self.name] += float((d_tensor ** 2).sum()) # type: ignore
# type: ignore
losses[self.name] += float((d_tensor ** 2).sum())
if i >= len(d_tensors):
d_tensors.append(d_trf_data.tensors)
else:
Expand Down Expand Up @@ -389,7 +396,7 @@ def from_disk(
def load_model(p):
p = Path(p).absolute()
tokenizer, transformer = huggingface_from_pretrained(
p, self.model.attrs["tokenizer_config"]
p, self.model.attrs["tokenizer_config"], self.model.attrs["transformer_config"]
)
self.model.attrs["tokenizer"] = tokenizer
self.model.attrs["set_transformer"](self.model, transformer)
Expand Down
4 changes: 3 additions & 1 deletion spacy_transformers/tests/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def name(request):

@pytest.fixture(scope="session")
def trf_model(name):
model = TransformerModel(name, get_doc_spans, {"use_fast": True})
model = TransformerModel(
name, get_doc_spans, {"use_fast": True}, {"output_attentions": False}
)
model.initialize()
return model

Expand Down
7 changes: 6 additions & 1 deletion spacy_transformers/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def _forward(model, tokens, is_train):
tensors.append(torch.zeros(*shape))
return tensors, lambda d_tensors: tokens

return Model("dummy-transformer", _forward, attrs={"width": width, "depth": depth})
return Model(
"dummy-transformer",
_forward,
attrs={"width": width, "depth": depth},
)


def DummyTransformer(
Expand All @@ -132,6 +136,7 @@ def DummyTransformer(
"tokenizer": DummyTokenizer(),
"grad_factor": 1.0,
"flush_cache_chance": 0.0,
"transformer_config": {}
},
dims={"nO": width},
)
11 changes: 8 additions & 3 deletions spacy_transformers/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Union
from pathlib import Path
from functools import partial
import random
from transformers import AutoModel, AutoTokenizer
from transformers.tokenization_utils import BatchEncoding
Expand All @@ -16,20 +17,24 @@
# fmt: on


def huggingface_from_pretrained(source: Union[Path, str], config: Dict):
def huggingface_from_pretrained(
source: Union[Path, str], tok_config: Dict, trf_config: Dict
):
"""Create a Huggingface transformer model from pretrained weights. Will
download the model if it is not already downloaded.
source (Union[str, Path]): The name of the model or a path to it, such as
'bert-base-cased'.
config (dict): Settings to pass to the tokenizer.
tok_config (dict): Settings to pass to the tokenizer.
trf_config (dict): Settings to pass to the transformer.
"""
if hasattr(source, "absolute"):
str_path = str(source.absolute())
else:
str_path = source
tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
transformer = AutoModel.from_pretrained(str_path)
transformer.forward = partial(transformer.forward, **trf_config)
ops = get_current_ops()
if isinstance(ops, CupyOps):
transformer.cuda()
Expand Down

0 comments on commit 01b67ef

Please sign in to comment.