Skip to content

Commit 457136f

Browse files
committed
Try speeding up CPU ops
1 parent a415f73 commit 457136f

File tree

6 files changed

+238
-27
lines changed

6 files changed

+238
-27
lines changed

data_analyzer.py

+1
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,4 @@ def hadamard_matrix(n):
231231
plt.yscale("log")
232232
plt.legend()
233233
# %%
234+

download_stuff.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from huggingface_hub import HfApi, hf_hub_download
22
from tqdm.auto import tqdm
33
import os
4+
import argparse
5+
from pathlib import Path
46

57

6-
# Thanks Deepseek
78
def download_subdirectory(repo_id, subdir, local_dir, repo_type="model"):
89
print(f"Downloading {subdir} from {repo_id} to {local_dir}")
910
# Initialize HfApi
@@ -32,13 +33,27 @@ def download_subdirectory(repo_id, subdir, local_dir, repo_type="model"):
3233
)
3334
print(f"Downloaded {len(target_files)} files to {local_dir}")
3435

35-
# Usage
36-
37-
repo_id = "nev/flux1-saes"
38-
for directory_name in "sae_double_l18_img".split():
39-
download_subdirectory(
40-
repo_id=repo_id,
41-
subdir=directory_name,
42-
local_dir="somewhere"
43-
)
36+
def main():
37+
parser = argparse.ArgumentParser(description="Download directories from Hugging Face")
38+
parser.add_argument("--repo", type=str, default="dmitriihook/flux1-saes",
39+
help="HuggingFace repository ID")
40+
parser.add_argument("--dirs", type=str, nargs="+",
41+
default=["maxacts_itda_50k_256/itda_new_data", "maxacts_itda_50k_256"],
42+
help="Directories to download")
43+
parser.add_argument("--output", type=str, default="somewhere",
44+
help="Local directory to save files")
45+
46+
args = parser.parse_args()
4447

48+
for directory in args.dirs:
49+
# Create a proper local directory path that matches the structure
50+
local_path = Path(args.output) / directory
51+
local_path.parent.mkdir(parents=True, exist_ok=True)
52+
download_subdirectory(
53+
repo_id=args.repo,
54+
subdir=directory,
55+
local_dir=args.output
56+
)
57+
58+
if __name__ == "__main__":
59+
main()

main.py

+151-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,81 @@
4141
break
4242
app, rt = fast_app(hdrs=plotly_headers)
4343

44+
45+
# Add a function to compute spatial metrics for a feature
46+
def compute_spatial_metrics(feature_id):
47+
"""Compute spatial metrics for a specific feature."""
48+
rows = scored_storage.get_rows(feature_id)
49+
50+
# Group rows by idx
51+
metrics_by_image = {}
52+
for (idx, h, w), score in rows:
53+
key = idx
54+
if key not in metrics_by_image:
55+
# Create activation grid for this image
56+
grid = np.zeros((HEIGHT, WIDTH), dtype=float)
57+
metrics_by_image[key] = {"grid": grid, "activations": []}
58+
59+
# Add score to the grid
60+
metrics_by_image[key]["grid"][h, w] = score
61+
metrics_by_image[key]["activations"].append((h, w, score))
62+
63+
# Compute metrics for each image
64+
results = {}
65+
for idx, data in metrics_by_image.items():
66+
grid = data["grid"]
67+
68+
# Skip if no activations
69+
if grid.sum() == 0:
70+
continue
71+
72+
# Get positions where activation occurs
73+
active_positions = np.where(grid > 0)
74+
if len(active_positions[0]) == 0:
75+
continue
76+
77+
# Compute center of mass
78+
h_indices, w_indices = np.indices((HEIGHT, WIDTH))
79+
total_activation = grid.sum()
80+
center_h = np.sum(h_indices * grid) / total_activation if total_activation > 0 else 0
81+
center_w = np.sum(w_indices * grid) / total_activation if total_activation > 0 else 0
82+
83+
# Compute average distance from center of mass (spatial spread)
84+
distances = np.sqrt((h_indices - center_h)**2 + (w_indices - center_w)**2)
85+
avg_distance = np.sum(distances * grid) / total_activation if total_activation > 0 else 0
86+
87+
# Compute concentration ratio: what percentage of total activation is in the top 25% of active pixels
88+
active_values = grid[active_positions]
89+
sorted_values = np.sort(active_values)[::-1] # Sort in descending order
90+
quarter_point = max(1, len(sorted_values) // 4)
91+
concentration_ratio = np.sum(sorted_values[:quarter_point]) / total_activation if total_activation > 0 else 0
92+
93+
# Compute activation area: percentage of image area that has activations
94+
activation_area = len(active_positions[0]) / (HEIGHT * WIDTH)
95+
96+
# Store metrics
97+
results[idx] = {
98+
"spatial_spread": float(avg_distance),
99+
"concentration_ratio": float(concentration_ratio),
100+
"activation_area": float(activation_area),
101+
"max_activation": float(grid.max()),
102+
"center": (float(center_h), float(center_w))
103+
}
104+
105+
# Aggregate metrics across images
106+
if results:
107+
avg_metrics = {
108+
"spatial_spread": float(np.mean([m["spatial_spread"] for m in results.values()])),
109+
"concentration_ratio": float(np.mean([m["concentration_ratio"] for m in results.values()])),
110+
"activation_area": float(np.mean([m["activation_area"] for m in results.values()])),
111+
"num_images": len(results)
112+
}
113+
return avg_metrics
114+
return None
115+
116+
# Cache for spatial metrics to avoid recomputation
117+
spatial_metrics_cache = {}
118+
44119
@rt("/cached_image/{image_id}")
45120
def cached_image(image_id: int):
46121
img_path = image_cache_dir / f"{image_id}.jpg"
@@ -79,6 +154,8 @@ def top_features():
79154
Br(),
80155
H1(f"Spatial sparsity: {spatial_sparsity():.3f}"),
81156
Br(),
157+
P(A("View Spatial Metrics", href="/spatial_metrics")),
158+
Br(),
82159
*[Card(
83160
P(f"Feature {i}, Frequency: {frequencies[i]:.5f}, Max: {maxima[i]}"),
84161
A("View Max Acts", href=f"/maxacts/{i}")
@@ -133,6 +210,15 @@ def maxacts(feature_id: int):
133210
# Add score to the corresponding location in the grid
134211
grouped_rows[key][h, w] = score
135212

213+
# Compute spatial metrics for this feature if not already cached
214+
if feature_id not in spatial_metrics_cache:
215+
spatial_metrics_cache[feature_id] = compute_spatial_metrics(feature_id)
216+
217+
metrics = spatial_metrics_cache[feature_id]
218+
metrics_display = ""
219+
if metrics:
220+
metrics_display = f"Spatial Spread: {metrics['spatial_spread']:.3f}, Concentration: {metrics['concentration_ratio']:.3f}, Active Area: {metrics['activation_area']:.3f}"
221+
136222
# Prepare images and cards
137223
imgs = []
138224
for idx, grid in sorted(grouped_rows.items(), key=lambda x: x[1].max(), reverse=True)[:20]:
@@ -191,10 +277,73 @@ def maxacts(feature_id: int):
191277

192278
return Div(
193279
P(A("<- Go back", href="/top_features")),
280+
H2(f"Feature {feature_id} Spatial Metrics: {metrics_display}"),
194281
Div(*imgs, style="display: flex; flex-wrap: wrap; gap: 20px; justify-content: center"),
195282
style="padding: 20px"
196283
)
197284

285+
# Add a new endpoint to view spatial metrics for all features
286+
@rt("/spatial_metrics")
287+
def spatial_metrics_view():
288+
# Get all feature IDs
289+
counts = scored_storage.key_counts()
290+
maxima = scored_storage.key_maxima()
291+
292+
# Filter features with significant activations
293+
cond = maxima > 4
294+
features = np.arange(len(scored_storage))[cond]
295+
296+
# Compute metrics for all features (with caching)
297+
all_metrics = []
298+
for feature_id in features:
299+
if feature_id not in spatial_metrics_cache:
300+
spatial_metrics_cache[feature_id] = compute_spatial_metrics(feature_id)
301+
302+
metrics = spatial_metrics_cache[feature_id]
303+
if metrics:
304+
all_metrics.append({
305+
"feature_id": int(feature_id),
306+
"spatial_spread": metrics["spatial_spread"],
307+
"concentration_ratio": metrics["concentration_ratio"],
308+
"activation_area": metrics["activation_area"],
309+
"num_images": metrics["num_images"]
310+
})
311+
312+
# Sort by activation area (from most concentrated to most dispersed)
313+
all_metrics.sort(key=lambda x: x["activation_area"])
314+
315+
# Create scatter plot of concentration vs spatial spread
316+
scatter_plot = plotly2fasthtml(px.scatter(
317+
x=[m["activation_area"] for m in all_metrics],
318+
y=[m["concentration_ratio"] for m in all_metrics],
319+
hover_name=[f"Feature {m['feature_id']}" for m in all_metrics],
320+
labels={"x": "Activation Area (% of image)", "y": "Concentration Ratio"},
321+
title="Spatial Concentration Analysis"
322+
))
323+
324+
# Create cards for features
325+
feature_cards = [
326+
Card(
327+
P(f"Feature {m['feature_id']}"),
328+
P(f"Concentration: {m['concentration_ratio']:.3f}"),
329+
P(f"Active Area: {m['activation_area']:.3f}%"),
330+
P(f"Spatial Spread: {m['spatial_spread']:.3f}"),
331+
A("View Max Acts", href=f"/maxacts/{m['feature_id']}"),
332+
style="width: 200px; margin: 10px;"
333+
) for m in all_metrics[:50] # Show top 50 most concentrated features
334+
]
335+
336+
return Div(
337+
H1("Spatial Metrics Analysis"),
338+
P(A("<- Go back", href="/top_features")),
339+
Br(),
340+
scatter_plot,
341+
Br(),
342+
H2("Most Concentrated Features (Lowest Activation Area)"),
343+
Div(*feature_cards, style="display: flex; flex-wrap: wrap; justify-content: center;"),
344+
style="padding: 20px;"
345+
)
346+
198347
NUM_PROMPTS = 4
199348

200349
@rt("/gen_image", methods=["GET"])
@@ -248,8 +397,9 @@ def home():
248397
H1("fae"),
249398
H2("SAE"),
250399
P(A("Top features", href="/top_features")),
400+
P(A("Spatial Metrics", href="/spatial_metrics")),
251401
P(A("Generator", href="/gen_image")),
252402
style="padding: 5em"
253403
)
254404

255-
serve()
405+
serve()

run_all.sh

+31-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,35 @@ set -e
1616

1717
# 768 = 512 + (256 = 16*16 = 256/16 * 256/16)
1818
# 1536 = 512 + (1024 = 32*32 = 512/16 * 512/16)
19+
# Run with default settings
20+
# uv run python run_fae.py
1921

20-
uv run python -m src.fae.sae_trainer --train_mode=False -seq_len=1536 --batch_size=4 --restore_from=somewhere/sae_double_l18_img
22+
# Specify a different cache path
23+
# uv run python run_fae.py --cache-path="somewhere/other_cache_dir"
24+
25+
# Change image dimensions
26+
# uv run python run_fae.py --width=768 --height=768
27+
28+
# Use a different port
29+
# uv run python run_fae.py --port=5002
30+
31+
# Only compute and output metrics without starting the server
32+
# uv run python run_fae.py --metrics-only
33+
34+
# Clear the image cache before starting
35+
# uv run python run_fae.py --clear-image-cache
36+
37+
# uv run python -m src.fae.sae_trainer --train_mode=False -seq_len=1536 --batch_size=4 --restore_from=somewhere/sae_double_l18_img
38+
# uv run python -m src.fae.sae_trainer --train_mode=True -seq_len=1536 --timesteps=4 --batch_size=4 --restore_from=somewhere/sae_double_l18_img
39+
# uv run python -m src.fae.sae_trainer
40+
# uv run python -m src.fae.sae_trainer --train_mode=False
41+
42+
# uv run python -m src.fae.sae_trainer --layer=18 --block_type=double --train_mode=False --restore_from=somewhere/sae_double_l18_img
43+
# uv run python -m src.fae.sae_trainer --layer=12 --block_type=double --sae_train_every=1
44+
45+
# uv run python -m src.fae.sae_trainer --layer=3 --block_type=double
46+
# uv run python -m src.fae.sae_trainer --layer=3 --block_type=double --train_mode=False
47+
# uv run python -m src.fae.sae_trainer --layer=6 --block_type=double
48+
# uv run python -m src.fae.sae_trainer --layer=6 --block_type=double --train_mode=False
49+
uv run python -m src.fae.sae_trainer --layer=15 --block_type=double
50+
uv run python -m src.fae.sae_trainer --layer=15 --block_type=double --train_mode=False

src/fae/sae_common.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax
2+
import equinox as eqx
23
import jax.numpy as jnp
34
from dataclasses import dataclass
45
from jaxtyping import Array, Float, UInt
@@ -143,6 +144,7 @@ def width_and_height(self):
143144
width_and_height = math.isqrt(width_height_product)
144145
return width_and_height, width_and_height
145146

147+
@eqx.filter_jit
146148
def cut_up(
147149
self, training_data: Float[Array, "*batch seq_len d_model"]
148150
) -> Float[Array, "full_batch_size d_model"]:

src/fae/sae_trainer.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ def main(*, restore: bool = False,
631631
**(dict(sae_train_every=1,
632632
sae_batch_size_multiplier=1)
633633
if not train_mode else {}),
634+
site = (block_type, layer),
634635
**extra_config_items,
635636
)
636637
if not train_mode and not restore:
@@ -660,6 +661,17 @@ def main(*, restore: bool = False,
660661
cycle_detected = False
661662
activation_cache = []
662663
# normalization_hack = None
664+
665+
@jax.jit
666+
def process_reaped(reaped):
667+
if block_type == "double":
668+
training_data = jnp.concatenate((reaped[f"double.resid.txt"], reaped[f"double.resid.img"]), axis=-2)
669+
else:
670+
training_data = reaped[f"single.resid"]
671+
training_data = training_data.reshape(-1, *training_data.shape[2:]) # (timesteps, sequence_length, d_model)
672+
training_data = config.cut_up(training_data)
673+
return training_data
674+
663675
for step, prompts in zip(range(config.n_steps), chunked(prompts_iterator, config.batch_size)):
664676
if len(prompts) < config.batch_size:
665677
logger.warning("End of dataset")
@@ -686,46 +698,47 @@ def main(*, restore: bool = False,
686698
reaped = outputs[1].reaped # (timesteps, batch, sequence_length, d_model)
687699
assert isinstance(images, jnp.ndarray) # to silence mypy
688700
logger.add(sys.stderr, level="INFO")
689-
if block_type == "double":
690-
training_data = jnp.concatenate((reaped[f"double.resid.txt"], reaped[f"double.resid.img"]), axis=-2)
691-
else:
692-
training_data = reaped[f"single.resid"]
693-
training_data = training_data.reshape(-1, *training_data.shape[2:]) # (timesteps, sequence_length, d_model)
694-
training_data = config.cut_up(training_data)
701+
training_data = process_reaped(reaped)
695702
take_data = int(len(training_data) * config.use_data_fraction)
696703
if config.transfer_to_cpu:
697-
training_data = np.asarray(training_data)
704+
training_data = np.array(training_data)
698705
np.random.shuffle(training_data)
699706
training_data = training_data[:take_data]
700707
else:
701708
training_data = jax.random.permutation(key, training_data)[:take_data]
702-
training_data = jax.device_put(
703-
training_data,
704-
jax.sharding.NamedSharding(sae_trainer.mesh, jax.sharding.PartitionSpec("dp", *((None,) * (training_data.ndim - 1)))))
705709
activation_cache.append(training_data)
706710
if len(activation_cache) >= config.sae_train_every:
707711
assert config.sae_train_every % config.sae_batch_size_multiplier == 0
708712
for inner_step in range(config.sae_train_every // config.sae_batch_size_multiplier):
709713
if train_mode:
710-
cache_data = np.concatenate(activation_cache, axis=0)
711-
if config.transfer_to_cpu and inner_step == 0:
712-
np.random.shuffle(cache_data)
714+
if config.transfer_to_cpu:
715+
cache_data = np.concatenate(activation_cache, axis=0)
716+
if inner_step == 0:
717+
np.random.shuffle(cache_data)
718+
else:
719+
if len(activation_cache) == 1:
720+
cache_data = activation_cache[0]
721+
else:
722+
cache_data = jnp.concatenate(activation_cache, axis=0)
723+
# cache_data = jax.jit(lambda x: jnp.concatenate(x, axis=0))(activation_cache)
713724
if len(cache_data) == config.train_batch_size:
714725
training_data, activation_cache = cache_data, []
715726
else:
716727
activation_cache = [cache_data[config.train_batch_size:]]
717728
training_data = cache_data[:config.train_batch_size]
729+
if config.transfer_to_cpu:
730+
training_data = jnp.asarray(training_data)
718731
else:
719732
assert config.sae_batch_size_multiplier == config.sae_train_every == 1
720-
if int(sae_trainer.sae_trainer.sae_logic.info.n_steps) == 0 and config.use_pca:
733+
if int(sae_trainer.sae_trainer.sae_logic.info.n_steps) == 0 and config.use_pca and config.do_update:
721734
sae_trainer.sae_trainer = eqx.tree_at(
722735
lambda x: x.sae_logic.info.whitening_matrix,
723736
sae_trainer.sae_trainer,
724737
compute_whitening(training_data)
725738
)
726739
sae_outputs = sae_trainer.step(
727740
jax.device_put(
728-
jnp.asarray(training_data),
741+
training_data,
729742
jax.sharding.NamedSharding(sae_trainer.mesh, jax.sharding.PartitionSpec("dp", None))))
730743
if not train_mode:
731744
sae_weights, sae_indices = map(

0 commit comments

Comments
 (0)