Skip to content

Commit

Permalink
Merge pull request #1375 from tjkirch/gzip-user-data
Browse files Browse the repository at this point in the history
Actually merge #1366: gzip user data
  • Loading branch information
tjkirch authored Mar 9, 2021
2 parents dd5f0a1 + f4c9756 commit ae1dc2b
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 38 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ Here's the user data to change the message of the day setting, as we did in the
motd = "my own value!"
```

If your user data is over the size limit of the platform (e.g. 16KiB for EC2) you can compress the contents with gzip.
(With [aws-cli](https://aws.amazon.com/cli/), you can use `--user-data fileb:///path/to/gz-file` to pass binary data.)

### Description of settings

Here we'll describe each setting you can change.
Expand Down
3 changes: 3 additions & 0 deletions sources/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions sources/api/early-boot-config/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ exclude = ["README.md"]
[dependencies]
apiclient = { path = "../apiclient" }
base64 = "0.13"
flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] }
http = "0.2"
log = "0.4"
reqwest = { version = "0.10", default-features = false, features = ["blocking"] }
Expand All @@ -27,3 +28,7 @@ toml = "0.5"

[build-dependencies]
cargo-readme = "3.1"

[dev-dependencies]
hex-literal = "0.3"
lazy_static = "1.4"
226 changes: 226 additions & 0 deletions sources/api/early-boot-config/src/compression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
//! This module supports reading from an input source that could be compressed or plain text.
//!
//! Currently gzip compression is supported.

use flate2::read::GzDecoder;
use std::fs::File;
use std::io::{BufReader, Chain, Cursor, ErrorKind, Read, Result, Take};
use std::path::Path;

/// "File magic" that indicates file type is stored in a few bytes at the start at the start of the
/// data. For now we only need two bytes for gzip, but if adding new formats, we'd need to read
/// more. (The simplest approach may be to read the max length for any format we need and compare
/// the appropriate prefix length.)
/// https://en.wikipedia.org/wiki/List_of_file_signatures
const MAGIC_LEN: usize = 2;

// We currently only support gzip, but it shouldn't be hard to add more.
/// These bytes are at the start of any gzip-compressed data.
const GZ_MAGIC: [u8; 2] = [0x1f, 0x8b];

/// This helper takes a slice of bytes representing UTF-8 text, which can optionally be
/// compressed, and returns an uncompressed string.
pub fn expand_slice_maybe(input: &[u8]) -> Result<String> {
let mut output = String::new();
let mut reader = OptionalCompressionReader::new(Cursor::new(input));
reader.read_to_string(&mut output)?;
Ok(output)
}

/// This helper takes the path to a file containing UTF-8 text, which can optionally be compressed,
/// and returns an uncompressed string of all its contents. File reads are done through BufReader.
pub fn expand_file_maybe<P>(path: P) -> Result<String>
where
P: AsRef<Path>,
{
let path = path.as_ref();
let file = File::open(&path)?;
let mut output = String::new();
let mut reader = OptionalCompressionReader::new(BufReader::new(file));
reader.read_to_string(&mut output)?;
Ok(output)
}

// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=

/// This type lets you wrap a `Read` whose data may or may not be compressed, and its `read()`
/// calls will uncompress the data if needed.
pub struct OptionalCompressionReader<R>(CompressionType<R>);

/// This represents the type of compression we've detected within a `Read`, or `Unknown` if we
/// haven't yet read any bytes to be able to detect it.
enum CompressionType<R> {
/// This represents the starting state of the reader before we've read the magic bytes and
/// detected any compression.
///
/// We need ownership of the `Read` to construct one of the variants below, so we use an
/// `Option` to allow `take`ing the value out, even if we only have a &mut reference in the
/// `read` implementation. This is safe because detection is a one-time process and we know we
/// construct this with Some value.
Unknown(Option<R>),

/// We haven't found recognizable compression.
None(Peek<R>),

/// We found gzip compression.
Gz(GzDecoder<Peek<R>>),
}

/// `Peek` lets us read the starting bytes (the "magic") of an input `Read` but maintain those
/// bytes in an internal buffer. We Take the number of bytes we read (to handle reads shorter than
/// MAGIC_LEN) and Chain them together with the rest of the input, to represent the full input.
type Peek<T> = Chain<Take<Cursor<[u8; MAGIC_LEN]>>, T>;

impl<R: Read> OptionalCompressionReader<R> {
/// Build a new `OptionalCompressionReader` before we know the input compression type.
pub fn new(input: R) -> Self {
Self(CompressionType::Unknown(Some(input)))
}
}

/// Implement `Read` by checking whether we've detected compression type yet, and if not, detecting
/// it and then replacing ourselves with the appropriate type so we can continue reading.
impl<R: Read> Read for OptionalCompressionReader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match self.0 {
CompressionType::Unknown(ref mut input) => {
// Take ownership of our `Read` object so we can store it in a new variant.
let mut reader = input.take().expect(
"OptionalCompressionReader constructed with None input; programming error",
);

// Read the "magic" that tells us the compression type.
let mut magic = [0u8; MAGIC_LEN];
let count = reader.retry_read(&mut magic)?;

// We need to return all of the bytes, but we just consumed MAGIC_LEN of them.
// This chains together those initial bytes with the remainder so we have them all.
let magic_read = Cursor::new(magic).take(count as u64);
let full_input = magic_read.chain(reader);

// Detect compression type based on the magic bytes.
if count == MAGIC_LEN && magic == GZ_MAGIC {
// Use a gzip decoder if gzip compressed.
self.0 = CompressionType::Gz(GzDecoder::new(full_input))
} else {
// We couldn't detect any compression; just read the input.
self.0 = CompressionType::None(full_input)
}

// We've replaced Unknown with a known compression type; defer to that for reading.
self.read(buf)
}

// After initial detection, we just perform standard reads on the reader we prepared.
CompressionType::None(ref mut r) => r.read(buf),
CompressionType::Gz(ref mut r) => r.read(buf),
}
}
}

// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=

/// This trait represents a `Read` operation where we want to retry after standard interruptions
/// (unlike `read()`) but also need to know the number of bytes we read (unlike `read_exact()`).
trait RetryRead<R> {
fn retry_read(&mut self, buf: &mut [u8]) -> Result<usize>;
}

impl<R: Read> RetryRead<R> for R {
// This implementation is based on stdlib Read::read_exact, but hitting EOF isn't a failure, we
// just want to return the number of bytes we could read.
fn retry_read(&mut self, mut buf: &mut [u8]) -> Result<usize> {
let mut count = 0;

// Read until we have no more space in the output buffer
while !buf.is_empty() {
match self.read(buf) {
// No bytes left, done
Ok(0) => break,
// Read n bytes, slide ahead n in the output buffer and read more
Ok(n) => {
count += n;
let tmp = buf;
buf = &mut tmp[n..];
}
// Retry on interrupt
Err(e) if e.kind() == ErrorKind::Interrupted => {}
// Other failures are fatal
Err(e) => return Err(e),
}
}

Ok(count)
}
}

// =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^= =^..^=

#[cfg(test)]
mod test {
use super::*;
use hex_literal::hex;
use lazy_static::lazy_static;
use std::io::Cursor;

lazy_static! {
/// Some plain text strings and their gzip encodings.
static ref DATA: &'static [(&'static str, &'static [u8])] = &[
("", &hex!("1f8b 0808 3863 3960 0003 656d 7074 7900 0300 0000 0000 0000 0000")),
("4", &hex!("1f8b 0808 6f63 3960 0003 666f 7572 0033 0100 381b b6f3 0100 0000")),
("42", &hex!("1f8b 0808 7c6b 3960 0003 616e 7377 6572 0033 3102 0088 b024 3202 0000 00")),
("hi there", &hex!("1f8b 0808 d24f 3960 0003 6869 7468 6572 6500 cbc8 5428 c948 2d4a 0500 ec76 a3e3 0800 0000")),
];
}

#[test]
fn test_plain() {
for (plain, _gz) in *DATA {
let input = Cursor::new(plain);
let mut output = String::new();
OptionalCompressionReader::new(input)
.read_to_string(&mut output)
.unwrap();
assert_eq!(output, *plain);
}
}

#[test]
fn test_gz() {
for (plain, gz) in *DATA {
let input = Cursor::new(gz);
let mut output = String::new();
OptionalCompressionReader::new(input)
.read_to_string(&mut output)
.unwrap();
assert_eq!(output, *plain);
}
}

#[test]
fn test_helper_plain() {
for (plain, _gz) in *DATA {
assert_eq!(expand_slice_maybe(plain.as_bytes()).unwrap(), *plain);
}
}

#[test]
fn test_helper_gz() {
for (plain, gz) in *DATA {
assert_eq!(expand_slice_maybe(gz).unwrap(), *plain);
}
}

#[test]
fn test_magic_prefix() {
// Confirm that if we give a prefix of valid magic, but not the whole thing, we just get
// that input back.
let input = Cursor::new(&[0x1f]);
let mut output = Vec::new();
let count = OptionalCompressionReader::new(input)
.read_to_end(&mut output)
.unwrap();
assert_eq!(count, 1);
assert_eq!(output, &[0x1f]);
}
}
1 change: 1 addition & 0 deletions sources/api/early-boot-config/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::fs;
use std::str::FromStr;
use std::{env, process};

mod compression;
mod provider;
mod settings;
use crate::provider::PlatformDataProvider;
Expand Down
66 changes: 50 additions & 16 deletions sources/api/early-boot-config/src/provider/aws.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! The aws module implements the `PlatformDataProvider` trait for gathering userdata on AWS.

use super::{PlatformDataProvider, SettingsJson};
use crate::compression::expand_slice_maybe;
use http::StatusCode;
use reqwest::blocking::Client;
use serde_json::json;
Expand Down Expand Up @@ -48,7 +49,7 @@ impl AwsDataProvider {
session_token: &str,
uri: &str,
description: &str,
) -> Result<Option<String>> {
) -> Result<Option<Vec<u8>>> {
debug!("Requesting {} from {}", description, uri);
let response = client
.get(uri)
Expand All @@ -57,15 +58,35 @@ impl AwsDataProvider {
.context(error::Request { method: "GET", uri })?;
trace!("IMDS response: {:?}", &response);

// IMDS data can be larger than we'd want to log (50k+ compressed) so we don't necessarily
// want to show the whole thing, and don't want to show binary data.
fn response_string(response: &[u8]) -> String {
// arbitrary max len; would be nice to print the start of the data if it's
// uncompressed, but we'd need to break slice at a safe point for UTF-8, and without
// reading in the whole thing like String::from_utf8.
if response.len() > 2048 {
"<very long>".to_string()
} else if let Ok(s) = String::from_utf8(response.into()) {
s
} else {
"<binary>".to_string()
}
}

match response.status() {
code @ StatusCode::OK => {
info!("Received {}", description);
let response_body = response.text().context(error::ResponseBody {
method: "GET",
uri,
code,
})?;
trace!("Response text: {:?}", &response_body);
let response_body = response
.bytes()
.context(error::ResponseBody {
method: "GET",
uri,
code,
})?
.to_vec();

let response_str = response_string(&response_body);
trace!("Response: {:?}", response_str);

Ok(Some(response_body))
}
Expand All @@ -74,18 +95,24 @@ impl AwsDataProvider {
StatusCode::NOT_FOUND => Ok(None),

code @ _ => {
let response_body = response.text().context(error::ResponseBody {
method: "GET",
uri,
code,
})?;
trace!("Response text: {:?}", &response_body);
let response_body = response
.bytes()
.context(error::ResponseBody {
method: "GET",
uri,
code,
})?
.to_vec();

let response_str = response_string(&response_body);

trace!("Response: {:?}", response_str);

error::Response {
method: "GET",
uri,
code,
response_body,
response_body: response_str,
}
.fail()
}
Expand All @@ -98,11 +125,13 @@ impl AwsDataProvider {
let desc = "user data";
let uri = Self::USER_DATA_ENDPOINT;

let user_data_str = match Self::fetch_imds(client, session_token, uri, desc) {
let user_data_raw = match Self::fetch_imds(client, session_token, uri, desc) {
Err(e) => return Err(e),
Ok(None) => return Ok(None),
Ok(Some(s)) => s,
};
let user_data_str = expand_slice_maybe(&user_data_raw)
.context(error::Decompression { what: "user data" })?;
trace!("Received user data: {}", user_data_str);

// Remove outer "settings" layer before sending to API
Expand Down Expand Up @@ -131,7 +160,9 @@ impl AwsDataProvider {
match Self::fetch_imds(client, session_token, uri, desc) {
Err(e) => return Err(e),
Ok(None) => return Ok(None),
Ok(Some(s)) => s,
Ok(Some(raw)) => {
expand_slice_maybe(&raw).context(error::Decompression { what: "user data" })?
}
}
};
trace!("Received instance identity document: {}", iid_str);
Expand Down Expand Up @@ -198,6 +229,9 @@ mod error {
#[snafu(display("Response '{}' from '{}': {}", get_bad_status_code(&source), uri, source))]
BadResponse { uri: String, source: reqwest::Error },

#[snafu(display("Failed to decompress {}: {}", what, source))]
Decompression { what: String, source: io::Error },

#[snafu(display("Error deserializing from JSON: {}", source))]
DeserializeJson { source: serde_json::error::Error },

Expand Down
Loading

0 comments on commit ae1dc2b

Please sign in to comment.