Skip to content

Commit

Permalink
Added support for loading URLs in CompilerEnvStateReader.read_paths()
Browse files Browse the repository at this point in the history
  • Loading branch information
thecoblack committed Jun 7, 2022
1 parent 9ad5fbb commit f105d45
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
22 changes: 15 additions & 7 deletions compiler_gym/compiler_env_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""This module defines a class to represent a compiler environment state."""
import csv
import re
import sys
from io import StringIO
from typing import Iterable, List, Optional, TextIO

import requests
from pydantic import BaseModel, Field, validator

from compiler_gym.datasets.uri import BenchmarkUri
Expand All @@ -23,10 +26,7 @@ class CompilerEnvState(BaseModel):

benchmark: str = Field(
allow_mutation=False,
examples=[
"benchmark://cbench-v1/crc32",
"generator://csmith-v0/0",
],
examples=["benchmark://cbench-v1/crc32", "generator://csmith-v0/0",],
)
"""The URI of the benchmark used for this episode."""

Expand All @@ -37,9 +37,7 @@ class CompilerEnvState(BaseModel):
"""The walltime of the episode in seconds. Must be non-negative."""

reward: Optional[float] = Field(
required=False,
default=None,
allow_mutation=True,
required=False, default=None, allow_mutation=True,
)
"""The cumulative reward for this episode. Optional."""

Expand Down Expand Up @@ -229,6 +227,16 @@ def read_paths(paths: Iterable[str]) -> Iterable[CompilerEnvState]:
for path in paths:
if path == "-":
yield from iter(CompilerEnvStateReader(sys.stdin))
elif (
re.match(r"^(http|https)://[a-zA-Z0-9.-_/]+(\.csv)$", path) is not None
):
response: requests.Response = requests.get(path)
if response.status_code == 200:
yield from iter(CompilerEnvStateReader(StringIO(response.text)))
else:
raise requests.exceptions.InvalidURL(
f"Url {path} content could not be obtained"
)
else:
with open(path) as f:
yield from iter(CompilerEnvStateReader(f))
88 changes: 88 additions & 0 deletions tests/compiler_env_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path

import pytest
import requests
from pydantic import ValidationError as PydanticValidationError

from compiler_gym import CompilerEnvState, CompilerEnvStateWriter
Expand Down Expand Up @@ -321,5 +322,92 @@ def test_state_serialize_deserialize_equality_no_reward():
assert state_from_csv.commandline == "-a -b -c"


def test_read_paths_stdin(monkeypatch):
monkeypatch.setattr(
"sys.stdin",
StringIO(
"benchmark,reward,walltime,commandline\n"
"benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n"
),
)
reader = CompilerEnvStateReader.read_paths(["-"])
assert list(reader) == [
CompilerEnvState(
benchmark="benchmark://cbench-v0/foo",
walltime=5,
commandline="-a -b -c",
reward=2,
)
]


def test_read_paths_file(tmp_path):
file_dir = f"{tmp_path}/test.csv"
with open(file_dir, "w") as csv_file:
csv_file.write(
"benchmark,reward,walltime,commandline\n"
"benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n"
)
reader = CompilerEnvStateReader.read_paths([file_dir])
assert list(reader) == [
CompilerEnvState(
benchmark="benchmark://cbench-v0/foo",
walltime=5,
commandline="-a -b -c",
reward=2,
)
]


def test_read_paths_url(monkeypatch):
urls = ["https://compilergym.ai/benchmarktest.csv"]

class MockResponse:
def __init__(self, text, status_code):
self.text = text
self.status_code = status_code

def ok_mock_response(*args, **kwargs):
return MockResponse(
(
"benchmark,reward,walltime,commandline\n"
"benchmark://cbench-v0/foo,2.0,5.0,-a -b -c\n"
),
200,
)

monkeypatch.setattr(requests, "get", ok_mock_response)
reader = CompilerEnvStateReader.read_paths(urls)
assert list(reader) == [
CompilerEnvState(
benchmark="benchmark://cbench-v0/foo",
walltime=5,
commandline="-a -b -c",
reward=2,
)
]

def bad_mock_response(*args, **kwargs):
return MockResponse("", 404)

monkeypatch.setattr(requests, "get", bad_mock_response)
with pytest.raises(requests.exceptions.InvalidURL):
reader = CompilerEnvStateReader.read_paths(urls)
list(reader)


def test_read_paths_bad_inputs():
bad_dirs = [
"/fake/directory/file.csv",
"fake/directory/file.csv",
"https://www.compilergym.ai/benchmark",
"htts://www.compilergym.ai/benchmark.csv",
"htts://www.compilergym.ai/benchmark",
]
with pytest.raises(FileNotFoundError):
reader = CompilerEnvStateReader.read_paths(bad_dirs)
list(reader)


if __name__ == "__main__":
main()

0 comments on commit f105d45

Please sign in to comment.