Skip to content

Commit

Permalink
[NFC] Log reader changes for MLGO environments. (#242)
Browse files Browse the repository at this point in the history
* Log reader changes for MLGO environments.

* Addressing comments.

* Fix a few more spurious changes.

* Tuple -> Union.

* Replace raw_bytes with to_numpy.

* Type hint for to_numpy.

* NDArray -> ndarray.
  • Loading branch information
jacob-hegna authored May 16, 2023
1 parent c963a8a commit 9d00bcf
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
24 changes: 20 additions & 4 deletions compiler_opt/rl/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
import json
import math

from typing import Any, BinaryIO, Dict, Generator, List, Optional
from typing import Any, BinaryIO, Dict, Generator, List, Optional, Union
import numpy as np
import tensorflow as tf

_element_type_name_map = {
Expand All @@ -86,6 +87,11 @@
}


def convert_dtype_to_ctype(dtype: str) -> Union[type, tf.dtypes.DType]:
"""Public interface for the _dtype_to_ctype dict."""
return _dtype_to_ctype[dtype]


def create_tensorspec(d: Dict[str, Any]) -> tf.TensorSpec:
name: str = d['name']
shape: List[int] = [int(e) for e in d['shape']]
Expand Down Expand Up @@ -120,6 +126,12 @@ def __init__(self, spec: tf.TensorSpec, buffer: bytes):
def spec(self):
return self._spec

def to_numpy(self) -> np.ndarray:
return np.frombuffer(
self._buffer,
dtype=convert_dtype_to_ctype(self._spec.dtype),
count=self._len)

def _set_view(self):
# c_char_p is a nul-terminated string, so the more appropriate cast here
# would be POINTER(c_char), but unfortunately, c_char_p is the only
Expand Down Expand Up @@ -205,11 +217,15 @@ def _enumerate_log_from_stream(
score=score)


def read_log_from_file(f) -> Generator[ObservationRecord, None, None]:
header = _read_header(f)
if header:
yield from _enumerate_log_from_stream(f, header)


def read_log(fname: str) -> Generator[ObservationRecord, None, None]:
with open(fname, 'rb') as f:
header = _read_header(f)
if header:
yield from _enumerate_log_from_stream(f, header)
yield from read_log_from_file(f)


def _add_feature(se: tf.train.SequenceExample, spec: tf.TensorSpec,
Expand Down
36 changes: 25 additions & 11 deletions compiler_opt/rl/log_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.protobuf import text_format # pytype: disable=pyi-error
from typing import BinaryIO

import numpy as np
import tensorflow as tf


Expand All @@ -38,26 +39,25 @@ def write_buff(f: BinaryIO, buffer: list, ct):


def write_context_marker(f: BinaryIO, name: str):
f.write(nl)
f.write(json_to_bytes({'context': name}))
f.write(nl)


def write_observation_marker(f: BinaryIO, obs_idx: int):
f.write(nl)
f.write(json_to_bytes({'observation': obs_idx}))
f.write(nl)


def begin_features(f: BinaryIO):
def write_nl(f: BinaryIO):
f.write(nl)


def write_outcome_marker(f: BinaryIO, obs_idx: int):
f.write(nl)
f.write(json_to_bytes({'outcome': obs_idx}))
f.write(nl)


def create_example(fname: str, nr_contexts=1):

t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
t1_val = [1, 2, 3]
s = [1.2]
Expand All @@ -69,12 +69,12 @@ def create_example(fname: str, nr_contexts=1):
'name': 'tensor_name2',
'port': 0,
'shape': [2, 3],
'type': 'float'
'type': 'float',
}, {
'name': 'tensor_name1',
'port': 0,
'shape': [3, 1],
'type': 'int64_t'
'type': 'int64_t',
}],
'score': {
'name': 'reward',
Expand All @@ -83,29 +83,30 @@ def create_example(fname: str, nr_contexts=1):
'type': 'float'
}
}))
write_nl(f)
for ctx_id in range(nr_contexts):
t0_val = [v + ctx_id * 10 for v in t0_val]
t1_val = [v + ctx_id * 10 for v in t1_val]
write_context_marker(f, f'context_nr_{ctx_id}')
write_observation_marker(f, 0)
begin_features(f)
write_buff(f, t0_val, ctypes.c_float)
write_buff(f, t1_val, ctypes.c_int64)
write_nl(f)
write_outcome_marker(f, 0)
begin_features(f)
write_buff(f, s, ctypes.c_float)
write_nl(f)

t0_val = [v + 1 for v in t0_val]
t1_val = [v + 1 for v in t1_val]
s[0] += 1

write_observation_marker(f, 1)
begin_features(f)
write_buff(f, t0_val, ctypes.c_float)
write_buff(f, t1_val, ctypes.c_int64)
write_nl(f)
write_outcome_marker(f, 1)
begin_features(f)
write_buff(f, s, ctypes.c_float)
write_nl(f)


class LogReaderTest(tf.test.TestCase):
Expand Down Expand Up @@ -155,6 +156,19 @@ def test_read_log(self):
obs_id += 1
self.assertEqual(obs_id, 2)

def test_to_numpy(self):
logfile = self.create_tempfile()
create_example(logfile)
t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
t1_val = [1, 2, 3]
for record in log_reader.read_log(logfile):
np.testing.assert_allclose(record.feature_values[0].to_numpy(),
np.array(t0_val))
np.testing.assert_allclose(record.feature_values[1].to_numpy(),
np.array(t1_val))
t0_val = [v + 1 for v in t0_val]
t1_val = [v + 1 for v in t1_val]

def test_seq_example_conversion(self):
logfile = self.create_tempfile()
create_example(logfile, nr_contexts=2)
Expand Down

0 comments on commit 9d00bcf

Please sign in to comment.