Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions src/confighttp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ namespace confighttp {
// TODO: Input Validation
pt::read_json(ss, inputTree);
std::string pin = inputTree.get<std::string>("pin");
outputTree.put("status", nvhttp::pin(pin));
std::string name = inputTree.get<std::string>("name");
outputTree.put("status", nvhttp::pin(pin, name));
}
catch (std::exception &e) {
BOOST_LOG(warning) << "SavePin: "sv << e.what();
Expand All @@ -717,6 +718,60 @@ namespace confighttp {
response->write(data.str());
});
nvhttp::erase_all_clients();
proc::proc.terminate();
outputTree.put("status", true);
}

void
unpair(resp_https_t response, req_https_t request) {
if (!authenticate(response, request)) return;

print_req(request);

std::stringstream ss;
ss << request->content.rdbuf();

pt::ptree inputTree, outputTree;

auto g = util::fail_guard([&]() {
std::ostringstream data;
pt::write_json(data, outputTree);
response->write(data.str());
});

try {
// TODO: Input Validation
pt::read_json(ss, inputTree);
std::string uuid = inputTree.get<std::string>("uuid");
outputTree.put("status", nvhttp::unpair_client(uuid));
}
catch (std::exception &e) {
BOOST_LOG(warning) << "Unpair: "sv << e.what();
outputTree.put("status", false);
outputTree.put("error", e.what());
return;
}
}

void
listClients(resp_https_t response, req_https_t request) {
if (!authenticate(response, request)) return;

print_req(request);

pt::ptree named_certs = nvhttp::get_all_clients();

pt::ptree outputTree;

outputTree.put("status", false);

auto g = util::fail_guard([&]() {
std::ostringstream data;
pt::write_json(data, outputTree);
response->write(data.str());
});

outputTree.add_child("named_certs", named_certs);
outputTree.put("status", true);
}

Expand Down Expand Up @@ -765,7 +820,9 @@ namespace confighttp {
server.resource["^/api/restart$"]["POST"] = restart;
server.resource["^/api/password$"]["POST"] = savePassword;
server.resource["^/api/apps/([0-9]+)$"]["DELETE"] = deleteApp;
server.resource["^/api/clients/unpair$"]["POST"] = unpairAll;
server.resource["^/api/clients/unpair-all$"]["POST"] = unpairAll;
server.resource["^/api/clients/list$"]["GET"] = listClients;
server.resource["^/api/clients/unpair$"]["POST"] = unpair;
server.resource["^/api/apps/close$"]["POST"] = closeApp;
server.resource["^/api/covers/upload$"]["POST"] = uploadCover;
server.resource["^/images/sunshine.ico$"]["GET"] = getFaviconImage;
Expand Down
176 changes: 132 additions & 44 deletions src/nvhttp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ namespace nvhttp {
std::string pkey;
} conf_intern;

struct named_cert_t {
std::string name;
std::string uuid;
std::string cert;
};

struct client_t {
std::string uniqueID;
std::vector<std::string> certs;
std::vector<named_cert_t> named_devices;
};

struct pair_session_t {
Expand All @@ -145,7 +151,7 @@ namespace nvhttp {

// uniqueID, session
std::unordered_map<std::string, pair_session_t> map_id_sess;
std::unordered_map<std::string, client_t> map_id_client;
client_t client_root;
std::atomic<uint32_t> session_id_counter;

using args_t = SimpleWeb::CaseInsensitiveMultimap;
Expand Down Expand Up @@ -189,22 +195,18 @@ namespace nvhttp {
root.erase("root"s);

root.put("root.uniqueid", http::unique_id);
auto &nodes = root.add_child("root.devices", pt::ptree {});
for (auto &[_, client] : map_id_client) {
pt::ptree node;

node.put("uniqueid"s, client.uniqueID);

pt::ptree cert_nodes;
for (auto &cert : client.certs) {
pt::ptree cert_node;
cert_node.put_value(cert);
cert_nodes.push_back(std::make_pair(""s, cert_node));
}
node.add_child("certs"s, cert_nodes);

nodes.push_back(std::make_pair(""s, node));
client_t &client = client_root;
pt::ptree node;

pt::ptree named_cert_nodes;
for (auto &named_cert : client.named_devices) {
pt::ptree named_cert_node;
named_cert_node.put("name"s, named_cert.name);
named_cert_node.put("cert"s, named_cert.cert);
named_cert_node.put("uuid"s, named_cert.uuid);
named_cert_nodes.push_back(std::make_pair(""s, named_cert_node));
}
root.add_child("root.named_devices"s, named_cert_nodes);

try {
pt::write_json(config::nvhttp.file_state, root);
Expand All @@ -223,48 +225,79 @@ namespace nvhttp {
return;
}

pt::ptree root;
pt::ptree tree;
try {
pt::read_json(config::nvhttp.file_state, root);
pt::read_json(config::nvhttp.file_state, tree);
}
catch (std::exception &e) {
BOOST_LOG(error) << "Couldn't read "sv << config::nvhttp.file_state << ": "sv << e.what();

return;
}

auto unique_id_p = root.get_optional<std::string>("root.uniqueid");
auto unique_id_p = tree.get_optional<std::string>("root.uniqueid");
if (!unique_id_p) {
// This file doesn't contain moonlight credentials
http::unique_id = uuid_util::uuid_t::generate().string();
return;
}
http::unique_id = std::move(*unique_id_p);

auto device_nodes = root.get_child("root.devices");

for (auto &[_, device_node] : device_nodes) {
auto uniqID = device_node.get<std::string>("uniqueid");
auto &client = map_id_client.emplace(uniqID, client_t {}).first->second;

client.uniqueID = uniqID;
auto root = tree.get_child("root");
client_t client;

// Import from old format
if (root.get_child_optional("devices")) {
auto device_nodes = root.get_child("devices");
for (auto &[_, device_node] : device_nodes) {
auto uniqID = device_node.get<std::string>("uniqueid");

if (device_node.count("certs")) {
for (auto &[_, el] : device_node.get_child("certs")) {
named_cert_t named_cert;
named_cert.name = ""s;
named_cert.cert = el.get_value<std::string>();
named_cert.uuid = uuid_util::uuid_t::generate().string();
client.named_devices.emplace_back(named_cert);
client.certs.emplace_back(named_cert.cert);
}
}
}
}

for (auto &[_, el] : device_node.get_child("certs")) {
client.certs.emplace_back(el.get_value<std::string>());
if (root.count("named_devices")) {
for (auto &[_, el] : root.get_child("named_devices")) {
named_cert_t named_cert;
named_cert.name = el.get_child("name").get_value<std::string>();
named_cert.cert = el.get_child("cert").get_value<std::string>();
named_cert.uuid = el.get_child("uuid").get_value<std::string>();
client.named_devices.emplace_back(named_cert);
client.certs.emplace_back(named_cert.cert);
}
}

// Empty certificate chain and import certs from file
cert_chain.clear();
for (auto &cert : client.certs) {
cert_chain.add(crypto::x509(cert));
}
for (auto &named_cert : client.named_devices) {
cert_chain.add(crypto::x509(named_cert.cert));
}

client_root = client;
}

void
update_id_client(const std::string &uniqueID, std::string &&cert, op_e op) {
switch (op) {
case op_e::ADD: {
auto &client = map_id_client[uniqueID];
client_t &client = client_root;
client.certs.emplace_back(std::move(cert));
client.uniqueID = uniqueID;
} break;
case op_e::REMOVE:
map_id_client.erase(uniqueID);
client_t client;
client_root = client;
break;
}

Expand Down Expand Up @@ -579,15 +612,16 @@ namespace nvhttp {
/**
* @brief Compare the user supplied pin to the Moonlight pin.
* @param pin The user supplied pin.
* @param name The user supplied name.
* @return `true` if the pin is correct, `false` otherwise.
*
* EXAMPLES:
* ```cpp
* bool pin_status = nvhttp::pin("1234");
* bool pin_status = nvhttp::pin("1234", "laptop");
* ```
*/
bool
pin(std::string pin) {
pin(std::string pin, std::string name) {
pt::ptree tree;
if (map_id_sess.empty()) {
return false;
Expand All @@ -613,6 +647,14 @@ namespace nvhttp {
auto &sess = std::begin(map_id_sess)->second;
getservercert(sess, tree, pin);

// set up named cert
client_t &client = client_root;
named_cert_t named_cert;
named_cert.name = name;
named_cert.cert = sess.client.cert;
named_cert.uuid = uuid_util::uuid_t::generate().string();
client.named_devices.emplace_back(named_cert);

// response to the request for pin
std::ostringstream data;
pt::write_xml(data, tree);
Expand Down Expand Up @@ -645,9 +687,7 @@ namespace nvhttp {
auto clientID = args.find("uniqueid"s);

if (clientID != std::end(args)) {
if (auto it = map_id_client.find(clientID->second); it != std::end(map_id_client)) {
pair_status = 1;
}
pair_status = 1;
}
}

Expand Down Expand Up @@ -742,6 +782,20 @@ namespace nvhttp {
response->close_connection_after_response = true;
}

pt::ptree
get_all_clients() {
pt::ptree named_cert_nodes;
client_t &client = client_root;
for (auto &named_cert : client.named_devices) {
pt::ptree named_cert_node;
named_cert_node.put("name"s, named_cert.name);
named_cert_node.put("uuid"s, named_cert.uuid);
named_cert_nodes.push_back(std::make_pair(""s, named_cert_node));
}

return named_cert_nodes;
}

void
applist(resp_https_t response, req_https_t request) {
print_req<SimpleWeb::HTTPS>(request);
Expand Down Expand Up @@ -1020,12 +1074,6 @@ namespace nvhttp {
conf_intern.pkey = file_handler::read_file(config::nvhttp.pkey.c_str());
conf_intern.servercert = file_handler::read_file(config::nvhttp.cert.c_str());

for (auto &[_, client] : map_id_client) {
for (auto &cert : client.certs) {
cert_chain.add(crypto::x509(cert));
}
}

auto add_cert = std::make_shared<safe::queue_t<crypto::x509_t>>(30);

// resume doesn't always get the parameter "localAudioPlayMode"
Expand Down Expand Up @@ -1149,8 +1197,48 @@ namespace nvhttp {
*/
void
erase_all_clients() {
map_id_client.clear();
client_t client;
client_root = client;
cert_chain.clear();
save_state();
}

/**
* @brief Remove single client.
*
* EXAMPLES:
* ```cpp
* nvhttp::unpair_client("4D7BB2DD-5704-A405-B41C-891A022932E1");
* ```
*/
int
unpair_client(std::string uuid) {
int removed = 0;
client_t &client = client_root;
for (auto it = client.named_devices.begin(); it != client.named_devices.end();) {
if ((*it).uuid == uuid) {
// Find matching cert and remove it
for (auto cert = client.certs.begin(); cert != client.certs.end();) {
if ((*cert) == (*it).cert) {
cert = client.certs.erase(cert);
removed++;
}
else {
++cert;
}
}

// And then remove the named cert
it = client.named_devices.erase(it);
removed++;
}
else {
++it;
}
}

save_state();
load_state();
return removed;
}
} // namespace nvhttp
9 changes: 8 additions & 1 deletion src/nvhttp.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// standard includes
#include <string>

// lib includes
#include <boost/property_tree/ptree.hpp>

// local includes
#include "thread_safe.h"

Expand Down Expand Up @@ -43,7 +46,11 @@ namespace nvhttp {
void
start();
bool
pin(std::string pin);
pin(std::string pin, std::string name);
int
unpair_client(std::string uniqueid);
boost::property_tree::ptree
get_all_clients();
void
erase_all_clients();
} // namespace nvhttp
Loading