Skip to content

Commit

Permalink
Experimental C++ binding
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Jun 8, 2016
1 parent 7dce2b0 commit 24caf7c
Show file tree
Hide file tree
Showing 5 changed files with 429 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,15 @@ tutorial/*.html

# PyCharm project files
.idea

# C++ binding
binding-cpp/random_agent
binding-cpp/lib/gym_binding.a
binding-cpp/.obj/*

# Qt Creator files for C++ binding
binding-cpp/gym-binding-cpp.config
binding-cpp/gym-binding-cpp.creator
binding-cpp/gym-binding-cpp.creator.user
binding-cpp/gym-binding-cpp.files
binding-cpp/gym-binding-cpp.includes
22 changes: 22 additions & 0 deletions binding-cpp/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
OBJDIR=.obj
CFLAGS=`pkg-config --cflags jsoncpp libcurl` -std=c++11 -Wall
LIBS =`pkg-config --libs jsoncpp libcurl` -lstdc++

all: dirs random_agent lib/gym_binding.a
.PHONY: clean all dirs

random_agent: .obj/random_agent.o .obj/gym_binding.o
gcc -o $@ $^ $(LIBS)

lib/gym_binding.a: .obj/gym_binding.o
ar rcs $@ $^

$(OBJDIR)/%.o: %.cpp
g++ $(CFLAGS) -c $< -o $@

dirs:
mkdir -p $(OBJDIR) lib

clean:
rm -rf $(OBJDIR) lib/*

298 changes: 298 additions & 0 deletions binding-cpp/gym_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
#include "include/gym/gym.h"
#include <boost/enable_shared_from_this.hpp>

#include <curl/curl.h>

#include <json/value.h>
#include <json/reader.h>

#include <stdio.h>
#include <random>

namespace Gym {

static bool verbose = false;

static std::random_device rd;
static std::mt19937 rand_generator(rd());

std::vector<float> Space::sample()
{
if (type==DISCRETE) {
std::uniform_int_distribution<int> randint(0, discreet_n-1);
std::vector<float> r(1, 0.0f);
r[0] = randint(rand_generator);
return r;
}

assert(type==BOX);
std::uniform_real_distribution<float> rand(0.0f, 1.0f);
int sz = 1;
for (int dim: box_shape)
sz *= dim;
assert((int)box_high.size()==sz);
assert((int)box_low.size()==sz);

std::vector<float> r(sz, 0.0f);
for (int c=0; c<sz; ++c)
r[c] = (box_high[c]-box_low[c])*rand(rand_generator) + box_low[c];
return r;
}

static
std::string require(const Json::Value& v, const std::string& k)
{
if (!v.isObject() || !v.isMember(k))
throw std::runtime_error("cannot find required parameter '" + k + "'");
return v[k].asString();
}

static
boost::shared_ptr<Space> space_from_json(const Json::Value& j)
{
boost::shared_ptr<Space> r(new Space);
Json::Value v = j["info"];
std::string type = require(v, "name");
if (type=="Discrete") {
r->type = Space::DISCRETE;
r->discreet_n = v["n"].asInt(); // will throw runtime_error if cannot be converted to int

} else if (type=="Box") {
r->type = Space::BOX;
Json::Value shape = v["shape"];
Json::Value low = v["low"];
Json::Value high = v["high"];
if (!shape.isArray() || !low.isArray() || !high.isArray())
throw std::runtime_error("cannot parse box space (1)");
int l1 = low.size();
int l2 = high.size();
int ls = shape.size();
int sz = 1;
for (int s=0; s<ls; ++s) {
int e = shape[s].asInt();
r->box_shape.push_back(e);
sz *= e;
}
if (sz != l1 || l1 != l2)
throw std::runtime_error("cannot parse box space (2)");
r->box_low.resize(sz);
r->box_high.resize(sz);
for (int i=0; i<sz; ++i) {
r->box_low[i] = low[i].asFloat();
r->box_high[i] = high[i].asFloat();
}

} else {
throw std::runtime_error("unknown space type '" + type + "'");
}

return r;
}


// curl

static
std::size_t curl_save_to_string(void* buffer, std::size_t size, std::size_t nmemb, void* userp)
{
std::string* str = static_cast<std::string*>(userp);
const std::size_t bytes = nmemb*size;
str->append(static_cast<char*>(buffer), bytes);
return bytes;
}

class ClientReal: public Client, public boost::enable_shared_from_this<ClientReal> {
public:
std::string addr;
int port;

boost::shared_ptr<CURL> h;
boost::shared_ptr<curl_slist> headers;
std::vector<char> curl_error_buf;

ClientReal()
{
CURL* c = curl_easy_init();
curl_easy_setopt(c, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(c, CURLOPT_CONNECTTIMEOUT_MS, 3000);
curl_easy_setopt(c, CURLOPT_IPRESOLVE, CURL_IPRESOLVE_V4);
curl_easy_setopt(c, CURLOPT_FOLLOWLOCATION, true);
curl_easy_setopt(c, CURLOPT_SSL_VERIFYPEER, 0);
curl_easy_setopt(c, CURLOPT_SSL_VERIFYHOST, 0);
curl_easy_setopt(c, CURLOPT_WRITEFUNCTION, &curl_save_to_string);
curl_error_buf.assign(CURL_ERROR_SIZE, 0);
curl_easy_setopt(c, CURLOPT_ERRORBUFFER, curl_error_buf.data());
h.reset(c, std::ptr_fun(curl_easy_cleanup));
headers.reset(curl_slist_append(0, "Content-Type: application/json"), std::ptr_fun(curl_slist_free_all));
}

Json::Value GET(const std::string& route)
{
std::string url = "http://" + addr + route;
if (verbose) printf("GET %s\n", url.c_str());
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(h.get(), CURLOPT_PORT, port);
std::string answer;
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer);
curl_easy_setopt(h.get(), CURLOPT_POST, 0);
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, 0);

CURLcode r;
r = curl_easy_perform(h.get());
if (r) throw std::runtime_error(curl_error_buf.data());

Json::Value j;
throw_server_error_or_response_code(answer, j);
return j;
}

Json::Value POST(const std::string& route, const std::string& post_data)
{
std::string url = "http://" + addr + route;
if (verbose) printf("POST %s\n%s\n", url.c_str(), post_data.c_str());
curl_easy_setopt(h.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(h.get(), CURLOPT_PORT, port);
std::string answer;
curl_easy_setopt(h.get(), CURLOPT_WRITEDATA, &answer);
curl_easy_setopt(h.get(), CURLOPT_POST, 1);
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDS, post_data.c_str());
curl_easy_setopt(h.get(), CURLOPT_POSTFIELDSIZE_LARGE, (curl_off_t)post_data.size());
curl_easy_setopt(h.get(), CURLOPT_HTTPHEADER, headers.get());

CURLcode r = curl_easy_perform(h.get());
if (r) throw std::runtime_error(curl_error_buf.data());

Json::Value j;
throw_server_error_or_response_code(answer, j);
return j;
}

void throw_server_error_or_response_code(const std::string& answer, Json::Value& j)
{
long response_code;
CURLcode r = curl_easy_getinfo(h.get(), CURLINFO_RESPONSE_CODE, &response_code);
if (r) throw std::runtime_error(curl_error_buf.data());
if (verbose) printf("%i\n%s\n", (int)response_code, answer.c_str());

std::string parse_error;
Json::Reader jr;
if (!jr.parse(answer, j, false)) {
parse_error = jr.getFormattedErrorMessages();
parse_error += "original json that caused error: " + answer;
} else if (!j.isObject()) {
parse_error = "top level json is not an object";
parse_error += "original json that caused error: " + answer;
}

if (response_code != 200 && j.isObject() && j.isMember("message")) {
throw std::runtime_error(j["message"].asString());
} else if (response_code != 200) {
throw std::runtime_error("bad HTTP response code, and also cannot parse server message: " + answer);
} else {
// 200, but maybe invalid json
if (!parse_error.empty())
throw std::runtime_error(parse_error);
}
}

boost::shared_ptr<Environment> make(const std::string& env_id) override;
};

boost::shared_ptr<Client> client_create(const std::string& addr, int port)
{
boost::shared_ptr<ClientReal> client(new ClientReal);
client->addr = addr;
client->port = port;
return client;
}


// environment

class EnvironmentReal: public Environment {
public:
std::string instance_id;
boost::shared_ptr<ClientReal> client;
boost::shared_ptr<Space> space_act;
boost::shared_ptr<Space> space_obs;

boost::shared_ptr<Space> action_space() override
{
if (!space_act)
space_act = space_from_json(client->GET("/v1/envs/" + instance_id + "/action_space"));
return space_act;
}

boost::shared_ptr<Space> observation_space() override
{
if (!space_obs)
space_obs = space_from_json(client->GET("/v1/envs/" + instance_id + "/observation_space"));
return space_obs;
}

void observation_parse(const Json::Value& v, std::vector<float>& save_here)
{
if (!v.isArray())
throw std::runtime_error("cannot parse observation, not an array");
int s = v.size();
save_here.resize(s);
for (int i=0; i<s; ++i)
save_here[i] = v[i].asFloat();
}

void reset(State* save_initial_state_here) override
{
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/reset/", "");
observation_parse(ans["observation"], save_initial_state_here->observation);
}

void step(const std::vector<float>& action, bool render, State* save_state_here) override
{
Json::Value act_json;
boost::shared_ptr<Space> aspace = action_space();
if (aspace->type==Space::DISCRETE) {
act_json["action"] = (int) action[0];
} else if (aspace->type==Space::BOX) {
Json::Value& array = act_json["action"];
assert(action.size()==aspace->box_low.size()); // really assert, it's a programming error on C++ part
for (int c=0; c<(int)action.size(); ++c)
array[c] = action[c];
} else {
assert(0);
}
act_json["render"] = render;
Json::Value ans = client->POST("/v1/envs/" + instance_id + "/step/", act_json.toStyledString());
observation_parse(ans["observation"], save_state_here->observation);
save_state_here->done = ans["done"].asBool();
save_state_here->reward = ans["reward"].asFloat();
}

void monitor_start(const std::string& directory, bool force, bool resume) override
{
Json::Value data;
data["directory"] = directory;
data["force"] = force;
data["resume"] = resume;
client->POST("/v1/envs/" + instance_id + "/monitor/start/", data.toStyledString());
}

void monitor_stop() override
{
client->POST("/v1/envs/" + instance_id + "/monitor/close/", "");
}
};

boost::shared_ptr<Environment> ClientReal::make(const std::string& env_id)
{
Json::Value req;
req["env_id"] = env_id;
Json::Value ans = POST("/v1/envs/", req.toStyledString());
std::string instance_id = require(ans, "instance_id");
if (verbose) printf(" * created %s instance_id=%s\n", env_id.c_str(), instance_id.c_str());
boost::shared_ptr<EnvironmentReal> env(new EnvironmentReal);
env->client = shared_from_this();
env->instance_id = instance_id;
return env;
}

} // namespace
53 changes: 53 additions & 0 deletions binding-cpp/include/gym/gym.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#ifndef __GYM_H__
#define __GYM_H__
// Caffe uses boost::shared_ptr (as opposed to std::shared_ptr), so do we.
#include <boost/shared_ptr.hpp>
#include <vector>

namespace Gym {

struct Space {
enum SpaceType {
DISCRETE,
BOX,
} type;

std::vector<float> sample(); // Random vector that belong to this space

std::vector<int> box_shape; // Similar to Caffe blob shape, for example { 64, 96, 3 } for 96x64 rgb image.
std::vector<float> box_high;
std::vector<float> box_low;

int discreet_n;
};

struct State {
std::vector<float> observation; // get observation_space() to make sense of this data
float reward;
bool done;
std::string info;
};

class Environment {
public:
virtual boost::shared_ptr<Space> action_space() =0;
virtual boost::shared_ptr<Space> observation_space() =0;

virtual void reset(State* save_initial_state_here) =0;

virtual void step(const std::vector<float>& action, bool render, State* save_state_here) =0;

virtual void monitor_start(const std::string& directory, bool force, bool resume) =0;
virtual void monitor_stop() =0;
};

class Client {
public:
virtual boost::shared_ptr<Environment> make(const std::string& name) =0;
};

extern boost::shared_ptr<Client> client_create(const std::string& addr, int port);

} // namespace

#endif // __GYM_H__
Loading

0 comments on commit 24caf7c

Please sign in to comment.