41
41
break
42
42
app , rt = fast_app (hdrs = plotly_headers )
43
43
44
+
45
+ # Add a function to compute spatial metrics for a feature
46
+ def compute_spatial_metrics (feature_id ):
47
+ """Compute spatial metrics for a specific feature."""
48
+ rows = scored_storage .get_rows (feature_id )
49
+
50
+ # Group rows by idx
51
+ metrics_by_image = {}
52
+ for (idx , h , w ), score in rows :
53
+ key = idx
54
+ if key not in metrics_by_image :
55
+ # Create activation grid for this image
56
+ grid = np .zeros ((HEIGHT , WIDTH ), dtype = float )
57
+ metrics_by_image [key ] = {"grid" : grid , "activations" : []}
58
+
59
+ # Add score to the grid
60
+ metrics_by_image [key ]["grid" ][h , w ] = score
61
+ metrics_by_image [key ]["activations" ].append ((h , w , score ))
62
+
63
+ # Compute metrics for each image
64
+ results = {}
65
+ for idx , data in metrics_by_image .items ():
66
+ grid = data ["grid" ]
67
+
68
+ # Skip if no activations
69
+ if grid .sum () == 0 :
70
+ continue
71
+
72
+ # Get positions where activation occurs
73
+ active_positions = np .where (grid > 0 )
74
+ if len (active_positions [0 ]) == 0 :
75
+ continue
76
+
77
+ # Compute center of mass
78
+ h_indices , w_indices = np .indices ((HEIGHT , WIDTH ))
79
+ total_activation = grid .sum ()
80
+ center_h = np .sum (h_indices * grid ) / total_activation if total_activation > 0 else 0
81
+ center_w = np .sum (w_indices * grid ) / total_activation if total_activation > 0 else 0
82
+
83
+ # Compute average distance from center of mass (spatial spread)
84
+ distances = np .sqrt ((h_indices - center_h )** 2 + (w_indices - center_w )** 2 )
85
+ avg_distance = np .sum (distances * grid ) / total_activation if total_activation > 0 else 0
86
+
87
+ # Compute concentration ratio: what percentage of total activation is in the top 25% of active pixels
88
+ active_values = grid [active_positions ]
89
+ sorted_values = np .sort (active_values )[::- 1 ] # Sort in descending order
90
+ quarter_point = max (1 , len (sorted_values ) // 4 )
91
+ concentration_ratio = np .sum (sorted_values [:quarter_point ]) / total_activation if total_activation > 0 else 0
92
+
93
+ # Compute activation area: percentage of image area that has activations
94
+ activation_area = len (active_positions [0 ]) / (HEIGHT * WIDTH )
95
+
96
+ # Store metrics
97
+ results [idx ] = {
98
+ "spatial_spread" : float (avg_distance ),
99
+ "concentration_ratio" : float (concentration_ratio ),
100
+ "activation_area" : float (activation_area ),
101
+ "max_activation" : float (grid .max ()),
102
+ "center" : (float (center_h ), float (center_w ))
103
+ }
104
+
105
+ # Aggregate metrics across images
106
+ if results :
107
+ avg_metrics = {
108
+ "spatial_spread" : float (np .mean ([m ["spatial_spread" ] for m in results .values ()])),
109
+ "concentration_ratio" : float (np .mean ([m ["concentration_ratio" ] for m in results .values ()])),
110
+ "activation_area" : float (np .mean ([m ["activation_area" ] for m in results .values ()])),
111
+ "num_images" : len (results )
112
+ }
113
+ return avg_metrics
114
+ return None
115
+
116
+ # Cache for spatial metrics to avoid recomputation
117
+ spatial_metrics_cache = {}
118
+
44
119
@rt ("/cached_image/{image_id}" )
45
120
def cached_image (image_id : int ):
46
121
img_path = image_cache_dir / f"{ image_id } .jpg"
@@ -79,6 +154,8 @@ def top_features():
79
154
Br (),
80
155
H1 (f"Spatial sparsity: { spatial_sparsity ():.3f} " ),
81
156
Br (),
157
+ P (A ("View Spatial Metrics" , href = "/spatial_metrics" )),
158
+ Br (),
82
159
* [Card (
83
160
P (f"Feature { i } , Frequency: { frequencies [i ]:.5f} , Max: { maxima [i ]} " ),
84
161
A ("View Max Acts" , href = f"/maxacts/{ i } " )
@@ -133,6 +210,15 @@ def maxacts(feature_id: int):
133
210
# Add score to the corresponding location in the grid
134
211
grouped_rows [key ][h , w ] = score
135
212
213
+ # Compute spatial metrics for this feature if not already cached
214
+ if feature_id not in spatial_metrics_cache :
215
+ spatial_metrics_cache [feature_id ] = compute_spatial_metrics (feature_id )
216
+
217
+ metrics = spatial_metrics_cache [feature_id ]
218
+ metrics_display = ""
219
+ if metrics :
220
+ metrics_display = f"Spatial Spread: { metrics ['spatial_spread' ]:.3f} , Concentration: { metrics ['concentration_ratio' ]:.3f} , Active Area: { metrics ['activation_area' ]:.3f} "
221
+
136
222
# Prepare images and cards
137
223
imgs = []
138
224
for idx , grid in sorted (grouped_rows .items (), key = lambda x : x [1 ].max (), reverse = True )[:20 ]:
@@ -191,10 +277,73 @@ def maxacts(feature_id: int):
191
277
192
278
return Div (
193
279
P (A ("<- Go back" , href = "/top_features" )),
280
+ H2 (f"Feature { feature_id } Spatial Metrics: { metrics_display } " ),
194
281
Div (* imgs , style = "display: flex; flex-wrap: wrap; gap: 20px; justify-content: center" ),
195
282
style = "padding: 20px"
196
283
)
197
284
285
+ # Add a new endpoint to view spatial metrics for all features
286
+ @rt ("/spatial_metrics" )
287
+ def spatial_metrics_view ():
288
+ # Get all feature IDs
289
+ counts = scored_storage .key_counts ()
290
+ maxima = scored_storage .key_maxima ()
291
+
292
+ # Filter features with significant activations
293
+ cond = maxima > 4
294
+ features = np .arange (len (scored_storage ))[cond ]
295
+
296
+ # Compute metrics for all features (with caching)
297
+ all_metrics = []
298
+ for feature_id in features :
299
+ if feature_id not in spatial_metrics_cache :
300
+ spatial_metrics_cache [feature_id ] = compute_spatial_metrics (feature_id )
301
+
302
+ metrics = spatial_metrics_cache [feature_id ]
303
+ if metrics :
304
+ all_metrics .append ({
305
+ "feature_id" : int (feature_id ),
306
+ "spatial_spread" : metrics ["spatial_spread" ],
307
+ "concentration_ratio" : metrics ["concentration_ratio" ],
308
+ "activation_area" : metrics ["activation_area" ],
309
+ "num_images" : metrics ["num_images" ]
310
+ })
311
+
312
+ # Sort by activation area (from most concentrated to most dispersed)
313
+ all_metrics .sort (key = lambda x : x ["activation_area" ])
314
+
315
+ # Create scatter plot of concentration vs spatial spread
316
+ scatter_plot = plotly2fasthtml (px .scatter (
317
+ x = [m ["activation_area" ] for m in all_metrics ],
318
+ y = [m ["concentration_ratio" ] for m in all_metrics ],
319
+ hover_name = [f"Feature { m ['feature_id' ]} " for m in all_metrics ],
320
+ labels = {"x" : "Activation Area (% of image)" , "y" : "Concentration Ratio" },
321
+ title = "Spatial Concentration Analysis"
322
+ ))
323
+
324
+ # Create cards for features
325
+ feature_cards = [
326
+ Card (
327
+ P (f"Feature { m ['feature_id' ]} " ),
328
+ P (f"Concentration: { m ['concentration_ratio' ]:.3f} " ),
329
+ P (f"Active Area: { m ['activation_area' ]:.3f} %" ),
330
+ P (f"Spatial Spread: { m ['spatial_spread' ]:.3f} " ),
331
+ A ("View Max Acts" , href = f"/maxacts/{ m ['feature_id' ]} " ),
332
+ style = "width: 200px; margin: 10px;"
333
+ ) for m in all_metrics [:50 ] # Show top 50 most concentrated features
334
+ ]
335
+
336
+ return Div (
337
+ H1 ("Spatial Metrics Analysis" ),
338
+ P (A ("<- Go back" , href = "/top_features" )),
339
+ Br (),
340
+ scatter_plot ,
341
+ Br (),
342
+ H2 ("Most Concentrated Features (Lowest Activation Area)" ),
343
+ Div (* feature_cards , style = "display: flex; flex-wrap: wrap; justify-content: center;" ),
344
+ style = "padding: 20px;"
345
+ )
346
+
198
347
NUM_PROMPTS = 4
199
348
200
349
@rt ("/gen_image" , methods = ["GET" ])
@@ -248,8 +397,9 @@ def home():
248
397
H1 ("fae" ),
249
398
H2 ("SAE" ),
250
399
P (A ("Top features" , href = "/top_features" )),
400
+ P (A ("Spatial Metrics" , href = "/spatial_metrics" )),
251
401
P (A ("Generator" , href = "/gen_image" )),
252
402
style = "padding: 5em"
253
403
)
254
404
255
- serve ()
405
+ serve ()
0 commit comments