diff --git a/compiler_opt/rl/log_reader.py b/compiler_opt/rl/log_reader.py index e2f63aed..88f0101a 100644 --- a/compiler_opt/rl/log_reader.py +++ b/compiler_opt/rl/log_reader.py @@ -61,7 +61,8 @@ import json import math -from typing import Any, BinaryIO, Dict, Generator, List, Optional +from typing import Any, BinaryIO, Dict, Generator, List, Optional, Union +import numpy as np import tensorflow as tf _element_type_name_map = { @@ -86,6 +87,11 @@ } +def convert_dtype_to_ctype(dtype: str) -> Union[type, tf.dtypes.DType]: + """Public interface for the _dtype_to_ctype dict.""" + return _dtype_to_ctype[dtype] + + def create_tensorspec(d: Dict[str, Any]) -> tf.TensorSpec: name: str = d['name'] shape: List[int] = [int(e) for e in d['shape']] @@ -120,6 +126,12 @@ def __init__(self, spec: tf.TensorSpec, buffer: bytes): def spec(self): return self._spec + def to_numpy(self) -> np.ndarray: + return np.frombuffer( + self._buffer, + dtype=convert_dtype_to_ctype(self._spec.dtype), + count=self._len) + def _set_view(self): # c_char_p is a nul-terminated string, so the more appropriate cast here # would be POINTER(c_char), but unfortunately, c_char_p is the only @@ -205,11 +217,15 @@ def _enumerate_log_from_stream( score=score) +def read_log_from_file(f) -> Generator[ObservationRecord, None, None]: + header = _read_header(f) + if header: + yield from _enumerate_log_from_stream(f, header) + + def read_log(fname: str) -> Generator[ObservationRecord, None, None]: with open(fname, 'rb') as f: - header = _read_header(f) - if header: - yield from _enumerate_log_from_stream(f, header) + yield from read_log_from_file(f) def _add_feature(se: tf.train.SequenceExample, spec: tf.TensorSpec, diff --git a/compiler_opt/rl/log_reader_test.py b/compiler_opt/rl/log_reader_test.py index e093ff35..467eeec0 100644 --- a/compiler_opt/rl/log_reader_test.py +++ b/compiler_opt/rl/log_reader_test.py @@ -22,6 +22,7 @@ from google.protobuf import text_format # pytype: disable=pyi-error from typing import BinaryIO +import numpy as np import tensorflow as tf @@ -38,26 +39,25 @@ def write_buff(f: BinaryIO, buffer: list, ct): def write_context_marker(f: BinaryIO, name: str): - f.write(nl) f.write(json_to_bytes({'context': name})) + f.write(nl) def write_observation_marker(f: BinaryIO, obs_idx: int): - f.write(nl) f.write(json_to_bytes({'observation': obs_idx})) + f.write(nl) -def begin_features(f: BinaryIO): +def write_nl(f: BinaryIO): f.write(nl) def write_outcome_marker(f: BinaryIO, obs_idx: int): - f.write(nl) f.write(json_to_bytes({'outcome': obs_idx})) + f.write(nl) def create_example(fname: str, nr_contexts=1): - t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] t1_val = [1, 2, 3] s = [1.2] @@ -69,12 +69,12 @@ def create_example(fname: str, nr_contexts=1): 'name': 'tensor_name2', 'port': 0, 'shape': [2, 3], - 'type': 'float' + 'type': 'float', }, { 'name': 'tensor_name1', 'port': 0, 'shape': [3, 1], - 'type': 'int64_t' + 'type': 'int64_t', }], 'score': { 'name': 'reward', @@ -83,29 +83,30 @@ def create_example(fname: str, nr_contexts=1): 'type': 'float' } })) + write_nl(f) for ctx_id in range(nr_contexts): t0_val = [v + ctx_id * 10 for v in t0_val] t1_val = [v + ctx_id * 10 for v in t1_val] write_context_marker(f, f'context_nr_{ctx_id}') write_observation_marker(f, 0) - begin_features(f) write_buff(f, t0_val, ctypes.c_float) write_buff(f, t1_val, ctypes.c_int64) + write_nl(f) write_outcome_marker(f, 0) - begin_features(f) write_buff(f, s, ctypes.c_float) + write_nl(f) t0_val = [v + 1 for v in t0_val] t1_val = [v + 1 for v in t1_val] s[0] += 1 write_observation_marker(f, 1) - begin_features(f) write_buff(f, t0_val, ctypes.c_float) write_buff(f, t1_val, ctypes.c_int64) + write_nl(f) write_outcome_marker(f, 1) - begin_features(f) write_buff(f, s, ctypes.c_float) + write_nl(f) class LogReaderTest(tf.test.TestCase): @@ -155,6 +156,19 @@ def test_read_log(self): obs_id += 1 self.assertEqual(obs_id, 2) + def test_to_numpy(self): + logfile = self.create_tempfile() + create_example(logfile) + t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + t1_val = [1, 2, 3] + for record in log_reader.read_log(logfile): + np.testing.assert_allclose(record.feature_values[0].to_numpy(), + np.array(t0_val)) + np.testing.assert_allclose(record.feature_values[1].to_numpy(), + np.array(t1_val)) + t0_val = [v + 1 for v in t0_val] + t1_val = [v + 1 for v in t1_val] + def test_seq_example_conversion(self): logfile = self.create_tempfile() create_example(logfile, nr_contexts=2)