Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions source/extensions/transport_sockets/alts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "tsi_frame_protector",
srcs = [
"tsi_frame_protector.cc",
],
hdrs = [
"tsi_frame_protector.h",
],
repository = "@envoy",
deps = [
":grpc_tsi_wrapper",
"//source/common/buffer:buffer_lib",
],
)

envoy_cc_library(
name = "tsi_handshaker",
srcs = [
Expand Down
77 changes: 77 additions & 0 deletions source/extensions/transport_sockets/alts/tsi_frame_protector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "extensions/transport_sockets/alts/tsi_frame_protector.h"

#include "common/common/assert.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Alts {

// TODO(lizan): tune size later
static constexpr uint32_t BUFFER_SIZE = 16384;

TsiFrameProtector::TsiFrameProtector(CFrameProtectorPtr&& frame_protector)
: frame_protector_(std::move(frame_protector)) {}

tsi_result TsiFrameProtector::protect(Buffer::Instance& input, Buffer::Instance& output) {
ASSERT(frame_protector_);

unsigned char protected_buffer[BUFFER_SIZE];
while (input.length() > 0) {
auto* message_bytes = reinterpret_cast<unsigned char*>(input.linearize(input.length()));
size_t protected_buffer_size = BUFFER_SIZE;
size_t processed_message_size = input.length();
tsi_result result =
tsi_frame_protector_protect(frame_protector_.get(), message_bytes, &processed_message_size,
protected_buffer, &protected_buffer_size);
if (result != TSI_OK) {
ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED);
return result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
output.add(protected_buffer, protected_buffer_size);
input.drain(processed_message_size);
}

// TSI may buffer some of the input internally. Flush its buffer to protected_buffer.
size_t still_pending_size;
do {
size_t protected_buffer_size = BUFFER_SIZE;
tsi_result result = tsi_frame_protector_protect_flush(
frame_protector_.get(), protected_buffer, &protected_buffer_size, &still_pending_size);
if (result != TSI_OK) {
ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED);
return result;
}
output.add(protected_buffer, protected_buffer_size);
} while (still_pending_size > 0);

return TSI_OK;
}

tsi_result TsiFrameProtector::unprotect(Buffer::Instance& input, Buffer::Instance& output) {
ASSERT(frame_protector_);

unsigned char unprotected_buffer[BUFFER_SIZE];

while (input.length() > 0) {
auto* message_bytes = reinterpret_cast<unsigned char*>(input.linearize(input.length()));
size_t unprotected_buffer_size = BUFFER_SIZE;
size_t processed_message_size = input.length();
tsi_result result = tsi_frame_protector_unprotect(frame_protector_.get(), message_bytes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sanity check, there's no flushing problem in this direction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is no tsi_frame_protector_unprotect_flush.

&processed_message_size, unprotected_buffer,
&unprotected_buffer_size);
if (result != TSI_OK) {
ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED);
return result;
}
output.add(unprotected_buffer, unprotected_buffer_size);
input.drain(processed_message_size);
}

return TSI_OK;
}

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy
49 changes: 49 additions & 0 deletions source/extensions/transport_sockets/alts/tsi_frame_protector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "envoy/buffer/buffer.h"

#include "extensions/transport_sockets/alts/grpc_tsi.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Alts {

/**
* A C++ wrapper for tsi_frame_protector interface.
* For detail of tsi_frame_protector, see
* https://github.com/grpc/grpc/blob/v1.10.0/src/core/tsi/transport_security_interface.h#L70
*
* TODO(lizan): migrate to tsi_zero_copy_grpc_protector for further optimization
*/
class TsiFrameProtector final {
public:
explicit TsiFrameProtector(CFrameProtectorPtr&& frame_protector);

/**
* Wrapper for tsi_frame_protector_protect
* @param input supplies the input data to protect, the method will drain it when it is processed.
* @param output supplies the buffer where the protected data will be stored.
* @return tsi_result the status.
*/
tsi_result protect(Buffer::Instance& input, Buffer::Instance& output);

/**
* Wrapper for tsi_frame_protector_unprotect
* @param input supplies the input data to unprotect, the method will drain it when it is
* processed.
* @param output supplies the buffer where the unprotected data will be stored.
* @return tsi_result the status.
*/
tsi_result unprotect(Buffer::Instance& input, Buffer::Instance& output);

private:
CFrameProtectorPtr frame_protector_;
};

typedef std::unique_ptr<TsiFrameProtector> TsiFrameProtectorPtr;

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy
10 changes: 10 additions & 0 deletions test/extensions/transport_sockets/alts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ load(

envoy_package()

envoy_extension_cc_test(
name = "tsi_frame_protector_test",
srcs = ["tsi_frame_protector_test.cc"],
extension_name = "envoy.transport_sockets.alts",
deps = [
"//source/extensions/transport_sockets/alts:tsi_frame_protector",
"//test/mocks/buffer:buffer_mocks",
],
)

envoy_extension_cc_test(
name = "tsi_handshaker_test",
srcs = ["tsi_handshaker_test.cc"],
Expand Down
150 changes: 150 additions & 0 deletions test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#include "common/buffer/buffer_impl.h"

#include "extensions/transport_sockets/alts/tsi_frame_protector.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "src/core/tsi/fake_transport_security.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Alts {

using testing::InSequence;
using testing::Invoke;
using testing::NiceMock;
using testing::SaveArg;
using testing::Test;
using testing::_;
using namespace std::string_literals;

/**
* Test with fake frame protector. The protected frame header is 4 byte length (little endian,
* include header itself) and following the body.
*/
class TsiFrameProtectorTest : public Test {
public:
TsiFrameProtectorTest()
: raw_frame_protector_(tsi_create_fake_frame_protector(nullptr)),
frame_protector_(CFrameProtectorPtr{raw_frame_protector_}) {}

protected:
tsi_frame_protector* raw_frame_protector_;
TsiFrameProtector frame_protector_;
};

TEST_F(TsiFrameProtectorTest, Protect) {
{
Buffer::OwnedImpl input, encrypted;
input.add("foo");

EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits: comment where the prefix "\x07\0\0\0" comes from? I guess fake_frame_protector has a contract about how to generate it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is comment at top of this file (L23)

}

{
Buffer::OwnedImpl input, encrypted;
input.add("foo");

EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString());

input.add("bar");
EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));
EXPECT_EQ("\x07\0\0\0foo\x07\0\0\0bar"s, encrypted.toString());
}

{
Buffer::OwnedImpl input, encrypted;
input.add(std::string(20000, 'a'));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a const for protected_buffer_size, you can use a relative larger number here instead of hard coding 20000.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a coincidence that the frame size generated by fake frame protector is max at 16K and the buffer size is 16K, the hard code here is not reflecting the buffer size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test case is testing protect() with |input| size larger than the buffer. So this line can be something more general, like: input.add(std::string(BUFFER_SIZE + 3620, 'a')); Am I misunderstanding?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The frame size of fake frame protector output is define here, it is coincidence that is same as BUFFER_SIZE, but the test should be larger than the frame size, it is not depends on BUFFER_SIZE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it testing against frame size which is not our implementation? I thought you meant to test protect() with input larger than |protected_buffer| which I think is worth to add if not covered yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the expected output below (L63) is depends on frame size in fake frame protector, the input here is to exceed the BUFFER_SIZE and they are just happened to be same.

We don't have visibility of BUFFER_SIZE in test.cc so I'll have to define BUFFER_SIZE, input.add(std::string(BUFFER_SIZE + 3620, 'a')) won't work anyway.


EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted));

// fake frame protector will split long buffer to 2 "encrypted" frames with length 16K.
std::string expected =
"\0\x40\0\0"s + std::string(16380, 'a') + "\x28\x0e\0\0"s + std::string(3620, 'a');
EXPECT_EQ(expected, encrypted.toString());
}
}

TEST_F(TsiFrameProtectorTest, ProtectError) {
const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable;
tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable;
mock_vtable.protect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*,
size_t*) { return TSI_INTERNAL_ERROR; };
raw_frame_protector_->vtable = &mock_vtable;

Buffer::OwnedImpl input, encrypted;
input.add("foo");

EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted));

raw_frame_protector_->vtable = vtable;
}

TEST_F(TsiFrameProtectorTest, ProtectFlushError) {
const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable;
tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable;
mock_vtable.protect_flush = [](tsi_frame_protector*, unsigned char*, size_t*, size_t*) {
return TSI_INTERNAL_ERROR;
};
raw_frame_protector_->vtable = &mock_vtable;

Buffer::OwnedImpl input, encrypted;
input.add("foo");

EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted));

raw_frame_protector_->vtable = vtable;
}

TEST_F(TsiFrameProtectorTest, Unprotect) {
{
Buffer::OwnedImpl input, decrypted;
input.add("\x07\0\0\0bar"s);

EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted));
EXPECT_EQ("bar", decrypted.toString());
}

{
Buffer::OwnedImpl input, decrypted;
input.add("\x0a\0\0\0foo"s);

EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted));
EXPECT_EQ("", decrypted.toString());

input.add("bar");
EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted));
EXPECT_EQ("foobar", decrypted.toString());
}

{
Buffer::OwnedImpl input, decrypted;
input.add("\0\x40\0\0"s + std::string(16380, 'a'));
input.add("\x28\x0e\0\0"s + std::string(3620, 'a'));

EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted));
EXPECT_EQ(std::string(20000, 'a'), decrypted.toString());
}
}
TEST_F(TsiFrameProtectorTest, UnprotectError) {
const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable;
tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable;
mock_vtable.unprotect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*,
size_t*) { return TSI_INTERNAL_ERROR; };
raw_frame_protector_->vtable = &mock_vtable;

Buffer::OwnedImpl input, decrypted;
input.add("\x0a\0\0\0foo"s);

EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.unprotect(input, decrypted));

raw_frame_protector_->vtable = vtable;
}

} // namespace Alts
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy