Skip to content

Commit

Permalink
Add support for document masking during training (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jul 19, 2024
1 parent b45002e commit 4e00460
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for document masking via flash-attn during training with `--data.generate_doc_lengths`.
- Added config options for `model.norm_after`, `model.scale_emb_init`, and `auxiliary_loss_multiplier` (used with zloss).
- Added scripts for running experiments on qk_norm, norm reordering, and zloss.

Expand Down
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class DataConfig(BaseConfig):
label_mask_paths: Optional[List[str]] = None
pad_direction: PaddingDirection = PaddingDirection.right
generate_attention_mask: bool = False
generate_doc_lengths: bool = False
num_workers: int = 0
drop_last: bool = False
pin_memory: bool = False
Expand Down
4 changes: 3 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..aliases import PathOrStr
from ..config import DataConfig, TrainConfig
from ..exceptions import OLMoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size, is_distributed
from ..torch_util import barrier, get_global_rank, get_world_size
from .collator import DataCollator
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset
Expand Down Expand Up @@ -40,7 +40,9 @@ def build_memmap_dataset(
metadata=metadata,
include_instance_metadata=include_instance_metadata,
pad_token_id=train_config.model.pad_token_id,
eos_token_id=train_config.model.eos_token_id,
generate_attention_mask=data_config.generate_attention_mask,
generate_doc_lengths=data_config.generate_doc_lengths,
label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
instance_filter_config=data_config.instance_filter,
)
Expand Down
15 changes: 15 additions & 0 deletions olmo/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
all_indices = []
all_metadata = []
all_instance_mask = []
all_doc_lens = []
all_max_doc_lens = []
max_docs = max((len(x["doc_lens"]) if isinstance(x, dict) and "doc_lens" in x else 0 for x in items))

for x in items:
input_ids = x["input_ids"] if isinstance(x, dict) else x
if not isinstance(input_ids, torch.Tensor):
Expand Down Expand Up @@ -103,6 +107,13 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
if instance_mask is not None:
all_instance_mask.append(torch.tensor(instance_mask))

# Document lengths.
doc_lens = x.get("doc_lens") if isinstance(x, dict) else None
if doc_lens is not None:
doc_pad_shape = (0, max_docs - len(doc_lens))
all_doc_lens.append(F.pad(doc_lens, doc_pad_shape, value=0))
all_max_doc_lens.append(int(doc_lens.max()))

# Metadata.
metadata = x.get("metadata") if isinstance(x, dict) else None
if metadata is not None:
Expand All @@ -119,6 +130,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["index"] = torch.stack(all_indices)
if all_instance_mask:
out["instance_mask"] = torch.stack(all_instance_mask)
if all_doc_lens:
out["doc_lens"] = torch.stack(all_doc_lens)
if all_max_doc_lens:
out["max_doc_lens"] = all_max_doc_lens
if all_metadata:
out["metadata"] = all_metadata

Expand Down
19 changes: 15 additions & 4 deletions olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand All @@ -12,7 +12,7 @@
from ..aliases import PathOrStr
from ..config import InstanceFilterConfig
from ..util import _get_s3_client, file_size, get_bytes_range
from .util import find_periodic_sequences
from .util import find_periodic_sequences, get_document_lengths

__all__ = ["MemMapDataset"]

Expand Down Expand Up @@ -47,20 +47,25 @@ def __init__(
self,
*paths: PathOrStr,
chunk_size: int = 1024,
memmap_dtype=np.uint16,
memmap_dtype: Union[Type[np.uint8], Type[np.uint16], Type[np.uint32], Type[np.uint64]] = np.uint16,
metadata: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None,
include_instance_metadata: bool = True,
generate_attention_mask: bool = False,
generate_doc_lengths: bool = False,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
label_mask_paths: Optional[List[PathOrStr]] = None,
instance_filter_config: Optional[InstanceFilterConfig] = None,
):
if not paths:
raise ValueError("At least one path is required")

if generate_attention_mask and not pad_token_id:
if generate_attention_mask and pad_token_id is None:
raise ValueError("'pad_token_id' is required for 'generate_attention_mask'")

if generate_doc_lengths and eos_token_id is None:
raise ValueError("'eos_token_id' is required for 'generate_cu_doc_lengths'")

if label_mask_paths and len(label_mask_paths) != len(paths):
raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'")

Expand All @@ -79,7 +84,9 @@ def __init__(
self.dtype = memmap_dtype
self._include_instance_metadata = include_instance_metadata
self._generate_attention_mask = generate_attention_mask
self._generate_doc_lengths = generate_doc_lengths
self._pad_token_id = pad_token_id
self._eos_token_id = eos_token_id
self.instance_filter_config = instance_filter_config

@property
Expand Down Expand Up @@ -207,6 +214,10 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
attn_mask.masked_fill_(input_ids == self._pad_token_id, 0)
out["attention_mask"] = attn_mask

if self._generate_doc_lengths:
assert self._eos_token_id is not None
out["doc_lens"] = get_document_lengths(input_ids, self._eos_token_id)

return out

def __add__(self, other: MemMapDataset) -> MemMapDataset:
Expand Down
12 changes: 12 additions & 0 deletions olmo/data/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Generator, List, NamedTuple

import numpy as np
import torch


def find_end_first_consecutive_true(arr: np.ndarray) -> int:
Expand Down Expand Up @@ -116,3 +117,14 @@ def find_periodic_sequences(
# cannot accurately determine the period of a sequence that repeats
# less than 3 times with this algorithm
yield out


def get_document_lengths(input_ids: torch.Tensor, eos_token_id: int) -> torch.Tensor:
doc_boundaries = torch.cat(
[
torch.tensor([-1], dtype=torch.int32),
(input_ids == eos_token_id).nonzero(as_tuple=True)[0].to(dtype=torch.int32),
torch.tensor([] if input_ids[-1] == eos_token_id else [input_ids.shape[0] - 1], dtype=torch.int32),
]
)
return doc_boundaries[1:] - doc_boundaries[:-1]
Loading

0 comments on commit 4e00460

Please sign in to comment.