Skip to content

Commit

Permalink
[bug] Fix file download (#567)
Browse files Browse the repository at this point in the history
* fix file download

* fix nits

* fix tests

* fix tests
  • Loading branch information
KCarretto authored Feb 10, 2024
1 parent 5f185bd commit 1c8eff8
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 119 deletions.
48 changes: 12 additions & 36 deletions implants/imix/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,56 +147,32 @@ impl TaskHandle {
tavern: &mut impl Transport,
req: FileRequest,
) -> Result<()> {
let (ch_file_chunk, file_chunk) = channel::<DownloadFileResponse>();
let (tx, rx) = channel::<DownloadFileResponse>();

tavern
.download_file(DownloadFileRequest { name: req.name() }, ch_file_chunk)
.download_file(DownloadFileRequest { name: req.name() }, tx)
.await?;

let task_id = self.id;
let handle = tokio::task::spawn(async move {
loop {
let resp = match file_chunk.recv() {
Ok(r) => r,
Err(_err) => {
match _err.to_string().as_str() {
"receiving on a closed channel" => {}
_ => {
#[cfg(debug_assertions)]
log::error!(
"failed to download file chunk: task_id={}, name={}: {}",
task_id,
req.name(),
_err
);
}
}
return;
}
};

#[cfg(debug_assertions)]
log::info!(
"downloaded file chunk: task_id={}, name={}, size={}",
task_id,
req.name(),
resp.chunk.len()
);

match req.send_chunk(resp.chunk) {
let handle = tokio::task::spawn_blocking(move || {
for r in rx {
match req.send_chunk(r.chunk) {
Ok(_) => {}
Err(_err) => {
#[cfg(debug_assertions)]
log::error!(
"failed to send downloaded file chunk: task_id={}, name={}: {}",
task_id,
"failed to send downloaded file chunk: {}: {}",
req.name(),
_err
);

return;
}
};
}
}
#[cfg(debug_assertions)]
log::info!("file download completed: {}", req.name());
});

self.download_handles.push(handle);
Ok(())
}
Expand Down
1 change: 1 addition & 0 deletions implants/lib/c2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = "2021"

[dependencies]
eldritch = { workspace = true }
log = { workspace = true }
tonic = { workspace = true, features = ["tls-roots"] }
prost = { workspace = true}
prost-types = { workspace = true }
Expand Down
39 changes: 35 additions & 4 deletions implants/lib/c2/src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,44 @@ impl crate::Transport for GRPC {
async fn download_file(
&mut self,
request: crate::pb::DownloadFileRequest,
sender: Sender<crate::pb::DownloadFileResponse>,
tx: Sender<crate::pb::DownloadFileResponse>,
) -> Result<()> {
#[cfg(debug_assertions)]
let filename = request.name.clone();

let resp = self.download_file_impl(request).await?;
let mut stream = resp.into_inner();
while let Some(file_chunk) = stream.message().await? {
sender.send(file_chunk)?;
}
tokio::spawn(async move {
loop {
let msg = match stream.message().await {
Ok(maybe_msg) => match maybe_msg {
Some(msg) => msg,
None => {
break;
}
},
Err(_err) => {
#[cfg(debug_assertions)]
log::error!("failed to download file: {}: {}", filename, _err);

return;
}
};
match tx.send(msg) {
Ok(_) => {}
Err(_err) => {
#[cfg(debug_assertions)]
log::error!(
"failed to send downloaded file chunk: {}: {}",
filename,
_err
);

return;
}
}
}
});
Ok(())
}

Expand Down
198 changes: 120 additions & 78 deletions implants/lib/eldritch/src/assets/copy_impl.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::runtime::Client;
use anyhow::{Context, Result};
use starlark::{eval::Evaluator, values::list::ListRef};
use std::fs::OpenOptions;
use std::io::Write;
use std::{fs, sync::mpsc::Receiver};

fn copy_local(src: String, dst: String) -> Result<()> {
Expand All @@ -15,36 +17,42 @@ fn copy_local(src: String, dst: String) -> Result<()> {
}
}

fn copy_remote(file_reciever: Receiver<Vec<u8>>, dst: String) -> Result<()> {
loop {
let val = match file_reciever.recv() {
Ok(v) => v,
Err(err) => {
match err.to_string().as_str() {
"channel is empty and sending half is closed" => {
break;
}
"timed out waiting on channel" => {
continue;
}
_ => {
#[cfg(debug_assertions)]
log::debug!("failed to drain channel: {}", err)
}
}
break;
}
};
match fs::write(dst.clone(), val) {
Ok(_) => {}
Err(local_err) => return Err(local_err.try_into()?),
};
fn copy_remote(rx: Receiver<Vec<u8>>, dst_path: String) -> Result<()> {
// Truncate file
let mut dst = OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(&dst_path)
.context(format!(
"failed to truncate destination file: {}",
&dst_path
))?;
dst.flush()
.context(format!("failed to flush file truncation: {}", &dst_path))?;

// Reopen file for writing
let mut dst = OpenOptions::new()
.create(true)
.append(true)
.open(&dst_path)
.context(format!("failed to open file for writing: {}", &dst_path))?;

// Listen for downloaded chunks and write them
for chunk in rx {
dst.write_all(&chunk)
.context(format!("failed to write file chunk: {}", &dst_path))?;
}

// Ensure all chunks gets written
dst.flush()
.context(format!("failed to flush file: {}", &dst_path))?;

Ok(())
}

pub fn copy(starlark_eval: &mut Evaluator<'_, '_>, src: String, dst: String) -> Result<()> {
// #[allow(clippy::needless_pass_by_ref_mut)]
pub fn copy(starlark_eval: &Evaluator<'_, '_>, src: String, dst: String) -> Result<()> {
let remote_assets = starlark_eval.module().get("remote_assets");

if let Some(assets) = remote_assets {
Expand All @@ -63,64 +71,98 @@ pub fn copy(starlark_eval: &mut Evaluator<'_, '_>, src: String, dst: String) ->

#[cfg(test)]
mod tests {
use crate::assets::copy_impl::copy_remote;
use crate::Runtime;

use std::sync::mpsc::channel;
use std::{collections::HashMap, io::prelude::*};
use tempfile::NamedTempFile;

// fn init_log() {
// pretty_env_logger::formatted_timed_builder()
// .filter_level(log::LevelFilter::Info)
// .parse_env("IMIX_LOG")
// .init();
// }

// #[tokio::test]
// async fn test_remote_copy() -> anyhow::Result<()> {
// // Create files
// let mut tmp_file_dst = NamedTempFile::new()?;
// let path_dst = String::from(tmp_file_dst.path().to_str().unwrap());

// let (sender, reciver) = channel::<Vec<u8>>();
// sender.send("Hello from a remote asset".as_bytes().to_vec())?;

// copy_remote(reciver, path_dst)?;

// let mut contents = String::new();
// tmp_file_dst.read_to_string(&mut contents)?;
// assert!(contents.contains("Hello from a remote asset"));
// Ok(())
// }

// #[tokio::test]
// async fn test_remote_copy_full() -> anyhow::Result<()> {
// init_log();
// log::debug!("Testing123");

// // Create files
// let mut tmp_file_dst = NamedTempFile::new()?;
// let path_dst = String::from(tmp_file_dst.path().to_str().unwrap());

// let (runtime, broker) = Runtime::new();
// let handle = tokio::task::spawn_blocking(move || {
// runtime.run(crate::pb::Tome {
// eldritch: r#"assets.copy("test_tome/test_file.txt", input_params['test_output'])"#
// .to_owned(),
// parameters: HashMap::from([("test_output".to_string(), path_dst)]),
// file_names: Vec::from(["test_tome/test_file.txt".to_string()]),
// })
// });
// handle.await?;
// println!("{:?}", broker.collect_file_requests().len());
// assert!(broker.collect_errors().is_empty()); // No errors even though the remote asset is inaccessible

// let mut contents = String::new();
// tmp_file_dst.read_to_string(&mut contents)?;
// // Compare - Should be empty basically just didn't error
// assert!(contents.contains(""));

// Ok(())
// }
#[tokio::test]
async fn test_remote_copy() -> anyhow::Result<()> {
// Create files
let mut tmp_file_dst = NamedTempFile::new()?;
let path_dst = String::from(tmp_file_dst.path().to_str().unwrap());

let (ch_data, data) = channel::<Vec<u8>>();
let handle = tokio::task::spawn_blocking(|| {
copy_remote(data, path_dst).expect("copy_remote failed")
});

ch_data.send("Hello from a remote asset".as_bytes().to_vec())?;
ch_data.send("Goodbye from a remote asset".as_bytes().to_vec())?;

// Drop the Sender, to indicate no more data will be sent (channel closed)
drop(ch_data);

handle.await?;

let mut contents = String::new();
tmp_file_dst.read_to_string(&mut contents)?;
assert!(contents.contains("Hello from a remote asset"));
assert!(contents.contains("Goodbye from a remote asset"));
Ok(())
}

#[tokio::test]
async fn test_remote_copy_full() -> anyhow::Result<()> {
// Create files
let mut tmp_file_dst = NamedTempFile::new()?;
let path_dst = String::from(tmp_file_dst.path().to_str().unwrap());

// Create a runtime
let (runtime, broker) = Runtime::new();

// Execute eldritch in it's own thread
let handle = tokio::task::spawn_blocking(move || {
runtime.run(crate::pb::Tome {
eldritch: r#"assets.copy("test_tome/test_file.txt", input_params['test_output'])"#
.to_owned(),
parameters: HashMap::from([("test_output".to_string(), path_dst)]),
file_names: Vec::from(["test_tome/test_file.txt".to_string()]),
})
});

// We now mock the agent, looping until eldritch requests a file
// We omit the sleep performed by the agent, just to save test time
loop {
// The broker only returns the data that is currently available
// So this may return an empty vec if our eldritch tokio task has not yet been scheduled
let mut reqs = broker.collect_file_requests();

// If no file request is yet available, just continue looping
if reqs.is_empty() {
continue;
}

// Ensure the right file was requested
assert!(reqs.len() == 1);
let req = reqs.pop().expect("no file request received!");
assert!(req.name() == "test_tome/test_file.txt");

// Now, we provide the file to eldritch (as a series of chunks)
req.send_chunk("chunk1\n".as_bytes().to_vec())
.expect("failed to send file chunk to eldritch");
req.send_chunk("chunk2\n".as_bytes().to_vec())
.expect("failed to send file chunk to eldritch");

// We've finished providing the file, so we stop looping
// This will drop `req`, which consequently drops the underlying `Sender` for the file channel
// This will cause the next `recv()` to error with "channel is empty and sending half is closed"
// which is what tells eldritch that there are no more file chunks to wait for
break;
}

// Now that we've finished writing data, we wait for eldritch to finish
handle.await?;

// Lastly, assert the file was written correctly
let mut contents = String::new();
tmp_file_dst.read_to_string(&mut contents)?;
assert_eq!("chunk1\nchunk2\n", contents.as_str());

Ok(())
}

#[test]
fn test_embedded_copy() -> anyhow::Result<()> {
Expand Down
7 changes: 6 additions & 1 deletion tavern/internal/c2/c2test/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package c2test

import (
"context"
"errors"
"net"
"testing"

Expand Down Expand Up @@ -34,8 +35,9 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) {
baseSrv := grpc.NewServer()
c2pb.RegisterC2Server(baseSrv, c2.New(graph))

grpcErrCh := make(chan error, 1)
go func() {
require.NoError(t, baseSrv.Serve(lis), "failed to serve grpc")
grpcErrCh <- baseSrv.Serve(lis)
}()

conn, err := grpc.DialContext(
Expand All @@ -52,5 +54,8 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) {
assert.NoError(t, lis.Close())
baseSrv.Stop()
assert.NoError(t, graph.Close())
if err := <-grpcErrCh; err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Fatalf("failed to serve grpc")
}
}
}

0 comments on commit 1c8eff8

Please sign in to comment.