diff --git a/source/extensions/transport_sockets/alts/BUILD b/source/extensions/transport_sockets/alts/BUILD index 28cc6960e7154..da086ff2d6e2f 100644 --- a/source/extensions/transport_sockets/alts/BUILD +++ b/source/extensions/transport_sockets/alts/BUILD @@ -24,6 +24,21 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "tsi_frame_protector", + srcs = [ + "tsi_frame_protector.cc", + ], + hdrs = [ + "tsi_frame_protector.h", + ], + repository = "@envoy", + deps = [ + ":grpc_tsi_wrapper", + "//source/common/buffer:buffer_lib", + ], +) + envoy_cc_library( name = "tsi_handshaker", srcs = [ diff --git a/source/extensions/transport_sockets/alts/tsi_frame_protector.cc b/source/extensions/transport_sockets/alts/tsi_frame_protector.cc new file mode 100644 index 0000000000000..1cb8cc22494b1 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_frame_protector.cc @@ -0,0 +1,77 @@ +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" + +#include "common/common/assert.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +// TODO(lizan): tune size later +static constexpr uint32_t BUFFER_SIZE = 16384; + +TsiFrameProtector::TsiFrameProtector(CFrameProtectorPtr&& frame_protector) + : frame_protector_(std::move(frame_protector)) {} + +tsi_result TsiFrameProtector::protect(Buffer::Instance& input, Buffer::Instance& output) { + ASSERT(frame_protector_); + + unsigned char protected_buffer[BUFFER_SIZE]; + while (input.length() > 0) { + auto* message_bytes = reinterpret_cast(input.linearize(input.length())); + size_t protected_buffer_size = BUFFER_SIZE; + size_t processed_message_size = input.length(); + tsi_result result = + tsi_frame_protector_protect(frame_protector_.get(), message_bytes, &processed_message_size, + protected_buffer, &protected_buffer_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(protected_buffer, protected_buffer_size); + input.drain(processed_message_size); + } + + // TSI may buffer some of the input internally. Flush its buffer to protected_buffer. + size_t still_pending_size; + do { + size_t protected_buffer_size = BUFFER_SIZE; + tsi_result result = tsi_frame_protector_protect_flush( + frame_protector_.get(), protected_buffer, &protected_buffer_size, &still_pending_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(protected_buffer, protected_buffer_size); + } while (still_pending_size > 0); + + return TSI_OK; +} + +tsi_result TsiFrameProtector::unprotect(Buffer::Instance& input, Buffer::Instance& output) { + ASSERT(frame_protector_); + + unsigned char unprotected_buffer[BUFFER_SIZE]; + + while (input.length() > 0) { + auto* message_bytes = reinterpret_cast(input.linearize(input.length())); + size_t unprotected_buffer_size = BUFFER_SIZE; + size_t processed_message_size = input.length(); + tsi_result result = tsi_frame_protector_unprotect(frame_protector_.get(), message_bytes, + &processed_message_size, unprotected_buffer, + &unprotected_buffer_size); + if (result != TSI_OK) { + ASSERT(result != TSI_INVALID_ARGUMENT && result != TSI_UNIMPLEMENTED); + return result; + } + output.add(unprotected_buffer, unprotected_buffer_size); + input.drain(processed_message_size); + } + + return TSI_OK; +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/alts/tsi_frame_protector.h b/source/extensions/transport_sockets/alts/tsi_frame_protector.h new file mode 100644 index 0000000000000..ac2fe1fc8f7f2 --- /dev/null +++ b/source/extensions/transport_sockets/alts/tsi_frame_protector.h @@ -0,0 +1,49 @@ +#pragma once + +#include "envoy/buffer/buffer.h" + +#include "extensions/transport_sockets/alts/grpc_tsi.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +/** + * A C++ wrapper for tsi_frame_protector interface. + * For detail of tsi_frame_protector, see + * https://github.com/grpc/grpc/blob/v1.10.0/src/core/tsi/transport_security_interface.h#L70 + * + * TODO(lizan): migrate to tsi_zero_copy_grpc_protector for further optimization + */ +class TsiFrameProtector final { +public: + explicit TsiFrameProtector(CFrameProtectorPtr&& frame_protector); + + /** + * Wrapper for tsi_frame_protector_protect + * @param input supplies the input data to protect, the method will drain it when it is processed. + * @param output supplies the buffer where the protected data will be stored. + * @return tsi_result the status. + */ + tsi_result protect(Buffer::Instance& input, Buffer::Instance& output); + + /** + * Wrapper for tsi_frame_protector_unprotect + * @param input supplies the input data to unprotect, the method will drain it when it is + * processed. + * @param output supplies the buffer where the unprotected data will be stored. + * @return tsi_result the status. + */ + tsi_result unprotect(Buffer::Instance& input, Buffer::Instance& output); + +private: + CFrameProtectorPtr frame_protector_; +}; + +typedef std::unique_ptr TsiFrameProtectorPtr; + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/transport_sockets/alts/BUILD b/test/extensions/transport_sockets/alts/BUILD index 87a6a480ff3ff..171a97fd28e43 100644 --- a/test/extensions/transport_sockets/alts/BUILD +++ b/test/extensions/transport_sockets/alts/BUILD @@ -11,6 +11,16 @@ load( envoy_package() +envoy_extension_cc_test( + name = "tsi_frame_protector_test", + srcs = ["tsi_frame_protector_test.cc"], + extension_name = "envoy.transport_sockets.alts", + deps = [ + "//source/extensions/transport_sockets/alts:tsi_frame_protector", + "//test/mocks/buffer:buffer_mocks", + ], +) + envoy_extension_cc_test( name = "tsi_handshaker_test", srcs = ["tsi_handshaker_test.cc"], diff --git a/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc b/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc new file mode 100644 index 0000000000000..2604e837c8cd7 --- /dev/null +++ b/test/extensions/transport_sockets/alts/tsi_frame_protector_test.cc @@ -0,0 +1,150 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/transport_sockets/alts/tsi_frame_protector.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "src/core/tsi/fake_transport_security.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { + +using testing::InSequence; +using testing::Invoke; +using testing::NiceMock; +using testing::SaveArg; +using testing::Test; +using testing::_; +using namespace std::string_literals; + +/** + * Test with fake frame protector. The protected frame header is 4 byte length (little endian, + * include header itself) and following the body. + */ +class TsiFrameProtectorTest : public Test { +public: + TsiFrameProtectorTest() + : raw_frame_protector_(tsi_create_fake_frame_protector(nullptr)), + frame_protector_(CFrameProtectorPtr{raw_frame_protector_}) {} + +protected: + tsi_frame_protector* raw_frame_protector_; + TsiFrameProtector frame_protector_; +}; + +TEST_F(TsiFrameProtectorTest, Protect) { + { + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString()); + } + + { + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo"s, encrypted.toString()); + + input.add("bar"); + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + EXPECT_EQ("\x07\0\0\0foo\x07\0\0\0bar"s, encrypted.toString()); + } + + { + Buffer::OwnedImpl input, encrypted; + input.add(std::string(20000, 'a')); + + EXPECT_EQ(TSI_OK, frame_protector_.protect(input, encrypted)); + + // fake frame protector will split long buffer to 2 "encrypted" frames with length 16K. + std::string expected = + "\0\x40\0\0"s + std::string(16380, 'a') + "\x28\x0e\0\0"s + std::string(3620, 'a'); + EXPECT_EQ(expected, encrypted.toString()); + } +} + +TEST_F(TsiFrameProtectorTest, ProtectError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.protect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*, + size_t*) { return TSI_INTERNAL_ERROR; }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted)); + + raw_frame_protector_->vtable = vtable; +} + +TEST_F(TsiFrameProtectorTest, ProtectFlushError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.protect_flush = [](tsi_frame_protector*, unsigned char*, size_t*, size_t*) { + return TSI_INTERNAL_ERROR; + }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, encrypted; + input.add("foo"); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.protect(input, encrypted)); + + raw_frame_protector_->vtable = vtable; +} + +TEST_F(TsiFrameProtectorTest, Unprotect) { + { + Buffer::OwnedImpl input, decrypted; + input.add("\x07\0\0\0bar"s); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("bar", decrypted.toString()); + } + + { + Buffer::OwnedImpl input, decrypted; + input.add("\x0a\0\0\0foo"s); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("", decrypted.toString()); + + input.add("bar"); + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ("foobar", decrypted.toString()); + } + + { + Buffer::OwnedImpl input, decrypted; + input.add("\0\x40\0\0"s + std::string(16380, 'a')); + input.add("\x28\x0e\0\0"s + std::string(3620, 'a')); + + EXPECT_EQ(TSI_OK, frame_protector_.unprotect(input, decrypted)); + EXPECT_EQ(std::string(20000, 'a'), decrypted.toString()); + } +} +TEST_F(TsiFrameProtectorTest, UnprotectError) { + const tsi_frame_protector_vtable* vtable = raw_frame_protector_->vtable; + tsi_frame_protector_vtable mock_vtable = *raw_frame_protector_->vtable; + mock_vtable.unprotect = [](tsi_frame_protector*, const unsigned char*, size_t*, unsigned char*, + size_t*) { return TSI_INTERNAL_ERROR; }; + raw_frame_protector_->vtable = &mock_vtable; + + Buffer::OwnedImpl input, decrypted; + input.add("\x0a\0\0\0foo"s); + + EXPECT_EQ(TSI_INTERNAL_ERROR, frame_protector_.unprotect(input, decrypted)); + + raw_frame_protector_->vtable = vtable; +} + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy