Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/libstore/daemon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1025,19 +1025,20 @@ void processConnection(
#endif

/* Exchange the greeting. */
WorkerProto::Version clientVersion =
auto [protoVersion, features] =
WorkerProto::BasicServerConnection::handshake(
to, from, PROTOCOL_VERSION);
to, from, PROTOCOL_VERSION, WorkerProto::allFeatures);

if (clientVersion < 0x10a)
if (protoVersion < 0x10a)
throw Error("the Nix client version is too old");

WorkerProto::BasicServerConnection conn;
conn.to = std::move(to);
conn.from = std::move(from);
conn.protoVersion = clientVersion;
conn.protoVersion = protoVersion;
conn.features = features;

auto tunnelLogger = new TunnelLogger(conn.to, clientVersion);
auto tunnelLogger = new TunnelLogger(conn.to, protoVersion);
auto prevLogger = nix::logger;
// FIXME
if (!recursive)
Expand Down
10 changes: 8 additions & 2 deletions src/libstore/remote-store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ void RemoteStore::initConnection(Connection & conn)
StringSink saved;
TeeSource tee(conn.from, saved);
try {
conn.protoVersion = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION);
auto [protoVersion, features] = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION,
WorkerProto::allFeatures);
conn.protoVersion = protoVersion;
conn.features = features;
} catch (SerialisationError & e) {
/* In case the other side is waiting for our input, close
it. */
Expand All @@ -88,6 +91,9 @@ void RemoteStore::initConnection(Connection & conn)

static_cast<WorkerProto::ClientHandshakeInfo &>(conn) = conn.postHandshake(*this);

for (auto & feature : conn.features)
debug("negotiated feature '%s'", feature);

auto ex = conn.processStderrReturn();
if (ex) std::rethrow_exception(ex);
}
Expand Down
51 changes: 45 additions & 6 deletions src/libstore/worker-protocol-connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

namespace nix {

const std::set<WorkerProto::Feature> WorkerProto::allFeatures{};

WorkerProto::BasicClientConnection::~BasicClientConnection()
{
try {
Expand Down Expand Up @@ -137,8 +139,21 @@ void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, S
}
}

WorkerProto::Version
WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
static std::set<WorkerProto::Feature>
intersectFeatures(const std::set<WorkerProto::Feature> & a, const std::set<WorkerProto::Feature> & b)
{
std::set<WorkerProto::Feature> res;
for (auto & x : a)
if (b.contains(x))
res.insert(x);
return res;
}

std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicClientConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
to << WORKER_MAGIC_1 << localVersion;
to.flush();
Expand All @@ -153,19 +168,43 @@ WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from,
if (GET_PROTOCOL_MINOR(daemonVersion) < 10)
throw Error("the Nix daemon version is too old");

return std::min(daemonVersion, localVersion);
auto protoVersion = std::min(daemonVersion, localVersion);

/* Exchange features. */
std::set<WorkerProto::Feature> daemonFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
to << supportedFeatures;
to.flush();
daemonFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
}

return {protoVersion, intersectFeatures(daemonFeatures, supportedFeatures)};
}

WorkerProto::Version
WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicServerConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
unsigned int magic = readInt(from);
if (magic != WORKER_MAGIC_1)
throw Error("protocol mismatch");
to << WORKER_MAGIC_2 << localVersion;
to.flush();
auto clientVersion = readInt(from);
return std::min(clientVersion, localVersion);

auto protoVersion = std::min(clientVersion, localVersion);

/* Exchange features. */
std::set<WorkerProto::Feature> clientFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
clientFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
to << supportedFeatures;
to.flush();
}

return {protoVersion, intersectFeatures(clientFeatures, supportedFeatures)};
}

WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store)
Expand Down
27 changes: 23 additions & 4 deletions src/libstore/worker-protocol-connection.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ struct WorkerProto::BasicConnection
*/
WorkerProto::Version protoVersion;

/**
* The set of features that both sides support.
*/
std::set<Feature> features;

/**
* Coercion to `WorkerProto::ReadConn`. This makes it easy to use the
* factored out serve protocol serializers with a
Expand Down Expand Up @@ -72,8 +77,8 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
/**
* Establishes connection, negotiating version.
*
* @return the version provided by the other side of the
* connection.
* @return the minimum version supported by both sides and the set
* of protocol features supported by both sides.
*
* @param to Taken by reference to allow for various error handling
* mechanisms.
Expand All @@ -82,8 +87,15 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);

/**
* After calling handshake, must call this to exchange some basic
Expand Down Expand Up @@ -138,8 +150,15 @@ struct WorkerProto::BasicServerConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);

/**
* After calling handshake, must call this to exchange some basic
Expand Down
8 changes: 7 additions & 1 deletion src/libstore/worker-protocol.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ namespace nix {
#define WORKER_MAGIC_1 0x6e697863
#define WORKER_MAGIC_2 0x6478696f

#define PROTOCOL_VERSION (1 << 8 | 37)
/* Note: you generally shouldn't change the protocol version. Define a
new `WorkerProto::Feature` instead. */
#define PROTOCOL_VERSION (1 << 8 | 38)
#define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00)
#define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff)

Expand Down Expand Up @@ -131,6 +133,10 @@ struct WorkerProto
{
WorkerProto::Serialise<T>::write(store, conn, t);
}

using Feature = std::string;

static const std::set<Feature> allFeatures;
};

enum struct WorkerProto::Op : uint64_t
Expand Down
49 changes: 38 additions & 11 deletions tests/unit/libstore/worker-protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,15 +658,15 @@ TEST_F(WorkerProtoTest, handshake_log)
FdSink out { toServer.writeSide.get() };
FdSource in0 { toClient.readSide.get() };
TeeSource in { in0, toClientLog };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion);
clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion, {}));
});

{
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
WorkerProto::BasicServerConnection::handshake(
out, in, defaultVersion);
out, in, defaultVersion, {});
};

thread.join();
Expand All @@ -675,6 +675,33 @@ TEST_F(WorkerProtoTest, handshake_log)
});
}

TEST_F(WorkerProtoTest, handshake_features)
{
Pipe toClient, toServer;
toClient.create();
toServer.create();

std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> clientResult;

auto clientThread = std::thread([&]() {
FdSink out { toServer.writeSide.get() };
FdSource in { toClient.readSide.get() };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, 123, {"bar", "aap", "mies", "xyzzy"});
});

FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
auto daemonResult = WorkerProto::BasicServerConnection::handshake(
out, in, 456, {"foo", "bar", "xyzzy"});

clientThread.join();

EXPECT_EQ(clientResult, daemonResult);
EXPECT_EQ(std::get<0>(clientResult), 123);
EXPECT_EQ(std::get<1>(clientResult), std::set<WorkerProto::Feature>({"bar", "xyzzy"}));
}

/// Has to be a `BufferedSink` for handshake.
struct NullBufferedSink : BufferedSink {
void writeUnbuffered(std::string_view data) override { }
Expand All @@ -686,8 +713,8 @@ TEST_F(WorkerProtoTest, handshake_client_replay)
NullBufferedSink nullSink;

StringSource in { toClientLog };
auto clientResult = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));

EXPECT_EQ(clientResult, defaultVersion);
});
Expand All @@ -705,13 +732,13 @@ TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws)
if (len < 8) {
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
EndOfFile);
} else {
// Not sure why cannot keep on checking for `EndOfFile`.
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
}
}
Expand All @@ -734,17 +761,17 @@ TEST_F(WorkerProtoTest, handshake_client_corrupted_throws)
// magic bytes don't match
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
} else if (idx < 8 || idx >= 12) {
// Number out of bounds
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
SerialisationError);
} else {
auto ver = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto ver = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));
// `std::min` of this and the other version saves us
EXPECT_EQ(ver, defaultVersion);
}
Expand Down