diff --git a/rosbag2_compression/CMakeLists.txt b/rosbag2_compression/CMakeLists.txt index 849cd4301..040e7be56 100644 --- a/rosbag2_compression/CMakeLists.txt +++ b/rosbag2_compression/CMakeLists.txt @@ -51,7 +51,8 @@ target_compile_definitions(${PROJECT_NAME}_zstd add_library(${PROJECT_NAME} SHARED - src/rosbag2_compression/compression_options.cpp) + src/rosbag2_compression/compression_options.cpp + src/rosbag2_compression/sequential_compression_reader.cpp) target_include_directories(${PROJECT_NAME} PUBLIC $ @@ -127,10 +128,14 @@ if(BUILD_TESTING) ament_add_gmock(test_compression_options test/rosbag2_compression/test_compression_options.cpp) - if(TARGET test_compression_options) - target_include_directories(test_compression_options PRIVATE include) - target_link_libraries(test_compression_options ${PROJECT_NAME}) - endif() + target_include_directories(test_compression_options PUBLIC include) + target_link_libraries(test_compression_options ${PROJECT_NAME}) + + ament_add_gmock(test_sequential_compression_reader + test/rosbag2_compression/test_sequential_compression_reader.cpp) + target_include_directories(test_sequential_compression_reader PUBLIC include) + target_link_libraries(test_sequential_compression_reader ${PROJECT_NAME}) + ament_target_dependencies(test_sequential_compression_reader rosbag2_cpp) endif() ament_package() diff --git a/rosbag2_compression/include/rosbag2_compression/sequential_compression_reader.hpp b/rosbag2_compression/include/rosbag2_compression/sequential_compression_reader.hpp new file mode 100644 index 000000000..bcd132132 --- /dev/null +++ b/rosbag2_compression/include/rosbag2_compression/sequential_compression_reader.hpp @@ -0,0 +1,153 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ROSBAG2_COMPRESSION__SEQUENTIAL_COMPRESSION_READER_HPP_ +#define ROSBAG2_COMPRESSION__SEQUENTIAL_COMPRESSION_READER_HPP_ + +#include +#include +#include + +#include "rosbag2_compression/base_decompressor_interface.hpp" +#include "rosbag2_compression/compression_options.hpp" + +#include "rosbag2_cpp/converter.hpp" +#include "rosbag2_cpp/reader_interfaces/base_reader_interface.hpp" +#include "rosbag2_cpp/serialization_format_converter_factory.hpp" +#include "rosbag2_cpp/serialization_format_converter_factory_interface.hpp" + +#include "rosbag2_storage/metadata_io.hpp" +#include "rosbag2_storage/storage_factory.hpp" +#include "rosbag2_storage/storage_factory_interface.hpp" +#include "rosbag2_storage/storage_interfaces/read_only_interface.hpp" + +#include "visibility_control.hpp" +#include "zstd_decompressor.hpp" + +#ifdef _WIN32 +# pragma warning(push) +# pragma warning(disable:4251) +#endif + +namespace rosbag2_compression +{ + +class ROSBAG2_COMPRESSION_PUBLIC SequentialCompressionReader + : public ::rosbag2_cpp::reader_interfaces::BaseReaderInterface +{ +public: + explicit SequentialCompressionReader( + std::unique_ptr storage_factory = + std::make_unique(), + std::shared_ptr converter_factory = + std::make_shared(), + std::unique_ptr metadata_io = + std::make_unique()); + + virtual ~SequentialCompressionReader(); + + void open( + const rosbag2_cpp::StorageOptions & storage_options, + const rosbag2_cpp::ConverterOptions & converter_options) override; + + void reset() override; + + bool has_next() override; + + std::shared_ptr read_next() override; + + std::vector get_all_topics_and_types() override; + + /** + * Ask whether there is another database file to read from the list of relative + * file paths. + * + * \return true if there are still files to read in the list + */ + virtual bool has_next_file() const; + + /** + * Return the relative file path pointed to by the current file iterator. + */ + virtual std::string get_current_file() const; + + /** + * Return the URI of the current file (i.e. no extensions). + */ + virtual std::string get_current_uri() const; + +protected: + /** + * Increment the current file iterator to point to the next file in the list of relative file + * paths. + * + * Expected usage: + * if (has_next_file()) load_next_file(); + */ + virtual void load_next_file(); + + /** + * Initializes the decompressor if a compression mode is specified in the metadata. + * + * \throws std::invalid_argument If compression format doesn't exist. + */ + virtual void setup_decompression(); + + /** + * Set the file path currently pointed to by the iterator. + * + * \param file Relative path to the file. + */ + virtual void set_current_file(const std::string & file); + +private: + /** + * Checks if all topics in the bagfile have the same RMW serialization format. + * Currently a bag file can only be played if all topics have the same serialization format. + * + * \param topics Vector of TopicInformation with metadata. + * \throws runtime_error if any topic has a different serialization format from the rest. + */ + virtual void check_topics_serialization_formats( + const std::vector & topics); + + /** + * Checks if the serialization format of the converter factory is the same as that of the storage + * factory. + * If not, changes the serialization format of the converter factory to use the serialization + * format of the storage factory. + * + * \param converter_serialization_format + * \param storage_serialization_format + */ + virtual void check_converter_serialization_format( + const std::string & converter_serialization_format, + const std::string & storage_serialization_format); + + std::unique_ptr storage_factory_{}; + std::shared_ptr converter_factory_{}; + std::shared_ptr storage_{}; + std::unique_ptr converter_{}; + std::unique_ptr decompressor_{}; + std::unique_ptr metadata_io_{}; + rosbag2_storage::BagMetadata metadata_{}; + std::vector file_paths_{}; // List of database files. + std::vector::iterator current_file_iterator_{}; // Index of file to read from + rosbag2_compression::CompressionMode compression_mode_{ + rosbag2_compression::CompressionMode::NONE}; +}; + +} // namespace rosbag2_compression + +#endif // ROSBAG2_COMPRESSION__SEQUENTIAL_COMPRESSION_READER_HPP_ diff --git a/rosbag2_compression/src/rosbag2_compression/sequential_compression_reader.cpp b/rosbag2_compression/src/rosbag2_compression/sequential_compression_reader.cpp new file mode 100644 index 000000000..46181e643 --- /dev/null +++ b/rosbag2_compression/src/rosbag2_compression/sequential_compression_reader.cpp @@ -0,0 +1,216 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rosbag2_compression/sequential_compression_reader.hpp" + +#include +#include +#include +#include +#include + +#include "rcpputils/filesystem_helper.hpp" + +#include "rosbag2_compression/compression_options.hpp" +#include "rosbag2_compression/zstd_decompressor.hpp" +#include "logging.hpp" + + +namespace rosbag2_compression +{ + +SequentialCompressionReader::SequentialCompressionReader( + std::unique_ptr storage_factory, + std::shared_ptr converter_factory, + std::unique_ptr metadata_io) +: storage_factory_{std::move(storage_factory)}, + converter_factory_{std::move(converter_factory)}, + metadata_io_{std::move(metadata_io)} +{} + +SequentialCompressionReader::~SequentialCompressionReader() +{ + reset(); +} + +void SequentialCompressionReader::reset() +{ + storage_.reset(); +} + +void SequentialCompressionReader::setup_decompression() +{ + compression_mode_ = rosbag2_compression::compression_mode_from_string(metadata_.compression_mode); + if (compression_mode_ != rosbag2_compression::CompressionMode::NONE) { + if (metadata_.compression_format == "zstd") { + decompressor_ = std::make_unique(); + // Decompress the first file so that it is readable. + ROSBAG2_COMPRESSION_LOG_DEBUG_STREAM("Decompressing " << get_current_file().c_str()); + set_current_file(decompressor_->decompress_uri(get_current_file())); + } else { + std::stringstream err; + err << "Unsupported compression format " << metadata_.compression_format; + throw std::invalid_argument{err.str()}; + } + } else { + throw std::invalid_argument{ + "SequentialCompressionReader requires a CompressionMode that is not NONE!"}; + } +} + +void SequentialCompressionReader::open( + const rosbag2_cpp::StorageOptions & storage_options, + const rosbag2_cpp::ConverterOptions & converter_options) +{ + if (metadata_io_->metadata_file_exists(storage_options.uri)) { + metadata_ = metadata_io_->read_metadata(storage_options.uri); + if (metadata_.relative_file_paths.empty()) { + ROSBAG2_COMPRESSION_LOG_WARN("No file paths were found in metadata."); + return; + } + file_paths_ = metadata_.relative_file_paths; + current_file_iterator_ = file_paths_.begin(); + setup_decompression(); + + storage_ = storage_factory_->open_read_only( + *current_file_iterator_, metadata_.storage_identifier); + if (!storage_) { + throw std::runtime_error{"No storage could be initialized. Abort"}; + } + } else { + throw std::runtime_error{"Compression is not supported for legacy bag files."}; + } + const auto & topics = metadata_.topics_with_message_count; + if (topics.empty()) { + ROSBAG2_COMPRESSION_LOG_WARN("No topics were listed in metadata."); + return; + } + + // Currently a bag file can only be played if all topics have the same serialization format. + check_topics_serialization_formats(topics); + check_converter_serialization_format( + converter_options.output_serialization_format, + topics[0].topic_metadata.serialization_format); +} + +bool SequentialCompressionReader::has_next() +{ + if (storage_) { + // If there's no new message, check if there's at least another file to read and update storage + // to read from there. Otherwise, check if there's another message. + if (!storage_->has_next() && has_next_file()) { + load_next_file(); + storage_ = storage_factory_->open_read_only( + *current_file_iterator_, metadata_.storage_identifier); + } + return storage_->has_next(); + } + throw std::runtime_error{"Bag is not open. Call open() before reading."}; +} + +std::shared_ptr SequentialCompressionReader::read_next() +{ + if (storage_ && decompressor_) { + auto message = storage_->read_next(); + if (compression_mode_ == rosbag2_compression::CompressionMode::MESSAGE) { + decompressor_->decompress_serialized_bag_message(message.get()); + } + return converter_ ? converter_->convert(message) : message; + } + throw std::runtime_error{"Bag is not open. Call open() before reading."}; +} + +std::vector SequentialCompressionReader::get_all_topics_and_types() +{ + if (storage_) { + return storage_->get_all_topics_and_types(); + } + throw std::runtime_error{"Bag is not open. Call open() before reading."}; +} + +bool SequentialCompressionReader::has_next_file() const +{ + // Handle case where bagfile is not split + if (current_file_iterator_ == file_paths_.end()) { + return false; + } + + return current_file_iterator_ + 1 != file_paths_.end(); +} + +void SequentialCompressionReader::load_next_file() +{ + if (current_file_iterator_ == file_paths_.end()) { + throw std::runtime_error{"Cannot load next file; already on last file!"}; + } + + if (decompressor_) { + throw std::runtime_error{"Bagfile is not opened"}; + } + + ++current_file_iterator_; + if (compression_mode_ == rosbag2_compression::CompressionMode::FILE) { + ROSBAG2_COMPRESSION_LOG_DEBUG_STREAM("Decompressing " << get_current_file().c_str()); + *current_file_iterator_ = decompressor_->decompress_uri(get_current_file()); + } +} + +std::string SequentialCompressionReader::get_current_uri() const +{ + const auto current_file = get_current_file(); + const auto current_uri = rcpputils::fs::remove_extension(current_file); + return current_uri.string(); +} + +std::string SequentialCompressionReader::get_current_file() const +{ + return *current_file_iterator_; +} + +void SequentialCompressionReader::set_current_file(const std::string & file) +{ + *current_file_iterator_ = file; +} + +void SequentialCompressionReader::check_topics_serialization_formats( + const std::vector & topics) +{ + const auto & storage_serialization_format = + topics[0].topic_metadata.serialization_format; + + for (const auto & topic : topics) { + if (topic.topic_metadata.serialization_format != storage_serialization_format) { + throw std::runtime_error{ + "Topics with different rwm serialization format have been found. " + "All topics must have the same serialization format."}; + } + } +} + +void SequentialCompressionReader::check_converter_serialization_format( + const std::string & converter_serialization_format, + const std::string & storage_serialization_format) +{ + if (converter_serialization_format != storage_serialization_format) { + converter_ = std::make_unique( + storage_serialization_format, + converter_serialization_format, + converter_factory_); + const auto topics = storage_->get_all_topics_and_types(); + for (const auto & topic_with_type : topics) { + converter_->add_topic(topic_with_type.name, topic_with_type.type); + } + } +} +} // namespace rosbag2_compression diff --git a/rosbag2_compression/test/rosbag2_compression/test_sequential_compression_reader.cpp b/rosbag2_compression/test/rosbag2_compression/test_sequential_compression_reader.cpp new file mode 100644 index 000000000..15752b91b --- /dev/null +++ b/rosbag2_compression/test/rosbag2_compression/test_sequential_compression_reader.cpp @@ -0,0 +1,103 @@ +// Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include + +#include "rosbag2_compression/sequential_compression_reader.hpp" + +#include "rosbag2_cpp/reader.hpp" + +#include "../../rosbag2_cpp/test/rosbag2_cpp/mock_converter_factory.hpp" +#include "../../rosbag2_cpp/test/rosbag2_cpp/mock_metadata_io.hpp" +#include "../../rosbag2_cpp/test/rosbag2_cpp/mock_storage.hpp" +#include "../../rosbag2_cpp/test/rosbag2_cpp/mock_storage_factory.hpp" + +using namespace testing; // NOLINT + +class SequentialCompressionReaderTest : public Test +{ +public: + SequentialCompressionReaderTest() + : storage_factory_{std::make_unique>()}, + storage_{std::make_shared>()}, + converter_factory_{std::make_shared>()}, + metadata_io_{std::make_unique>()}, + storage_serialization_format_{"rmw1_format"} + { + topic_with_type_ = rosbag2_storage::TopicMetadata{ + "topic", "test_msgs/BasicTypes", storage_serialization_format_}; + auto topics_and_types = std::vector{topic_with_type_}; + auto message = std::make_shared(); + message->topic_name = topic_with_type_.name; + + ON_CALL(*storage_, get_all_topics_and_types()).WillByDefault(Return(topics_and_types)); + ON_CALL(*storage_, read_next()).WillByDefault(Return(message)); + ON_CALL(*storage_factory_, open_read_only(_, _)).WillByDefault(Return(storage_)); + } + + std::unique_ptr> storage_factory_; + std::shared_ptr> storage_; + std::shared_ptr> converter_factory_; + std::unique_ptr> metadata_io_; + std::unique_ptr reader_; + std::string storage_serialization_format_; + rosbag2_storage::TopicMetadata topic_with_type_; +}; + +TEST_F(SequentialCompressionReaderTest, open_throws_if_unsupported_compressor) +{ + rosbag2_storage::BagMetadata metadata; + metadata.relative_file_paths = {"/path/to/storage"}; + metadata.topics_with_message_count.push_back({{topic_with_type_}, 1}); + metadata.compression_format = "bad_format"; + metadata.compression_mode = + rosbag2_compression::compression_mode_to_string(rosbag2_compression::CompressionMode::FILE); + EXPECT_CALL(*metadata_io_, read_metadata(_)).WillRepeatedly(Return(metadata)); + EXPECT_CALL(*metadata_io_, metadata_file_exists(_)).WillRepeatedly(Return(true)); + + auto sequential_reader = std::make_unique( + std::move(storage_factory_), converter_factory_, std::move(metadata_io_)); + + reader_ = std::make_unique(std::move(sequential_reader)); + EXPECT_THROW( + reader_->open(rosbag2_cpp::StorageOptions(), {"", storage_serialization_format_}), + std::invalid_argument); +} + +TEST_F(SequentialCompressionReaderTest, open_supports_zstd_compressor) +{ + rosbag2_storage::BagMetadata metadata; + metadata.relative_file_paths = {"/path/to/storage"}; + metadata.topics_with_message_count.push_back({{topic_with_type_}, 1}); + metadata.compression_format = "zstd"; + metadata.compression_mode = + rosbag2_compression::compression_mode_to_string(rosbag2_compression::CompressionMode::FILE); + ON_CALL(*metadata_io_, read_metadata(_)).WillByDefault(Return(metadata)); + ON_CALL(*metadata_io_, metadata_file_exists(_)).WillByDefault(Return(true)); + + auto sequential_reader = std::make_unique( + std::move(storage_factory_), converter_factory_, std::move(metadata_io_)); + + reader_ = std::make_unique(std::move(sequential_reader)); + // Throws runtime_error b/c compressor can't read + // TODO(piraka9011): Use a compression factory in reader. + EXPECT_THROW( + reader_->open(rosbag2_cpp::StorageOptions(), {"", storage_serialization_format_}), + std::runtime_error); +}