Skip to content

Commit

Permalink
Refactor bit_reader to use policy
Browse files Browse the repository at this point in the history
  • Loading branch information
KredeGC committed Dec 16, 2023
1 parent f3116e5 commit f8ef797
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 130 deletions.
7 changes: 3 additions & 4 deletions include/bitstream/stream/bit_measure.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@

#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <type_traits>

namespace bitstream
{
class bit_reader;

/**
* @brief A stream for writing objects tightly into a buffer
* @note Does not take ownership of the buffer
Expand All @@ -31,13 +30,13 @@ namespace bitstream
*/
bit_measure() noexcept :
m_NumBitsWritten(0),
m_TotalBits(0) {}
m_TotalBits((std::numeric_limits<uint32_t>::max)()) {}

/**
* @brief Construct a writer pointing to the given byte array with @p num_bytes size
* @param num_bytes The number of bytes in the array
*/
explicit bit_measure(uint32_t num_bytes) noexcept :
bit_measure(uint32_t num_bytes) noexcept :
m_NumBitsWritten(0),
m_TotalBits(num_bytes * 8) {}

Expand Down
114 changes: 42 additions & 72 deletions include/bitstream/stream/bit_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "byte_buffer.h"
#include "serialize_traits.h"
#include "stream_traits.h"

#include <cstdint>
#include <cstring>
Expand All @@ -18,63 +19,34 @@ namespace bitstream
* @brief A stream for reading objects from a tightly packed buffer
* @note Does not take ownership of the buffer
*/
template<typename Policy>
class bit_reader
{
public:
static constexpr bool writing = false;
static constexpr bool reading = true;

/**
* @brief Default construct a reader pointing to a null buffer
* @brief Construct a reader with the parameters passed to the underlying policy
* @param ...args The arguments to pass to the policy
*/
bit_reader() noexcept :
m_Buffer(nullptr),
m_NumBitsRead(0),
m_TotalBits(0),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

/**
* @brief Construct a reader pointing to the given byte array with @p num_bits
* @param bytes The byte array to read from. Should be 4-byte aligned if possible. The size of the array must be a multiple of 4
* @param num_bits The maximum number of bits that we can read
*/
explicit bit_reader(const void* bytes, uint32_t num_bits) noexcept :
m_Buffer(static_cast<const uint32_t*>(bytes)),
m_NumBitsRead(0),
m_TotalBits(num_bits),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

/**
* @brief Construct a reader pointing to the given @p buffer
* @param buffer The buffer to read from
* @param num_bits The maximum number of bits that we can read
*/
template<size_t Size>
explicit bit_reader(byte_buffer<Size>& buffer, uint32_t num_bits) noexcept :
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsRead(0),
m_TotalBits(num_bits),
template<typename... Ts,
typename = std::enable_if_t<std::is_constructible_v<Policy, Ts...>>>
bit_reader(Ts&&... args)
noexcept(std::is_nothrow_constructible_v<Policy, Ts...>) :
m_Policy(std::forward<Ts>(args) ...),
m_Scratch(0),
m_ScratchBits(0),
m_WordIndex(0) {}

bit_reader(const bit_reader&) = delete;

bit_reader(bit_reader&& other) noexcept :
m_Buffer(other.m_Buffer),
m_NumBitsRead(other.m_NumBitsRead),
m_TotalBits(other.m_TotalBits),
m_Policy(std::move(other.m_Policy)),
m_Scratch(other.m_Scratch),
m_ScratchBits(other.m_ScratchBits),
m_WordIndex(other.m_WordIndex)
{
other.m_Buffer = nullptr;
other.m_NumBitsRead = 0;
other.m_TotalBits = 0;
other.m_Scratch = 0;
other.m_ScratchBits = 0;
other.m_WordIndex = 0;
Expand All @@ -84,16 +56,11 @@ namespace bitstream

bit_reader& operator=(bit_reader&& rhs) noexcept
{
m_Buffer = rhs.m_Buffer;
m_NumBitsRead = rhs.m_NumBitsRead;
m_TotalBits = rhs.m_TotalBits;
m_Policy = std::move(rhs.m_Policy);
m_Scratch = rhs.m_Scratch;
m_ScratchBits = rhs.m_ScratchBits;
m_WordIndex = rhs.m_WordIndex;

rhs.m_Buffer = nullptr;
rhs.m_NumBitsRead = 0;
rhs.m_TotalBits = 0;
rhs.m_Scratch = 0;
rhs.m_ScratchBits = 0;
rhs.m_WordIndex = 0;
Expand All @@ -105,39 +72,39 @@ namespace bitstream
* @brief Returns the buffer that this reader is currently serializing from
* @return The buffer
*/
[[nodiscard]] const uint8_t* get_buffer() const noexcept { return reinterpret_cast<const uint8_t*>(m_Buffer); }
[[nodiscard]] const uint8_t* get_buffer() const noexcept { return reinterpret_cast<const uint8_t*>(m_Policy.get_buffer()); }

/**
* @brief Returns the number of bits which have been read from the buffer
* @return The number of bits which have been read
*/
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsRead; }
[[nodiscard]] uint32_t get_num_bits_serialized() const noexcept { return m_Policy.get_num_bits_serialized(); }

/**
* @brief Returns the number of bytes which have been read from the buffer
* @return The number of bytes which have been read
*/
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return m_NumBitsRead > 0U ? ((m_NumBitsRead - 1U) / 8U + 1U) : 0U; }
[[nodiscard]] uint32_t get_num_bytes_serialized() const noexcept { return get_num_bits_serialized() > 0U ? ((get_num_bits_serialized() - 1U) / 8U + 1U) : 0U; }

/**
* @brief Returns whether the @p num_bits be read from the buffer
* @param num_bits The number of bits to test
* @return Whether the number of bits can be read from the buffer
*/
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsRead + num_bits <= m_TotalBits; }
[[nodiscard]] bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_Policy.can_serialize_bits(num_bits); }

/**
* @brief Returns the number of bits which have not been read yet
* @note The same as get_total_bits() - get_num_bits_serialized()
* @return The remaining space in the buffer
*/
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return m_TotalBits - m_NumBitsRead; }
[[nodiscard]] uint32_t get_remaining_bits() const noexcept { return get_total_bits() - get_num_bits_serialized(); }

/**
* @brief Returns the size of the buffer, in bits
* @return The size of the buffer, in bits
*/
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_TotalBits; }
[[nodiscard]] uint32_t get_total_bits() const noexcept { return m_Policy.get_total_bits(); }

/**
* @brief Reads the first 32 bits of the buffer and compares it to a checksum of the @p protocol_version and the rest of the buffer
Expand All @@ -146,21 +113,23 @@ namespace bitstream
*/
[[nodiscard]] bool serialize_checksum(uint32_t protocol_version) noexcept
{
BS_ASSERT(m_NumBitsRead == 0);
BS_ASSERT(get_num_bits_serialized() == 0);

BS_ASSERT(can_serialize_bits(32U));

uint32_t num_bytes = (m_TotalBits - 1U) / 8U + 1U;
uint32_t num_bytes = (get_total_bits() - 1U) / 8U + 1U;
const uint32_t* buffer = m_Policy.get_buffer();

// Generate checksum to compare against
uint32_t generated_checksum = utility::crc_uint32(reinterpret_cast<const uint8_t*>(&protocol_version), reinterpret_cast<const uint8_t*>(m_Buffer + 1), num_bytes - 4);
uint32_t generated_checksum = utility::crc_uint32(reinterpret_cast<const uint8_t*>(&protocol_version), reinterpret_cast<const uint8_t*>(buffer + 1), num_bytes - 4);

// Advance the reader by the size of the checksum (32 bits / 1 word)
m_WordIndex++;
m_NumBitsRead += 32U;

BS_ASSERT(m_Policy.extend(32U));

// Read the checksum
uint32_t checksum = *m_Buffer;
uint32_t checksum = *buffer;

// Compare the checksum
return generated_checksum == checksum;
Expand All @@ -173,11 +142,13 @@ namespace bitstream
*/
[[nodiscard]] bool pad_to_size(uint32_t num_bytes) noexcept
{
BS_ASSERT(num_bytes * 8U <= m_TotalBits);
BS_ASSERT(num_bytes * 8U >= m_NumBitsRead);
uint32_t num_bits_read = get_num_bits_serialized();

BS_ASSERT(num_bytes * 8U >= num_bits_read);

uint32_t remainder = (num_bytes * 8U - m_NumBitsRead) % 32U;
BS_ASSERT(can_serialize_bits(num_bytes * 8U - num_bits_read));

uint32_t remainder = (num_bytes * 8U - num_bits_read) % 32U;
uint32_t zero;

// Test the last word more carefully, as it may have data
Expand All @@ -187,7 +158,7 @@ namespace bitstream
BS_ASSERT(status && zero == 0);
}

uint32_t offset = m_NumBitsRead / 32;
uint32_t offset = get_num_bits_serialized() / 32;
uint32_t max = num_bytes / 4;

// Test for zeros in padding
Expand Down Expand Up @@ -217,13 +188,13 @@ namespace bitstream
*/
[[nodiscard]] bool align() noexcept
{
uint32_t remainder = m_NumBitsRead % 8U;
uint32_t remainder = get_num_bits_serialized() % 8U;
if (remainder != 0U)
{
uint32_t zero;
bool status = serialize_bits(zero, 8U - remainder);

BS_ASSERT(status && zero == 0U && m_NumBitsRead % 8U == 0U);
BS_ASSERT(status && zero == 0U && get_num_bits_serialized() % 8U == 0U);
}

return true;
Expand All @@ -239,24 +210,23 @@ namespace bitstream
{
BS_ASSERT(num_bits > 0U && num_bits <= 32U);

BS_ASSERT(can_serialize_bits(num_bits));
BS_ASSERT(m_Policy.extend(num_bits));

// Fast path
if (num_bits == 32U && m_ScratchBits == 0U)
{
const uint32_t* ptr = m_Buffer + m_WordIndex;
const uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;

value = utility::to_big_endian32(*ptr);

m_NumBitsRead += num_bits;
m_WordIndex++;

return true;
}

if (m_ScratchBits < num_bits)
{
const uint32_t* ptr = m_Buffer + m_WordIndex;
const uint32_t* ptr = m_Policy.get_buffer() + m_WordIndex;

uint64_t ptr_value = static_cast<uint64_t>(utility::to_big_endian32(*ptr)) << (32U - m_ScratchBits);
m_Scratch |= ptr_value;
Expand All @@ -269,7 +239,6 @@ namespace bitstream

m_Scratch <<= num_bits;
m_ScratchBits -= num_bits;
m_NumBitsRead += num_bits;

return true;
}
Expand All @@ -292,10 +261,11 @@ namespace bitstream

if (m_ScratchBits % 32U == 0U && num_words > 0U)
{
BS_ASSERT(m_Policy.extend(num_words * 32U));

// If the read buffer is word-aligned, just memcpy it
std::memcpy(word_buffer, m_Buffer + m_WordIndex, num_words * 4U);
std::memcpy(word_buffer, m_Policy.get_buffer() + m_WordIndex, num_words * 4U);

m_NumBitsRead += num_words * 32U;
m_WordIndex += num_words;
}
else
Expand Down Expand Up @@ -361,12 +331,12 @@ namespace bitstream
}

private:
const uint32_t* m_Buffer;
uint32_t m_NumBitsRead;
uint32_t m_TotalBits;
Policy m_Policy;

uint64_t m_Scratch;
uint32_t m_ScratchBits;
uint32_t m_WordIndex;
};

using fixed_bit_reader = bit_reader<fixed_policy>;
}
2 changes: 1 addition & 1 deletion include/bitstream/stream/bit_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ namespace bitstream
uint32_t m_WordIndex;
};

using fixed_bit_writer = bit_writer<fixed_policy<true>>;
using fixed_bit_writer = bit_writer<fixed_policy>;

template<typename T>
using growing_bit_writer = bit_writer<growing_policy<T>>;
Expand Down
22 changes: 10 additions & 12 deletions include/bitstream/stream/stream_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,20 @@

#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>

namespace bitstream
{
template<bool Writing>
struct fixed_policy
{
using void_ptr = std::conditional_t<Writing, void*, const void*>;
using buffer_ptr = std::conditional_t<Writing, uint32_t*, const uint32_t*>;

/**
* @brief Construct a stream pointing to the given byte array with @p num_bytes size
* @param bytes The byte array to serialize to/from. Must be 4-byte aligned and the size must be a multiple of 4
* @param num_bytes The number of bytes in the array
*/
fixed_policy(void_ptr buffer, uint32_t num_bits) noexcept :
m_Buffer(static_cast<buffer_ptr>(buffer)),
fixed_policy(void* buffer, uint32_t num_bits) noexcept :
m_Buffer(static_cast<uint32_t*>(buffer)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}

Expand All @@ -31,7 +28,7 @@ namespace bitstream
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer, uint32_t num_bits) noexcept :
m_Buffer(reinterpret_cast<buffer_ptr>(buffer.Bytes)),
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(num_bits) {}

Expand All @@ -41,13 +38,13 @@ namespace bitstream
*/
template<size_t Size>
fixed_policy(byte_buffer<Size>& buffer) noexcept :
m_Buffer(reinterpret_cast<buffer_ptr>(buffer.Bytes)),
m_Buffer(reinterpret_cast<uint32_t*>(buffer.Bytes)),
m_NumBitsSerialized(0),
m_TotalBits(Size * 8) {}

buffer_ptr get_buffer() const noexcept { return m_Buffer; }
uint32_t* get_buffer() const noexcept { return m_Buffer; }

// TODO: Transition to size_t
// TODO: Transition sizes to size_t
uint32_t get_num_bits_serialized() const noexcept { return m_NumBitsSerialized; }

bool can_serialize_bits(uint32_t num_bits) const noexcept { return m_NumBitsSerialized + num_bits <= m_TotalBits; }
Expand All @@ -61,7 +58,8 @@ namespace bitstream
return status;
}

buffer_ptr m_Buffer;
uint32_t* m_Buffer;
// TODO: Transition sizes to size_t
uint32_t m_NumBitsSerialized;
uint32_t m_TotalBits;
};
Expand All @@ -79,7 +77,7 @@ namespace bitstream

bool can_serialize_bits(uint32_t num_bits) const noexcept { return true; }

uint32_t get_total_bits() const noexcept { return std::numeric_limits<uint32_t>::max(); }
uint32_t get_total_bits() const noexcept { return (std::numeric_limits<uint32_t>::max)(); }

bool extend(uint32_t num_bits)
{
Expand Down
Loading

0 comments on commit f8ef797

Please sign in to comment.