1
1
from __future__ import annotations
2
2
3
3
from copy import deepcopy
4
- from typing import Any , Dict , List , Optional , Tuple , Union
4
+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
5
5
6
6
import numpy as np
7
7
import torch
12
12
from ..aliases import PathOrStr
13
13
from ..config import InstanceFilterConfig
14
14
from ..util import _get_s3_client , file_size , get_bytes_range
15
- from .util import find_periodic_sequences
15
+ from .util import find_periodic_sequences , get_document_lengths
16
16
17
17
__all__ = ["MemMapDataset" ]
18
18
@@ -47,20 +47,25 @@ def __init__(
47
47
self ,
48
48
* paths : PathOrStr ,
49
49
chunk_size : int = 1024 ,
50
- memmap_dtype = np .uint16 ,
50
+ memmap_dtype : Union [ Type [ np . uint8 ], Type [ np . uint16 ], Type [ np . uint32 ], Type [ np . uint64 ]] = np .uint16 ,
51
51
metadata : Optional [Union [List [Dict [str , Any ]], Dict [str , Any ]]] = None ,
52
52
include_instance_metadata : bool = True ,
53
53
generate_attention_mask : bool = False ,
54
+ generate_doc_lengths : bool = False ,
54
55
pad_token_id : Optional [int ] = None ,
56
+ eos_token_id : Optional [int ] = None ,
55
57
label_mask_paths : Optional [List [PathOrStr ]] = None ,
56
58
instance_filter_config : Optional [InstanceFilterConfig ] = None ,
57
59
):
58
60
if not paths :
59
61
raise ValueError ("At least one path is required" )
60
62
61
- if generate_attention_mask and not pad_token_id :
63
+ if generate_attention_mask and pad_token_id is None :
62
64
raise ValueError ("'pad_token_id' is required for 'generate_attention_mask'" )
63
65
66
+ if generate_doc_lengths and eos_token_id is None :
67
+ raise ValueError ("'eos_token_id' is required for 'generate_cu_doc_lengths'" )
68
+
64
69
if label_mask_paths and len (label_mask_paths ) != len (paths ):
65
70
raise ValueError ("There must be the same number of 'label_mask_paths' as there are 'paths'" )
66
71
@@ -79,7 +84,9 @@ def __init__(
79
84
self .dtype = memmap_dtype
80
85
self ._include_instance_metadata = include_instance_metadata
81
86
self ._generate_attention_mask = generate_attention_mask
87
+ self ._generate_doc_lengths = generate_doc_lengths
82
88
self ._pad_token_id = pad_token_id
89
+ self ._eos_token_id = eos_token_id
83
90
self .instance_filter_config = instance_filter_config
84
91
85
92
@property
@@ -207,6 +214,10 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
207
214
attn_mask .masked_fill_ (input_ids == self ._pad_token_id , 0 )
208
215
out ["attention_mask" ] = attn_mask
209
216
217
+ if self ._generate_doc_lengths :
218
+ assert self ._eos_token_id is not None
219
+ out ["doc_lens" ] = get_document_lengths (input_ids , self ._eos_token_id )
220
+
210
221
return out
211
222
212
223
def __add__ (self , other : MemMapDataset ) -> MemMapDataset :
0 commit comments