Skip to content

Commit 4e00460

Browse files
authored
Add support for document masking during training (#661)
1 parent b45002e commit 4e00460

12 files changed

+238
-16
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Added support for document masking via flash-attn during training with `--data.generate_doc_lengths`.
1213
- Added config options for `model.norm_after`, `model.scale_emb_init`, and `auxiliary_loss_multiplier` (used with zloss).
1314
- Added scripts for running experiments on qk_norm, norm reordering, and zloss.
1415

olmo/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ class DataConfig(BaseConfig):
594594
label_mask_paths: Optional[List[str]] = None
595595
pad_direction: PaddingDirection = PaddingDirection.right
596596
generate_attention_mask: bool = False
597+
generate_doc_lengths: bool = False
597598
num_workers: int = 0
598599
drop_last: bool = False
599600
pin_memory: bool = False

olmo/data/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ..aliases import PathOrStr
77
from ..config import DataConfig, TrainConfig
88
from ..exceptions import OLMoConfigurationError
9-
from ..torch_util import barrier, get_global_rank, get_world_size, is_distributed
9+
from ..torch_util import barrier, get_global_rank, get_world_size
1010
from .collator import DataCollator
1111
from .iterable_dataset import IterableDataset
1212
from .memmap_dataset import MemMapDataset
@@ -40,7 +40,9 @@ def build_memmap_dataset(
4040
metadata=metadata,
4141
include_instance_metadata=include_instance_metadata,
4242
pad_token_id=train_config.model.pad_token_id,
43+
eos_token_id=train_config.model.eos_token_id,
4344
generate_attention_mask=data_config.generate_attention_mask,
45+
generate_doc_lengths=data_config.generate_doc_lengths,
4446
label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
4547
instance_filter_config=data_config.instance_filter,
4648
)

olmo/data/collator.py

+15
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
3030
all_indices = []
3131
all_metadata = []
3232
all_instance_mask = []
33+
all_doc_lens = []
34+
all_max_doc_lens = []
35+
max_docs = max((len(x["doc_lens"]) if isinstance(x, dict) and "doc_lens" in x else 0 for x in items))
36+
3337
for x in items:
3438
input_ids = x["input_ids"] if isinstance(x, dict) else x
3539
if not isinstance(input_ids, torch.Tensor):
@@ -103,6 +107,13 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
103107
if instance_mask is not None:
104108
all_instance_mask.append(torch.tensor(instance_mask))
105109

110+
# Document lengths.
111+
doc_lens = x.get("doc_lens") if isinstance(x, dict) else None
112+
if doc_lens is not None:
113+
doc_pad_shape = (0, max_docs - len(doc_lens))
114+
all_doc_lens.append(F.pad(doc_lens, doc_pad_shape, value=0))
115+
all_max_doc_lens.append(int(doc_lens.max()))
116+
106117
# Metadata.
107118
metadata = x.get("metadata") if isinstance(x, dict) else None
108119
if metadata is not None:
@@ -119,6 +130,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
119130
out["index"] = torch.stack(all_indices)
120131
if all_instance_mask:
121132
out["instance_mask"] = torch.stack(all_instance_mask)
133+
if all_doc_lens:
134+
out["doc_lens"] = torch.stack(all_doc_lens)
135+
if all_max_doc_lens:
136+
out["max_doc_lens"] = all_max_doc_lens
122137
if all_metadata:
123138
out["metadata"] = all_metadata
124139

olmo/data/memmap_dataset.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4-
from typing import Any, Dict, List, Optional, Tuple, Union
4+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
55

66
import numpy as np
77
import torch
@@ -12,7 +12,7 @@
1212
from ..aliases import PathOrStr
1313
from ..config import InstanceFilterConfig
1414
from ..util import _get_s3_client, file_size, get_bytes_range
15-
from .util import find_periodic_sequences
15+
from .util import find_periodic_sequences, get_document_lengths
1616

1717
__all__ = ["MemMapDataset"]
1818

@@ -47,20 +47,25 @@ def __init__(
4747
self,
4848
*paths: PathOrStr,
4949
chunk_size: int = 1024,
50-
memmap_dtype=np.uint16,
50+
memmap_dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16,
5151
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
5252
include_instance_metadata: bool = True,
5353
generate_attention_mask: bool = False,
54+
generate_doc_lengths: bool = False,
5455
pad_token_id: Optional[int] = None,
56+
eos_token_id: Optional[int] = None,
5557
label_mask_paths: Optional[List[PathOrStr]] = None,
5658
instance_filter_config: Optional[InstanceFilterConfig] = None,
5759
):
5860
if not paths:
5961
raise ValueError("At least one path is required")
6062

61-
if generate_attention_mask and not pad_token_id:
63+
if generate_attention_mask and pad_token_id is None:
6264
raise ValueError("'pad_token_id' is required for 'generate_attention_mask'")
6365

66+
if generate_doc_lengths and eos_token_id is None:
67+
raise ValueError("'eos_token_id' is required for 'generate_cu_doc_lengths'")
68+
6469
if label_mask_paths and len(label_mask_paths) != len(paths):
6570
raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'")
6671

@@ -79,7 +84,9 @@ def __init__(
7984
self.dtype = memmap_dtype
8085
self._include_instance_metadata = include_instance_metadata
8186
self._generate_attention_mask = generate_attention_mask
87+
self._generate_doc_lengths = generate_doc_lengths
8288
self._pad_token_id = pad_token_id
89+
self._eos_token_id = eos_token_id
8390
self.instance_filter_config = instance_filter_config
8491

8592
@property
@@ -207,6 +214,10 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
207214
attn_mask.masked_fill_(input_ids == self._pad_token_id, 0)
208215
out["attention_mask"] = attn_mask
209216

217+
if self._generate_doc_lengths:
218+
assert self._eos_token_id is not None
219+
out["doc_lens"] = get_document_lengths(input_ids, self._eos_token_id)
220+
210221
return out
211222

212223
def __add__(self, other: MemMapDataset) -> MemMapDataset:

olmo/data/util.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Generator, List, NamedTuple
22

33
import numpy as np
4+
import torch
45

56

67
def find_end_first_consecutive_true(arr: np.ndarray) -> int:
@@ -116,3 +117,14 @@ def find_periodic_sequences(
116117
# cannot accurately determine the period of a sequence that repeats
117118
# less than 3 times with this algorithm
118119
yield out
120+
121+
122+
def get_document_lengths(input_ids: torch.Tensor, eos_token_id: int) -> torch.Tensor:
123+
doc_boundaries = torch.cat(
124+
[
125+
torch.tensor([-1], dtype=torch.int32),
126+
(input_ids == eos_token_id).nonzero(as_tuple=True)[0].to(dtype=torch.int32),
127+
torch.tensor([] if input_ids[-1] == eos_token_id else [input_ids.shape[0] - 1], dtype=torch.int32),
128+
]
129+
)
130+
return doc_boundaries[1:] - doc_boundaries[:-1]

0 commit comments

Comments
 (0)