-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstash.diff
41 lines (41 loc) · 1.43 KB
/
stash.diff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
diff --git a/tsaplay/utils/addons.py b/tsaplay/utils/addons.py
index fef0c92..4203b11 100644
--- a/tsaplay/utils/addons.py
+++ b/tsaplay/utils/addons.py
@@ -146,7 +146,7 @@ def conf_matrix(model, features, labels, spec, params):
depth=params["_n_out_classes"],
),
num_classes=params["_n_out_classes"],
- )
+ ),
}
)
eval_hooks += [
@@ -217,7 +217,26 @@ def scalars(model, features, labels, spec, params):
),
"auc": tf.metrics.auc(
labels=tf.one_hot(indices=labels, depth=params["_n_out_classes"]),
- predictions=spec.predictions["probabilities"],
+ predictions=tf.Print(
+ input_=spec.predictions["probabilities"],
+ data=[
+ spec.predictions["probabilities"],
+ tf.concat(
+ [
+ features["left"],
+ features["target"],
+ features["right"],
+ ],
+ axis=1,
+ ),
+ # features["left"],
+ tf.expand_dims(features["target"], axis=2),
+ # features["right"],
+ labels,
+ ],
+ message="",
+ summarize=5000,
+ ),
name="auc_op",
),
}