Skip to content

Commit 6203618

Browse files
committed
transfer the codebase to group organization
0 parents  commit 6203618

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+6189
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
**/__pycache__

README.md

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# ESAE
2+
3+
## Installation
4+
5+
Please ensure that you have [Conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) installed on your system, as it is required for managing dependencies in this project.
6+
7+
To create and activate the Conda environment using the provided `environment.yml` file, please run the following command:
8+
9+
```sh
10+
conda env create -f environment.yml
11+
conda activate esae
12+
```
13+
14+
Before proceeding, please update the system paths specified in `source/__init__.py` to match your configuration. These paths are used for storing the datasets, model checkpoints, and many others.
15+
16+
```python
17+
from pathlib import Path
18+
19+
workspace = Path("/data/user_data/haok/esae")
20+
workspace.mkdir(mode=0o770, parents=True, exist_ok=True)
21+
22+
import os
23+
24+
os.environ["HF_HOME"] = "/data/user_data/haok/huggingface"
25+
```
26+
27+
## Overview
28+
29+
To train a Sparse Autoencoder (SAE), the first step is to download the dataset and compute the embeddings that will later be reconstructed. Initialize all datasets in this repository using the following command:
30+
31+
```sh
32+
python3 -m source.dataset.msMarco
33+
```
34+
35+
### Running Experiments
36+
37+
To experiment with different hyperparameters or model configurations, refer to the experiments in the `source/model/archive` directory. Each file in this directory contains a specific model setup.
38+
39+
If you want to create a new experiment, you can do so by adding a new file with your desired hyperparameters and configurations. Once ready, run the following command, replacing {version} with your file name:
40+
41+
```sh
42+
python3 -m source.model.{version}
43+
```
44+
45+
For example, if your new experiment file is under `source/model/240825A.py`, you would run:
46+
47+
```sh
48+
python3 -m source.model.240825A
49+
```
50+
51+
Model checkpoints are automatically saved under `{workspace}/model/{version}/state/`, where workspace is the path specified in `source/__init__.py`. This makes it easy to manage and retrieve your experiment results.
52+
53+
### Evaluating Performance
54+
55+
## Quality Assurance
56+
57+
### Standardized Interface
58+
59+
To ensure a clean and reusable codebase, this repository follows best practices by defining the interfaces in `source/interface.py`. All core components implement standardized interfaces that promote consistency and modularity. For instance, the Dataset class defines a blueprint that any dataset must follow by implementing the didIter method. This method enables the client to iterate over all document IDs in batches.
60+
61+
Here's an example:
62+
63+
```python
64+
from abc import ABC, abstractmethod
65+
from typing import Iterator, List
66+
67+
class Dataset(ABC):
68+
name: DatasetName
69+
70+
@abstractmethod
71+
def didIter(self, batchSize: int) -> Iterator[List[int]]:
72+
"""
73+
Iterate over the document IDs.
74+
75+
:param batchSize: The batch size for each iteration.
76+
:return: The iterator over the document IDs.
77+
"""
78+
raise NotImplementedError
79+
```
80+
81+
### Testing Locally
82+
83+
To ensure matainability, this codebase is fully type-checked using mypy and thoroughly tested with pytest. As new components are integrated into the interface, please ensure that corresponding test cases are added. Place your test cases under the relevant directories to keep the test suite comprehensive and organized.
84+
85+
You can run the following commands to perform these checks:
86+
87+
```sh
88+
mypy source
89+
pytest source
90+
```

acquire.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/bash
2+
3+
# Acquire a compute node with 32 CPUs, 96GB memory, and 1 A6000 GPU.
4+
# Please run this script with tmux to avoid losing the session.
5+
srun \
6+
--partition=long --time=07-00:00:00 \
7+
--cpus-per-task=32 --mem=96GB --gres=gpu:A6000:1 \
8+
--pty bash

environment.yml

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: esae
2+
3+
channels:
4+
- pytorch
5+
- nvidia
6+
- conda-forge
7+
8+
dependencies:
9+
- python=3.12
10+
- pytorch=2.4.0
11+
- pytorch-cuda=12.4
12+
- numpy=1.26.4
13+
- rich=13.7.1
14+
- pyarrow==16.1.0
15+
- mypy==1.10.0
16+
- aiofiles==22.1.0
17+
- aiohttp==3.9.5
18+
- transformers==4.44.2
19+
- pytest==7.4.4
20+
- pytest-asyncio==0.20.3
21+
- types-aiofiles==24.1.0.20240626
22+
- attrs==23.1.0
23+
- wandb==0.16.6
24+
- faiss-gpu==1.8.0
25+
- blobfile==3.0.0
26+
- treevizer=0.2.4
27+
- elasticsearch

mypy.ini

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[mypy]
2+
exclude = source/model/archive/
3+
4+
[mypy-transformers.*]
5+
ignore_missing_imports = True
6+
7+
[mypy-pyarrow.*]
8+
ignore_missing_imports = True
9+
10+
[mypy-faiss.*]
11+
ignore_missing_imports = True

pytest.ini

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
filterwarnings =
3+
ignore::UserWarning

reval.sh

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
echo "Reconstruct: 240919A"
2+
python3 -m source.interpret.retrieval.reconstruct 240919A
3+
echo "Reconstruct: 240919B"
4+
python3 -m source.interpret.retrieval.reconstruct 240919B
5+
echo "Reconstruct: 240919C"
6+
python3 -m source.interpret.retrieval.reconstruct 240919C
7+
echo "Reconstruct: 240919D"
8+
python3 -m source.interpret.retrieval.reconstruct 240919D
9+
echo "Reconstruct: 240919E"
10+
python3 -m source.interpret.retrieval.reconstruct 240919E
11+
echo "Reconstruct: 240919F"
12+
python3 -m source.interpret.retrieval.reconstruct 240919F

source/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pathlib import Path
2+
3+
workspace = Path("/data/group_data/cx_group/esae")
4+
workspace.mkdir(mode=0o770, parents=True, exist_ok=True)
5+
6+
import os
7+
8+
os.environ["HF_HOME"] = Path(workspace, "huggingface").as_posix()
9+
10+
from rich.console import Console
11+
12+
console = Console(width=80)
13+
console._log_render.show_path = False
14+
console._log_render.omit_repeated_times = False
15+
16+
import warnings
17+
18+
warnings.filterwarnings("ignore")

source/dataset/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pathlib import Path
2+
from source import workspace
3+
4+
workspace = Path(workspace, "dataset")
5+
workspace.mkdir(mode=0o770, parents=True, exist_ok=True)
6+
7+
from source.dataset.msMarco import MsMarcoDataset

0 commit comments

Comments
 (0)