-
Notifications
You must be signed in to change notification settings - Fork 865
/
Copy pathrecord_writer.py
47 lines (35 loc) · 1.27 KB
/
record_writer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
To write tf_record into file. Here we use it for tensorboard's event writting.
The code was borrow from https://github.com/TeamHG-Memex/tensorboard_logger
"""
import re
import struct
from .crc32c import crc32c
_VALID_OP_NAME_START = re.compile('^[A-Za-z0-9.]')
_VALID_OP_NAME_PART = re.compile('[A-Za-z0-9_.\\-/]+')
class RecordWriter(object):
def __init__(self, path, flush_secs=2):
self._name_to_tf_name = {}
self._tf_names = set()
self.path = path
self.flush_secs = flush_secs # TODO. flush every flush_secs, not every time.
self._writer = None
self._writer = open(path, 'wb')
def write(self, event_str):
w = self._writer.write
header = struct.pack('Q', len(event_str))
w(header)
w(struct.pack('I', masked_crc32c(header)))
w(event_str)
w(struct.pack('I', masked_crc32c(event_str)))
self._writer.flush()
def masked_crc32c(data):
x = u32(crc32c(data))
return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8)
def u32(x):
return x & 0xffffffff
def make_valid_tf_name(name):
if not _VALID_OP_NAME_START.match(name):
# Must make it valid somehow, but don't want to remove stuff
name = '.' + name
return '_'.join(_VALID_OP_NAME_PART.findall(name))