Skip to content

Commit f5e2ade

Browse files
committed
add llm pretraining & eval
1 parent b8f8293 commit f5e2ade

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

README.md

+9
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,12 @@ rating_methods:
9696
type: inlink_count
9797
```
9898

99+
## Pretraining and Evaluation
100+
101+
After running the crawler, the crawled document ids will be placed in `output_dir` in the configuration file. Run the following command to get the document texts:
102+
103+
```bash
104+
python fetch_docs.py --input_dir <document_ids_dir> --output_dir <document_texts_dir> --num_workers <num_workers>
105+
```
106+
107+
Then you can use the [DCLM](https://github.com/mlfoundations/dclm/) framework to run LLM pretraining and evaluation.

fetch_docs.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
import glob
3+
import os
4+
from functools import partial
5+
from multiprocessing import Pool
6+
7+
from tqdm import tqdm
8+
9+
from corpus_interface import ClueWeb22Api
10+
11+
12+
def fetch(cw22_api: ClueWeb22Api, output_dir: str, input_file: str) -> None:
13+
basename_without_ext = os.path.splitext(os.path.basename(input_file))[0]
14+
output_file = os.path.join(output_dir, f"{basename_without_ext}.jsonl")
15+
16+
with open(input_file, "r") as f:
17+
doc_ids = [line.strip() for line in f]
18+
19+
with open(output_file, "w") as f:
20+
for doc_id in doc_ids:
21+
doc_id = doc_id.strip()
22+
doc = cw22_api.get_clean_text(doc_id)
23+
if doc is not None:
24+
doc_stripped = doc.strip()
25+
if doc_stripped != "":
26+
f.write(doc_stripped + "\n")
27+
28+
29+
def main():
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument("--cw22_root_path", type=str)
32+
parser.add_argument("--input_dir", type=str)
33+
parser.add_argument("--output_dir", type=str, default=None)
34+
parser.add_argument("--num_workers", type=int, default=1)
35+
args = parser.parse_args()
36+
37+
if args.output_dir is not None:
38+
output_dir = args.output_dir
39+
else:
40+
output_dir = (
41+
f"{args.input_dir}_docs"
42+
if not args.input_dir.endswith("/")
43+
else f"{args.input_dir[:-1]}_docs"
44+
)
45+
if os.path.exists(output_dir):
46+
print(f"Output path {output_dir} already exists! Check again!")
47+
return
48+
os.makedirs(output_dir, exist_ok=True)
49+
50+
cw22_api = ClueWeb22Api(args.cw22_root_path)
51+
52+
all_input_files = glob.glob(os.path.join(args.input_dir, "*.txt"))
53+
all_input_files.sort()
54+
print(f"Number of files: {len(all_input_files)}")
55+
56+
fetch_partial = partial(fetch, cw22_api, output_dir)
57+
with Pool(args.num_workers) as p:
58+
for _ in tqdm(p.imap(fetch_partial, all_input_files), total=len(all_input_files)):
59+
pass
60+
61+
62+
if __name__ == "__main__":
63+
main()

0 commit comments

Comments
 (0)