Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[env] Add initial observations to StartSession calls. #189

Merged
merged 1 commit into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,11 @@ def reset( # pylint: disable=arguments-differ
if self.action_space_name
else 0
),
observation_space=(
[self.observation_space.index]
if self.observation_space
else None
),
),
)
except (ServiceError, ServiceTransportError, TimeoutError) as e:
Expand Down Expand Up @@ -749,7 +754,13 @@ def reset( # pylint: disable=arguments-differ
self.episode_reward = 0

if self.observation_space:
return self.observation[self.observation_space.id]
if len(reply.observation) != 1:
raise OSError(
f"Expected one observation from service, received {len(reply.observation)}"
)
return self.observation.spaces[self.observation_space.id].translate(
reply.observation[0]
)

def step(self, action: Union[int, Iterable[int]]) -> step_t:
"""Take a step.
Expand Down
14 changes: 12 additions & 2 deletions compiler_gym/envs/llvm/service/LlvmService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,19 @@ Status LlvmService::StartSession(ServerContext* /* unused */, const StartSession
RETURN_IF_ERROR(util::intToEnum(request->action_space(), &actionSpace));

// Construct the environment.
reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] =
auto session =
std::make_unique<LlvmSession>(std::move(benchmark), actionSpace, workingDirectory_);

// Compute the initial observations.
for (int i = 0; i < request->observation_space_size(); ++i) {
LlvmObservationSpace observationSpace;
RETURN_IF_ERROR(util::intToEnum(request->observation_space(i), &observationSpace));
auto observation = reply->add_observation();
RETURN_IF_ERROR(session->getObservation(observationSpace, observation));
}

reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] = std::move(session);
++nextSessionId_;

return Status::OK;
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/LlvmSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class LlvmSession {
// since the start of the session. This is just for logging and has no effect.
inline int actionCount() const { return actionCount_; }

protected:
// Run the requested action.
[[nodiscard]] grpc::Status runAction(LlvmAction action, StepReply* reply);

// Compute the requested observation.
[[nodiscard]] grpc::Status getObservation(LlvmObservationSpace space, Observation* reply);

protected:
// Run the given pass, possibly modifying the underlying LLVM module.
void runPass(llvm::Pass* pass, StepReply* reply);
void runPass(llvm::FunctionPass* pass, StepReply* reply);
Expand Down
4 changes: 4 additions & 0 deletions compiler_gym/service/proto/compiler_gym_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ message StartSessionRequest {
// space that is to be used for this session. Once set, the action space
// cannot be changed for the duration of the session.
int32 action_space = 2;
// A list of indices into the GetSpacesReply.observation_space_list
repeated int32 observation_space = 3;
}

message StartSessionReply {
Expand All @@ -85,6 +87,8 @@ message StartSessionReply {
// space and replace it with this one. Else, the action space remains
// unchanged.
ActionSpace new_action_space = 3;
// Observed states after completing the action.
repeated Observation observation = 4;
}

// ===========================================================================
Expand Down
15 changes: 12 additions & 3 deletions examples/example_compiler_gym_service/service_cc/ExampleService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,18 @@ Status ExampleService::StartSession(ServerContext* /* unused*/, const StartSessi
const auto actionSpace = actionSpaces[request->action_space()];

// Create the new compilation session given.
auto session = std::make_unique<ExampleCompilationSession>(benchmark, actionSpace);

// Generate initial observations.
for (int i = 0; i < request->observation_space_size(); ++i) {
RETURN_IF_ERROR(rangeCheck(request->observation_space(i), 0,
static_cast<int32_t>(getObservationSpaces().size()) - 1));
RETURN_IF_ERROR(
session->getObservation(request->observation_space(i), reply->add_observation()));
}

reply->set_session_id(nextSessionId_);
sessions_[nextSessionId_] = std::make_unique<ExampleCompilationSession>(benchmark, actionSpace);
sessions_[nextSessionId_] = std::move(session);
++nextSessionId_;

return Status::OK;
Expand Down Expand Up @@ -170,8 +180,7 @@ Status ExampleCompilationSession::Step(const StepRequest* request, StepReply* re
for (int i = 0; i < request->observation_space_size(); ++i) {
RETURN_IF_ERROR(rangeCheck(request->observation_space(i), 0,
static_cast<int32_t>(getObservationSpaces().size()) - 1));
auto observation = reply->add_observation();
RETURN_IF_ERROR(getObservation(request->observation_space(i), observation));
RETURN_IF_ERROR(getObservation(request->observation_space(i), reply->add_observation()));
}

return Status::OK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class ExampleCompilationSession {

[[nodiscard]] grpc::Status Step(const StepRequest* request, StepReply* reply);

private:
grpc::Status getObservation(int32_t observationSpace, Observation* reply);

private:
const std::string benchmark_;
ActionSpace actionSpace_;
};
Expand Down
35 changes: 22 additions & 13 deletions examples/example_compiler_gym_service/service_py/example_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def __init__(self, benchmark: str):
# Do any of the set up required to start a compilation "session".
self.benchmark = benchmark

def set_observation(self, observation_space, observation):
logging.debug("Compute observation %d", observation_space)
if observation_space == 0: # ir
observation.string_value = "Hello, world!"
elif observation_space == 1: # features
observation.int64_list.value[:] = [0, 0, 0]
elif observation_space == 1: # runtime
observation.scalar_double = 0
return observation

def step(self, request: proto.StepRequest, context) -> proto.StepReply:
reply = proto.StepReply()

Expand All @@ -108,14 +118,7 @@ def step(self, request: proto.StepRequest, context) -> proto.StepReply:
# Compute a list of observations from the user. Each value is an index
# into the OBSERVATION_SPACES list.
for observation_space in request.observation_space:
logging.debug("Compute observation %d", observation_space)
observation = reply.observation.add()
if observation_space == 0: # ir
observation.string_value = "Hello, world!"
elif observation_space == 1: # features
observation.int64_list.value[:] = [0, 0, 0]
elif observation_space == 1: # runtime
observation.scalar_double = 0
self.set_observation(observation_space, reply.observation.add())

return reply

Expand Down Expand Up @@ -163,6 +166,7 @@ def StartSession(
) -> proto.StartSessionReply:
"""Create a new compilation session."""
logging.debug("StartSession(benchmark=%s)", request.benchmark)
reply = proto.StartSessionReply()

if not request.benchmark:
benchmark = "foo" # Pick a default benchmark is none was requested.
Expand All @@ -175,11 +179,16 @@ def StartSession(
return

session = CompilationSession(benchmark=benchmark)
session_id = len(self.sessions)
self.sessions[session_id] = session
return proto.StartSessionReply(
session_id=session_id, benchmark=session.benchmark
)

# Generate the initial observations.
for observation_space in request.observation_space:
session.set_observation(observation_space, reply.observation.add())

reply.session_id = len(self.sessions)
reply.benchmark = session.benchmark
self.sessions[reply.session_id] = session

return reply

def EndSession(
self, request: proto.EndSessionRequest, context
Expand Down