Skip to content

Commit

Permalink
Fix tunnel parsing exception handling. (#550)
Browse files Browse the repository at this point in the history
Rather than using STFATAL and exiting, bubble up tunnel parsing
exception to main and display the responsible option arg and print help
menu so the user need not dig into /tmp/etclient-* to determine why et
aborted.

Fixes #491
  • Loading branch information
jshort authored Dec 14, 2022
1 parent 91099f6 commit 10dd371
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 62 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ add_library(
src/base/WinsockContext.hpp
src/base/SubprocessToString.hpp
src/base/SubprocessToString.cpp
src/base/TunnelUtils.hpp
src/base/TunnelUtils.cpp

${ET_HDRS}
${ET_SRCS}
Expand Down
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ignore:
- "src/terminal/TelemetryService*"
- "src/terminal/PsuedoTerminalConsole.hpp"
- "src/terminal/PsuedoUserTerminal.hpp"
- "src/terminal/*Main.cpp"
- "src/base/TcpSocketHandler*"
- "src/base/SubprocessToString*"
coverage:
Expand Down
66 changes: 66 additions & 0 deletions src/base/TunnelUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#include "TunnelUtils.hpp"

namespace et {
vector<PortForwardSourceRequest> parseRangesToRequests(const string& input) {
vector<PortForwardSourceRequest> pfsrs;
auto j = split(input, ',');
for (auto& pair : j) {
vector<string> sourceDestination = split(pair, ':');
if (sourceDestination.size() < 2) {
throw TunnelParseException(
"Tunnel argument must have source and destination between a ':'");
}
try {
if (sourceDestination[0].find_first_not_of("0123456789-") !=
string::npos &&
sourceDestination[1].find_first_not_of("0123456789-") !=
string::npos) {
PortForwardSourceRequest pfsr;
pfsr.set_environmentvariable(sourceDestination[0]);
pfsr.mutable_destination()->set_name(sourceDestination[1]);
pfsrs.push_back(pfsr);
} else if (sourceDestination[0].find('-') != string::npos &&
sourceDestination[1].find('-') != string::npos) {
vector<string> sourcePortRange = split(sourceDestination[0], '-');
int sourcePortStart = stoi(sourcePortRange[0]);
int sourcePortEnd = stoi(sourcePortRange[1]);

vector<string> destinationPortRange = split(sourceDestination[1], '-');
int destinationPortStart = stoi(destinationPortRange[0]);
int destinationPortEnd = stoi(destinationPortRange[1]);

if (sourcePortEnd - sourcePortStart !=
destinationPortEnd - destinationPortStart) {
throw TunnelParseException(
"source/destination port range must have same length");
} else {
int portRangeLength = sourcePortEnd - sourcePortStart + 1;
for (int i = 0; i < portRangeLength; ++i) {
PortForwardSourceRequest pfsr;
pfsr.mutable_source()->set_port(sourcePortStart + i);
pfsr.mutable_destination()->set_port(destinationPortStart + i);
pfsrs.push_back(pfsr);
}
}
} else if (sourceDestination[0].find('-') != string::npos ||
sourceDestination[1].find('-') != string::npos) {
throw TunnelParseException(
"Invalid port range syntax: if source is a range, "
"destination must be a range (and vice versa)");
} else {
PortForwardSourceRequest pfsr;
pfsr.mutable_source()->set_port(stoi(sourceDestination[0]));
pfsr.mutable_destination()->set_port(stoi(sourceDestination[1]));
pfsrs.push_back(pfsr);
}
} catch (const TunnelParseException& e) {
throw e;
} catch (const std::logic_error& lr) {
throw TunnelParseException("Invalid tunnel argument '" + input +
"': " + lr.what());
}
}
return pfsrs;
}

} // namespace et
20 changes: 20 additions & 0 deletions src/base/TunnelUtils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef __ET_TUNNEL_UTILS__
#define __ET_TUNNEL_UTILS__

#include "ETerminal.pb.h"

namespace et {

vector<PortForwardSourceRequest> parseRangesToRequests(const string& input);

class TunnelParseException : public std::exception {
public:
TunnelParseException(const string& msg) : message(msg) {}
const char* what() const noexcept override { return message.c_str(); }

private:
std::string message = " ";
};

} // namespace et
#endif // __ET_TUNNEL_UTILS__
55 changes: 1 addition & 54 deletions src/terminal/TerminalClient.cpp
Original file line number Diff line number Diff line change
@@ -1,62 +1,9 @@
#include "TerminalClient.hpp"

#include "TelemetryService.hpp"
#include "TunnelUtils.hpp"

namespace et {
vector<PortForwardSourceRequest> parseRangesToRequests(const string& input) {
vector<PortForwardSourceRequest> pfsrs;
auto j = split(input, ',');
for (auto& pair : j) {
vector<string> sourceDestination = split(pair, ':');
try {
if (sourceDestination[0].find_first_not_of("0123456789-") !=
string::npos &&
sourceDestination[1].find_first_not_of("0123456789-") !=
string::npos) {
PortForwardSourceRequest pfsr;
pfsr.set_environmentvariable(sourceDestination[0]);
pfsr.mutable_destination()->set_name(sourceDestination[1]);
pfsrs.push_back(pfsr);
} else if (sourceDestination[0].find('-') != string::npos &&
sourceDestination[1].find('-') != string::npos) {
vector<string> sourcePortRange = split(sourceDestination[0], '-');
int sourcePortStart = stoi(sourcePortRange[0]);
int sourcePortEnd = stoi(sourcePortRange[1]);

vector<string> destinationPortRange = split(sourceDestination[1], '-');
int destinationPortStart = stoi(destinationPortRange[0]);
int destinationPortEnd = stoi(destinationPortRange[1]);

if (sourcePortEnd - sourcePortStart !=
destinationPortEnd - destinationPortStart) {
STFATAL << "source/destination port range mismatch";
exit(1);
} else {
int portRangeLength = sourcePortEnd - sourcePortStart + 1;
for (int i = 0; i < portRangeLength; ++i) {
PortForwardSourceRequest pfsr;
pfsr.mutable_source()->set_port(sourcePortStart + i);
pfsr.mutable_destination()->set_port(destinationPortStart + i);
pfsrs.push_back(pfsr);
}
}
} else if (sourceDestination[0].find('-') != string::npos ||
sourceDestination[1].find('-') != string::npos) {
STFATAL << "Invalid port range syntax: if source is range, "
"destination must be range";
} else {
PortForwardSourceRequest pfsr;
pfsr.mutable_source()->set_port(stoi(sourceDestination[0]));
pfsr.mutable_destination()->set_port(stoi(sourceDestination[1]));
pfsrs.push_back(pfsr);
}
} catch (const std::logic_error& lr) {
STFATAL << "Logic error: " << lr.what();
exit(1);
}
}
return pfsrs;
}

TerminalClient::TerminalClient(
shared_ptr<SocketHandler> _socketHandler,
Expand Down
27 changes: 19 additions & 8 deletions src/terminal/TerminalClientMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "PsuedoTerminalConsole.hpp"
#include "TelemetryService.hpp"
#include "TerminalClient.hpp"
#include "TunnelUtils.hpp"
#include "WinsockContext.hpp"

using namespace et;
Expand All @@ -22,6 +23,12 @@ bool ping(SocketEndpoint socketEndpoint,
return true;
}

void handleParseException(std::exception& e, cxxopts::Options& options) {
CLOG(INFO, "stdout") << "Exception: " << e.what() << "\n" << endl;
CLOG(INFO, "stdout") << options.help({}) << endl;
exit(1);
}

int main(int argc, char** argv) {
WinsockContext context;
string tmpDir = GetTempDirectory();
Expand Down Expand Up @@ -340,17 +347,21 @@ int main(int argc, char** argv) {
}
TelemetryService::get()->logToDatadog("Session Started", el::Level::Info,
__FILE__, __LINE__);
TerminalClient terminalClient(
clientSocket, clientPipeSocket, socketEndpoint, id, passkey, console,
is_jumphost, result.count("t") ? result["t"].as<string>() : "",
result.count("r") ? result["r"].as<string>() : "", forwardAgent,
sshSocket, keepaliveDuration);
string tunnel_arg =
result.count("tunnel") ? result["tunnel"].as<string>() : "";
string r_tunnel_arg = result.count("reversetunnel")
? result["reversetunnel"].as<string>()
: "";
TerminalClient terminalClient(clientSocket, clientPipeSocket,
socketEndpoint, id, passkey, console,
is_jumphost, tunnel_arg, r_tunnel_arg,
forwardAgent, sshSocket, keepaliveDuration);
terminalClient.run(result.count("command") ? result["command"].as<string>()
: "");
} catch (TunnelParseException& tpe) {
handleParseException(tpe, options);
} catch (cxxopts::OptionException& oe) {
CLOG(INFO, "stdout") << "Exception: " << oe.what() << "\n" << endl;
CLOG(INFO, "stdout") << options.help({}) << endl;
exit(1);
handleParseException(oe, options);
}

#ifdef WIN32
Expand Down
36 changes: 36 additions & 0 deletions test/TerminalTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "TerminalClient.hpp"
#include "TerminalServer.hpp"
#include "TestHeaders.hpp"
#include "TunnelUtils.hpp"

namespace et {
TEST_CASE("FakeConsoleTest", "[FakeConsoleTest]") {
Expand Down Expand Up @@ -255,6 +256,41 @@ class EndToEndTestFixture {
bool wasShutdown = false;
};

TEST_CASE("InvalidTunnelArgParsing", "[InvalidTunnelArgParsing]") {
REQUIRE_THROWS_WITH(
parseRangesToRequests("6010"),
Catch::Matchers::Contains("must have source and destination"));
REQUIRE_THROWS_WITH(parseRangesToRequests("6010-6012:7000"),
Catch::Matchers::Contains("must be a range"));
REQUIRE_THROWS_WITH(parseRangesToRequests("6010:7000-7010"),
Catch::Matchers::Contains("must be a range"));
REQUIRE_THROWS_WITH(parseRangesToRequests("6010-6012:7000-8000"),
Catch::Matchers::Contains("must have same length"));
}

TEST_CASE("ValidTunnelArgParsing", "[ValidTunnelArgParsing]") {
// Plain port1:port2 forward
auto pfsrs_single = parseRangesToRequests("6010:7010");
REQUIRE(pfsrs_single.size() == 1);
REQUIRE(pfsrs_single[0].has_source());
REQUIRE(pfsrs_single[0].has_destination());
REQUIRE((pfsrs_single[0].source().has_port() &&
pfsrs_single[0].source().port() == 6010));
REQUIRE((pfsrs_single[0].destination().has_port() &&
pfsrs_single[0].destination().port() == 7010));

// range src_port1-src_port2:dest_port1-dest_port2 forward
auto pfsrs_ranges = parseRangesToRequests("6010-6013:7010-7013");
REQUIRE(pfsrs_ranges.size() == 4);

// named pipe forward
auto pfsrs_named = parseRangesToRequests("envvar:/tmp/destination");
REQUIRE(pfsrs_named.size() == 1);
REQUIRE(!pfsrs_named[0].has_source());
REQUIRE(pfsrs_named[0].has_destination());
REQUIRE(pfsrs_named[0].has_environmentvariable());
}

TEST_CASE_METHOD(EndToEndTestFixture, "EndToEndTest", "[EndToEndTest]") {
readWriteTest("1234567890123456", routerSocketHandler, fakeUserTerminal,
serverEndpoint, clientSocketHandler, clientPipeSocketHandler,
Expand Down

0 comments on commit 10dd371

Please sign in to comment.