Skip to content

Commit

Permalink
Merge pull request #189 from facebookresearch/async-reset
Browse files Browse the repository at this point in the history
[env] Add initial observations to StartSession calls.
  • Loading branch information
ChrisCummins authored Apr 21, 2021
2 parents 08da918 + a3f0d31 commit 39f027c
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 21 deletions.
13 changes: 12 additions & 1 deletion compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,11 @@ def reset( # pylint: disable=arguments-differ
if self.action_space_name
else 0
),
observation_space=(
[self.observation_space.index]
if self.observation_space
else None
),
),
)
except (ServiceError, ServiceTransportError, TimeoutError) as e:
Expand Down Expand Up @@ -749,7 +754,13 @@ def reset( # pylint: disable=arguments-differ
self.episode_reward = 0

if self.observation_space:
return self.observation[self.observation_space.id]
if len(reply.observation) != 1:
raise OSError(
f"Expected one observation from service, received {len(reply.observation)}"
)
return self.observation.spaces[self.observation_space.id].translate(
reply.observation[0]
)

def step(self, action: Union[int, Iterable[int]]) -> step_t:
"""Take a step.
Expand Down
14 changes: 12 additions & 2 deletions compiler_gym/envs/llvm/service/LlvmService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,19 @@ Status LlvmService::StartSession(ServerContext* /* unused */, const StartSession
RETURN_IF_ERROR(util::intToEnum(request->action_space(), &actionSpace));

// Construct the environment.
reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] =
auto session =
std::make_unique<LlvmSession>(std::move(benchmark), actionSpace, workingDirectory_);

// Compute the initial observations.
for (int i = 0; i < request->observation_space_size(); ++i) {
LlvmObservationSpace observationSpace;
RETURN_IF_ERROR(util::intToEnum(request->observation_space(i), &observationSpace));
auto observation = reply->add_observation();
RETURN_IF_ERROR(session->getObservation(observationSpace, observation));
}

reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] = std::move(session);
++nextSessionId_;

return Status::OK;
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/LlvmSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class LlvmSession {
// since the start of the session. This is just for logging and has no effect.
inline int actionCount() const { return actionCount_; }

protected:
// Run the requested action.
[[nodiscard]] grpc::Status runAction(LlvmAction action, StepReply* reply);

// Compute the requested observation.
[[nodiscard]] grpc::Status getObservation(LlvmObservationSpace space, Observation* reply);

protected:
// Run the given pass, possibly modifying the underlying LLVM module.
void runPass(llvm::Pass* pass, StepReply* reply);
void runPass(llvm::FunctionPass* pass, StepReply* reply);
Expand Down
4 changes: 4 additions & 0 deletions compiler_gym/service/proto/compiler_gym_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ message StartSessionRequest {
// space that is to be used for this session. Once set, the action space
// cannot be changed for the duration of the session.
int32 action_space = 2;
// A list of indices into the GetSpacesReply.observation_space_list
repeated int32 observation_space = 3;
}

message StartSessionReply {
Expand All @@ -85,6 +87,8 @@ message StartSessionReply {
// space and replace it with this one. Else, the action space remains
// unchanged.
ActionSpace new_action_space = 3;
// Observed states after completing the action.
repeated Observation observation = 4;
}

// ===========================================================================
Expand Down
15 changes: 12 additions & 3 deletions examples/example_compiler_gym_service/service_cc/ExampleService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,18 @@ Status ExampleService::StartSession(ServerContext* /* unused*/, const StartSessi
const auto actionSpace = actionSpaces[request->action_space()];

// Create the new compilation session given.
auto session = std::make_unique<ExampleCompilationSession>(benchmark, actionSpace);

// Generate initial observations.
for (int i = 0; i < request->observation_space_size(); ++i) {
RETURN_IF_ERROR(rangeCheck(request->observation_space(i), 0,
static_cast<int32_t>(getObservationSpaces().size()) - 1));
RETURN_IF_ERROR(
session->getObservation(request->observation_space(i), reply->add_observation()));
}

reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] = std::make_unique<ExampleCompilationSession>(benchmark, actionSpace);
sessions_[nextSessionId_] = std::move(session);
++nextSessionId_;

return Status::OK;
Expand Down Expand Up @@ -170,8 +180,7 @@ Status ExampleCompilationSession::Step(const StepRequest* request, StepReply* re
for (int i = 0; i < request->observation_space_size(); ++i) {
RETURN_IF_ERROR(rangeCheck(request->observation_space(i), 0,
static_cast<int32_t>(getObservationSpaces().size()) - 1));
auto observation = reply->add_observation();
RETURN_IF_ERROR(getObservation(request->observation_space(i), observation));
RETURN_IF_ERROR(getObservation(request->observation_space(i), reply->add_observation()));
}

return Status::OK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class ExampleCompilationSession {

[[nodiscard]] grpc::Status Step(const StepRequest* request, StepReply* reply);

private:
grpc::Status getObservation(int32_t observationSpace, Observation* reply);

private:
const std::string benchmark_;
ActionSpace actionSpace_;
};
Expand Down
35 changes: 22 additions & 13 deletions examples/example_compiler_gym_service/service_py/example_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def __init__(self, benchmark: str):
# Do any of the set up required to start a compilation "session".
self.benchmark = benchmark

def set_observation(self, observation_space, observation):
logging.debug("Compute observation %d", observation_space)
if observation_space == 0: # ir
observation.string_value = "Hello, world!"
elif observation_space == 1: # features
observation.int64_list.value[:] = [0, 0, 0]
elif observation_space == 1: # runtime
observation.scalar_double = 0
return observation

def step(self, request: proto.StepRequest, context) -> proto.StepReply:
reply = proto.StepReply()

Expand All @@ -108,14 +118,7 @@ def step(self, request: proto.StepRequest, context) -> proto.StepReply:
# Compute a list of observations from the user. Each value is an index
# into the OBSERVATION_SPACES list.
for observation_space in request.observation_space:
logging.debug("Compute observation %d", observation_space)
observation = reply.observation.add()
if observation_space == 0: # ir
observation.string_value = "Hello, world!"
elif observation_space == 1: # features
observation.int64_list.value[:] = [0, 0, 0]
elif observation_space == 1: # runtime
observation.scalar_double = 0
self.set_observation(observation_space, reply.observation.add())

return reply

Expand Down Expand Up @@ -163,6 +166,7 @@ def StartSession(
) -> proto.StartSessionReply:
"""Create a new compilation session."""
logging.debug("StartSession(benchmark=%s)", request.benchmark)
reply = proto.StartSessionReply()

if not request.benchmark:
benchmark = "foo" # Pick a default benchmark is none was requested.
Expand All @@ -175,11 +179,16 @@ def StartSession(
return

session = CompilationSession(benchmark=benchmark)
session_id = len(self.sessions)
self.sessions[session_id] = session
return proto.StartSessionReply(
session_id=session_id, benchmark=session.benchmark
)

# Generate the initial observations.
for observation_space in request.observation_space:
session.set_observation(observation_space, reply.observation.add())

reply.session_id = len(self.sessions)
reply.benchmark = session.benchmark
self.sessions[reply.session_id] = session

return reply

def EndSession(
self, request: proto.EndSessionRequest, context
Expand Down

0 comments on commit 39f027c

Please sign in to comment.