|
40 | 40 |
|
41 | 41 |
|
42 | 42 | def model(input):
|
43 |
| - update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) |
| 43 | + update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) |
44 | 44 | print("update_ops:{}".format(update_ops))
|
45 | 45 | with tf.control_dependencies(update_ops):
|
46 |
| - return tf.one_hot(tf.argmax(input, 1), 4, axis=-1) |
| 46 | + return tf.one_hot(tf.argmax(input=input, axis=1), 4, axis=-1) |
47 | 47 |
|
48 | 48 |
|
49 |
| -with tf.Session() as sess: |
50 |
| - input = tf.placeholder(tf.float32, [None, 4]) |
| 49 | +with tf.compat.v1.Session() as sess: |
| 50 | + input = tf.compat.v1.placeholder(tf.float32, [None, 4]) |
51 | 51 | # print(sess.run(tf.cast(t1, tf.bool)))
|
52 | 52 | # print(sess.run(tf.argmax(t2, 1)))
|
53 |
| - onehot = tf.one_hot(tf.argmax(t2, 1), 4, axis=-1) |
| 53 | + onehot = tf.one_hot(tf.argmax(input=t2, axis=1), 4, axis=-1) |
54 | 54 | print(sess.run(onehot))
|
55 | 55 | print(sess.run(tf.cast(onehot, tf.bool)))
|
56 | 56 | # tf.one_hot(tf.argmax(self.prediction, 1), size, axis = -1),
|
57 | 57 | # print([3, 3, 3]+t1+t3)
|
58 | 58 |
|
59 |
| - r1, _ = tf.metrics.recall( |
| 59 | + r1, _ = tf.compat.v1.metrics.recall( |
60 | 60 | labels=labels,
|
61 | 61 | predictions=model(input),
|
62 | 62 | weights=mask1,
|
63 |
| - updates_collections=tf.GraphKeys.UPDATE_OPS) |
64 |
| - p1, _ = tf.metrics.precision( |
| 63 | + updates_collections=tf.compat.v1.GraphKeys.UPDATE_OPS) |
| 64 | + p1, _ = tf.compat.v1.metrics.precision( |
65 | 65 | labels=labels,
|
66 | 66 | predictions=model(input),
|
67 | 67 | weights=mask1,
|
68 |
| - updates_collections=tf.GraphKeys.UPDATE_OPS) |
69 |
| - r2, _ = tf.metrics.recall( |
| 68 | + updates_collections=tf.compat.v1.GraphKeys.UPDATE_OPS) |
| 69 | + r2, _ = tf.compat.v1.metrics.recall( |
70 | 70 | labels=labels,
|
71 | 71 | predictions=model(input),
|
72 | 72 | weights=mask2,
|
73 |
| - updates_collections=tf.GraphKeys.UPDATE_OPS) |
74 |
| - p2, _ = tf.metrics.precision( |
| 73 | + updates_collections=tf.compat.v1.GraphKeys.UPDATE_OPS) |
| 74 | + p2, _ = tf.compat.v1.metrics.precision( |
75 | 75 | labels=labels,
|
76 | 76 | predictions=model(input),
|
77 | 77 | weights=mask2,
|
78 |
| - updates_collections=tf.GraphKeys.UPDATE_OPS) |
79 |
| - sess.run(tf.global_variables_initializer()) |
80 |
| - sess.run(tf.local_variables_initializer()) |
| 78 | + updates_collections=tf.compat.v1.GraphKeys.UPDATE_OPS) |
| 79 | + sess.run(tf.compat.v1.global_variables_initializer()) |
| 80 | + sess.run(tf.compat.v1.local_variables_initializer()) |
81 | 81 | # sess.run([r_op, p_op])
|
82 | 82 |
|
83 | 83 | sess.run(model(input), feed_dict={input: logits})
|
|
0 commit comments