-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7dce2b0
commit 24caf7c
Showing
5 changed files
with
429 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
OBJDIR=.obj | ||
CFLAGS=`pkg-config --cflags jsoncpp libcurl` -std=c++11 -Wall | ||
LIBS =`pkg-config --libs jsoncpp libcurl` -lstdc++ | ||
|
||
all: dirs random_agent lib/gym_binding.a | ||
.PHONY: clean all dirs | ||
|
||
random_agent: .obj/random_agent.o .obj/gym_binding.o | ||
gcc -o $@ $^ $(LIBS) | ||
|
||
lib/gym_binding.a: .obj/gym_binding.o | ||
ar rcs $@ $^ | ||
|
||
$(OBJDIR)/%.o: %.cpp | ||
g++ $(CFLAGS) -c $< -o $@ | ||
|
||
dirs: | ||
mkdir -p $(OBJDIR) lib | ||
|
||
clean: | ||
rm -rf $(OBJDIR) lib/* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,298 @@ | ||
#include "include/gym/gym.h" | ||
#include <boost/enable_shared_from_this.hpp> | ||
|
||
#include <curl/curl.h> | ||
|
||
#include <json/value.h> | ||
#include <json/reader.h> | ||
|
||
#include <stdio.h> | ||
#include <random> | ||
|
||
namespace Gym { | ||
|
||
static bool verbose = false; | ||
|
||
static std::random_device rd; | ||
static std::mt19937 rand_generator(rd()); | ||
|
||
std::vector<float> Space::sample() | ||
{ | ||
if (type==DISCRETE) { | ||
std::uniform_int_distribution<int> randint(0, discreet_n-1); | ||
std::vector<float> r(1, 0.0f); | ||
r[0] = randint(rand_generator); | ||
return r; | ||
} | ||
|
||
assert(type==BOX); | ||
std::uniform_real_distribution<float> rand(0.0f, 1.0f); | ||
int sz = 1; | ||
for (int dim: box_shape) | ||
sz *= dim; | ||
assert((int)box_high.size()==sz); | ||
assert((int)box_low.size()==sz); | ||
|
||
std::vector<float> r(sz, 0.0f); | ||
for (int c=0; c<sz; ++c) | ||
r[c] = (box_high[c]-box_low[c])*rand(rand_generator) + box_low[c]; | ||
return r; | ||
} | ||
|
||
static | ||
std::string require(const Json::Value& v, const std::string& k) | ||
{ | ||
if (!v.isObject() || !v.isMember(k)) | ||
throw std::runtime_error("cannot find required parameter '" + k + "'"); | ||
return v[k].asString(); | ||
} | ||
|
||
static | ||
boost::shared_ptr<Space> space_from_json(const Json::Value& j) | ||
{ | ||
boost::shared_ptr<Space> r(new Space); | ||
Json::Value v = j["info"]; | ||
std::string type = require(v, "name"); | ||
if (type=="Discrete") { | ||
r->type = Space::DISCRETE; | ||
r->discreet_n = v["n"].asInt(); // will throw runtime_error if cannot be converted to int | ||
|
||
} else if (type=="Box") { | ||
r->type = Space::BOX; | ||
Json::Value shape = v["shape"]; | ||
Json::Value low = v["low"]; | ||
Json::Value high = v["high"]; | ||
if (!shape.isArray() || !low.isArray() || !high.isArray()) | ||
throw std::runtime_error("cannot parse box space (1)"); | ||
int l1 = low.size(); | ||
int l2 = high.size(); | ||
int ls = shape.size(); | ||
int sz = 1; | ||
for (int s=0; s<ls; ++s) { | ||
int e = shape[s].asInt(); | ||
r->box_shape.push_back(e); | ||
sz *= e; | ||
} | ||
if (sz != l1 || l1 != l2) | ||
throw std::runtime_error("cannot parse box space (2)"); | ||
r->box_low.resize(sz); | ||
r->box_high.resize(sz); | ||
for (int i=0; i<sz; ++i) { | ||
r->box_low[i] = low[i].asFloat(); | ||
r->box_high[i] = high[i].asFloat(); | ||
} | ||
|
||
} else { | ||
throw std::runtime_error("unknown space type '" + type + "'"); | ||
} | ||
|
||
return r; | ||
} | ||
|
||
|
||
// curl | ||
|
||
static | ||
std::size_t curl_save_to_string(void* buffer, std::size_t size, std::size_t nmemb, void* userp) | ||
{ | ||
std::string* str = static_cast<std::string*>(userp); | ||
const std::size_t bytes = nmemb*size; | ||
str->append(static_cast<char*>(buffer), bytes); | ||
return bytes; | ||
} | ||
|
||
class ClientReal: public Client, public boost::enable_shared_from_this<ClientReal> { | ||
public: | ||
std::string addr; | ||
int port; | ||
|
||
boost::shared_ptr<CURL> h; | ||
boost::shared_ptr<curl_slist> headers; | ||
std::vector<char> curl_error_buf; | ||
|
||
ClientReal() | ||
{ | ||
CURL* c = curl_easy_init(); | ||
curl_easy_setopt(c, CURLOPT_NOSIGNAL, 1); | ||
curl_easy_setopt(c, CURLOPT_CONNECTTIMEOUT_MS, 3000); | ||
curl_easy_setopt(c, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4); | ||
curl_easy_setopt(c, CURLOPT_FOLLOWLOCATION, true); | ||
curl_easy_setopt(c, CURLOPT_SSL_VERIFYPEER, 0); | ||
curl_easy_setopt(c, CURLOPT_SSL_VERIFYHOST, 0); | ||
curl_easy_setopt(c, CURLOPT_WRITEFUNCTION, &curl_save_to_string); | ||
curl_error_buf.assign(CURL_ERROR_SIZE, 0); | ||
curl_easy_setopt(c, CURLOPT_ERRORBUFFER, curl_error_buf.data()); | ||
h.reset(c, std::ptr_fun(curl_easy_cleanup)); | ||
headers.reset(curl_slist_append(0, "Content-Type: application/json"), std::ptr_fun(curl_slist_free_all)); | ||
} | ||
|
||
Json::Value GET(const std::string& route) | ||
{ | ||
std::string url = "http://" + addr + route; | ||
if (verbose) printf("GET %s\n", url.c_str()); | ||
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str()); | ||
curl_easy_setopt(h.get(), CURLOPT_PORT, port); | ||
std::string answer; | ||
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer); | ||
curl_easy_setopt(h.get(), CURLOPT_POST, 0); | ||
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, 0); | ||
|
||
CURLcode r; | ||
r = curl_easy_perform(h.get()); | ||
if (r) throw std::runtime_error(curl_error_buf.data()); | ||
|
||
Json::Value j; | ||
throw_server_error_or_response_code(answer, j); | ||
return j; | ||
} | ||
|
||
Json::Value POST(const std::string& route, const std::string& post_data) | ||
{ | ||
std::string url = "http://" + addr + route; | ||
if (verbose) printf("POST %s\n%s\n", url.c_str(), post_data.c_str()); | ||
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str()); | ||
curl_easy_setopt(h.get(), CURLOPT_PORT, port); | ||
std::string answer; | ||
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer); | ||
curl_easy_setopt(h.get(), CURLOPT_POST, 1); | ||
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDS, post_data.c_str()); | ||
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDSIZE_LARGE, (curl_off_t)post_data.size()); | ||
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, headers.get()); | ||
|
||
CURLcode r = curl_easy_perform(h.get()); | ||
if (r) throw std::runtime_error(curl_error_buf.data()); | ||
|
||
Json::Value j; | ||
throw_server_error_or_response_code(answer, j); | ||
return j; | ||
} | ||
|
||
void throw_server_error_or_response_code(const std::string& answer, Json::Value& j) | ||
{ | ||
long response_code; | ||
CURLcode r = curl_easy_getinfo(h.get(), CURLINFO_RESPONSE_CODE, &response_code); | ||
if (r) throw std::runtime_error(curl_error_buf.data()); | ||
if (verbose) printf("%i\n%s\n", (int)response_code, answer.c_str()); | ||
|
||
std::string parse_error; | ||
Json::Reader jr; | ||
if (!jr.parse(answer, j, false)) { | ||
parse_error = jr.getFormattedErrorMessages(); | ||
parse_error += "original json that caused error: " + answer; | ||
} else if (!j.isObject()) { | ||
parse_error = "top level json is not an object"; | ||
parse_error += "original json that caused error: " + answer; | ||
} | ||
|
||
if (response_code != 200 && j.isObject() && j.isMember("message")) { | ||
throw std::runtime_error(j["message"].asString()); | ||
} else if (response_code != 200) { | ||
throw std::runtime_error("bad HTTP response code, and also cannot parse server message: " + answer); | ||
} else { | ||
// 200, but maybe invalid json | ||
if (!parse_error.empty()) | ||
throw std::runtime_error(parse_error); | ||
} | ||
} | ||
|
||
boost::shared_ptr<Environment> make(const std::string& env_id) override; | ||
}; | ||
|
||
boost::shared_ptr<Client> client_create(const std::string& addr, int port) | ||
{ | ||
boost::shared_ptr<ClientReal> client(new ClientReal); | ||
client->addr = addr; | ||
client->port = port; | ||
return client; | ||
} | ||
|
||
|
||
// environment | ||
|
||
class EnvironmentReal: public Environment { | ||
public: | ||
std::string instance_id; | ||
boost::shared_ptr<ClientReal> client; | ||
boost::shared_ptr<Space> space_act; | ||
boost::shared_ptr<Space> space_obs; | ||
|
||
boost::shared_ptr<Space> action_space() override | ||
{ | ||
if (!space_act) | ||
space_act = space_from_json(client->GET("/v1/envs/" + instance_id + "/action_space")); | ||
return space_act; | ||
} | ||
|
||
boost::shared_ptr<Space> observation_space() override | ||
{ | ||
if (!space_obs) | ||
space_obs = space_from_json(client->GET("/v1/envs/" + instance_id + "/observation_space")); | ||
return space_obs; | ||
} | ||
|
||
void observation_parse(const Json::Value& v, std::vector<float>& save_here) | ||
{ | ||
if (!v.isArray()) | ||
throw std::runtime_error("cannot parse observation, not an array"); | ||
int s = v.size(); | ||
save_here.resize(s); | ||
for (int i=0; i<s; ++i) | ||
save_here[i] = v[i].asFloat(); | ||
} | ||
|
||
void reset(State* save_initial_state_here) override | ||
{ | ||
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/reset/", ""); | ||
observation_parse(ans["observation"], save_initial_state_here->observation); | ||
} | ||
|
||
void step(const std::vector<float>& action, bool render, State* save_state_here) override | ||
{ | ||
Json::Value act_json; | ||
boost::shared_ptr<Space> aspace = action_space(); | ||
if (aspace->type==Space::DISCRETE) { | ||
act_json["action"] = (int) action[0]; | ||
} else if (aspace->type==Space::BOX) { | ||
Json::Value& array = act_json["action"]; | ||
assert(action.size()==aspace->box_low.size()); // really assert, it's a programming error on C++ part | ||
for (int c=0; c<(int)action.size(); ++c) | ||
array[c] = action[c]; | ||
} else { | ||
assert(0); | ||
} | ||
act_json["render"] = render; | ||
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/step/", act_json.toStyledString()); | ||
observation_parse(ans["observation"], save_state_here->observation); | ||
save_state_here->done = ans["done"].asBool(); | ||
save_state_here->reward = ans["reward"].asFloat(); | ||
} | ||
|
||
void monitor_start(const std::string& directory, bool force, bool resume) override | ||
{ | ||
Json::Value data; | ||
data["directory"] = directory; | ||
data["force"] = force; | ||
data["resume"] = resume; | ||
client->POST("/v1/envs/" + instance_id + "/monitor/start/", data.toStyledString()); | ||
} | ||
|
||
void monitor_stop() override | ||
{ | ||
client->POST("/v1/envs/" + instance_id + "/monitor/close/", ""); | ||
} | ||
}; | ||
|
||
boost::shared_ptr<Environment> ClientReal::make(const std::string& env_id) | ||
{ | ||
Json::Value req; | ||
req["env_id"] = env_id; | ||
Json::Value ans = POST("/v1/envs/", req.toStyledString()); | ||
std::string instance_id = require(ans, "instance_id"); | ||
if (verbose) printf(" * created %s instance_id=%s\n", env_id.c_str(), instance_id.c_str()); | ||
boost::shared_ptr<EnvironmentReal> env(new EnvironmentReal); | ||
env->client = shared_from_this(); | ||
env->instance_id = instance_id; | ||
return env; | ||
} | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#ifndef __GYM_H__ | ||
#define __GYM_H__ | ||
// Caffe uses boost::shared_ptr (as opposed to std::shared_ptr), so do we. | ||
#include <boost/shared_ptr.hpp> | ||
#include <vector> | ||
|
||
namespace Gym { | ||
|
||
struct Space { | ||
enum SpaceType { | ||
DISCRETE, | ||
BOX, | ||
} type; | ||
|
||
std::vector<float> sample(); // Random vector that belong to this space | ||
|
||
std::vector<int> box_shape; // Similar to Caffe blob shape, for example { 64, 96, 3 } for 96x64 rgb image. | ||
std::vector<float> box_high; | ||
std::vector<float> box_low; | ||
|
||
int discreet_n; | ||
}; | ||
|
||
struct State { | ||
std::vector<float> observation; // get observation_space() to make sense of this data | ||
float reward; | ||
bool done; | ||
std::string info; | ||
}; | ||
|
||
class Environment { | ||
public: | ||
virtual boost::shared_ptr<Space> action_space() =0; | ||
virtual boost::shared_ptr<Space> observation_space() =0; | ||
|
||
virtual void reset(State* save_initial_state_here) =0; | ||
|
||
virtual void step(const std::vector<float>& action, bool render, State* save_state_here) =0; | ||
|
||
virtual void monitor_start(const std::string& directory, bool force, bool resume) =0; | ||
virtual void monitor_stop() =0; | ||
}; | ||
|
||
class Client { | ||
public: | ||
virtual boost::shared_ptr<Environment> make(const std::string& name) =0; | ||
}; | ||
|
||
extern boost::shared_ptr<Client> client_create(const std::string& addr, int port); | ||
|
||
} // namespace | ||
|
||
#endif // __GYM_H__ |
Oops, something went wrong.