Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 308 additions & 40 deletions apps/desktop/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]

use rfd::FileDialog;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::{json, Value};
use std::{
collections::HashMap,
Expand Down Expand Up @@ -141,6 +141,49 @@ struct RehearsalRolePayload {
simplification: String,
setup_note: String,
manual_overrides: Vec<ManualOverridePayload>,
overlap_warnings: Vec<String>,
}

#[derive(Clone, Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SectionTimeRangePayload {
start: u32,
end: u32,
}

impl<'de> Deserialize<'de> for SectionTimeRangePayload {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
struct RawSectionTimeRangePayload {
start: u32,
end: u32,
}

let raw = RawSectionTimeRangePayload::deserialize(deserializer)?;
if raw.end <= raw.start {
return Err(serde::de::Error::custom(
"section timeRange end must be greater than start",
));
}

Ok(Self {
start: raw.start,
end: raw.end,
})
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields)]
struct PartGraphNodePayload {
role_id: String,
is_active: bool,
handoff_to: Vec<String>,
handoff_from: Vec<String>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand All @@ -149,8 +192,10 @@ struct RehearsalSectionPayload {
id: String,
label: String,
groove: String,
time_range: SectionTimeRangePayload,
confidence: ConfidencePayload,
roles: Vec<RehearsalRolePayload>,
part_graph: Vec<PartGraphNodePayload>,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -353,6 +398,74 @@ fn normalize_local_audio_source(path: &Path) -> Result<LocalAudioSourcePayload,
})
}

fn youtube_source_from_metadata(
metadata: &Value,
cache_root: &Path,
) -> Result<LocalAudioSourcePayload, String> {
let filepath = metadata
.get("filepath")
.and_then(|value| value.as_str())
.filter(|value| !value.trim().is_empty())
.ok_or_else(|| "Failed to parse YouTube import response.".to_string())?;
let title = metadata
.get("title")
.and_then(|value| value.as_str())
.unwrap_or("Unknown YouTube Audio");
let path = Path::new(filepath);
let link_metadata = std::fs::symlink_metadata(path)
.map_err(|_| "Could not read downloaded audio file.".to_string())?;
if link_metadata.file_type().is_symlink() {
return Err("YouTube import returned an invalid audio path.".to_string());
}

let canonical_cache_root = cache_root
.canonicalize()
.map_err(|_| "Could not validate YouTube import workspace.".to_string())?;
let canonical = path
.canonicalize()
.map_err(|_| "Could not read downloaded audio file.".to_string())?;
if !canonical.starts_with(&canonical_cache_root) {
return Err("YouTube import returned an invalid audio path.".to_string());
}

let file_metadata = std::fs::metadata(&canonical)
.map_err(|_| "Could not read downloaded audio file.".to_string())?;
if !file_metadata.is_file() || file_metadata.len() == 0 {
return Err("YouTube import returned an invalid audio file.".to_string());
}

let extension = canonical
.extension()
.and_then(|value| value.to_str())
.map(|value| value.to_ascii_lowercase())
.ok_or_else(|| "YouTube import returned an unsupported audio format.".to_string())?;
if !AUDIO_EXTENSIONS.contains(&extension.as_str()) {
return Err("YouTube import returned an unsupported audio format.".to_string());
}

let safe_title: String = title
.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '.' => '_',
c if c.is_control() => '_',
c => c,
})
.take(100)
.collect();
let safe_title = if safe_title.is_empty() {
"youtube_audio".to_string()
} else {
safe_title
};

Ok(LocalAudioSourcePayload {
source_path: canonical.to_string_lossy().into_owned(),
file_name: format!("{safe_title}.{extension}"),
extension,
file_size_bytes: file_metadata.len(),
})
}

fn parse_request_payload(payload: Value) -> Result<AnalysisJobRequest, String> {
let Value::Object(map) = payload else {
return Err("Invalid analysis job request: invalid field 'root'".into());
Expand Down Expand Up @@ -787,44 +900,7 @@ async fn import_youtube_url(

if parsed.get("ok").and_then(|v| v.as_bool()) == Some(true) {
if let Some(metadata) = parsed.get("metadata") {
let filepath = metadata
.get("filepath")
.and_then(|v| v.as_str())
.unwrap_or("");
let title = metadata
.get("title")
.and_then(|v| v.as_str())
.unwrap_or("Unknown YouTube Audio");
let path = Path::new(filepath);
let metadata_fs = std::fs::metadata(path)
.map_err(|_| "Could not read downloaded audio file.".to_string())?;
let extension = path
.extension()
.and_then(|v| v.to_str())
.unwrap_or("m4a")
.to_string();

let safe_title: String = title
.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' | '.' => '_',
c if c.is_control() => '_',
c => c,
})
.take(100)
.collect();
let safe_title = if safe_title.is_empty() {
"youtube_audio".to_string()
} else {
safe_title
};

let source = LocalAudioSourcePayload {
source_path: filepath.to_string(),
file_name: format!("{}.{}", safe_title, extension),
extension,
file_size_bytes: metadata_fs.len(),
};
let source = youtube_source_from_metadata(metadata, &cache_root)?;

let summary = ProjectBootstrapSummaryPayload {
project_id,
Expand Down Expand Up @@ -883,6 +959,30 @@ fn is_supported_youtube_url(url: &str) -> bool {

false
}

fn project_payload_from_content(content: &str) -> Result<RehearsalSongPayload, String> {
if let Ok(parsed) = serde_json::from_str::<RehearsalSongPayload>(content) {
return Ok(parsed);
}

let payload = serde_json::from_str::<Value>(content)
.map_err(|_| "Invalid project file format".to_string())?;
if let Some(sections) = payload.get("sections").and_then(Value::as_array) {
for (section_index, section) in sections.iter().enumerate() {
if section
.as_object()
.is_some_and(|section_object| !section_object.contains_key("timeRange"))
{
return Err(format!(
"Invalid project file format: sections[{section_index}].timeRange is required; reanalyze the project to restore section timing."
));
}
}
}

serde_json::from_value(payload).map_err(|_| "Invalid project file format".to_string())
}

#[tauri::command]
fn save_project(payload: Value) -> Result<(), String> {
let parsed = serde_json::from_value::<RehearsalSongPayload>(payload)
Expand Down Expand Up @@ -913,7 +1013,175 @@ fn load_project() -> Result<RehearsalSongPayload, String> {
}

let content = std::fs::read_to_string(path).map_err(|_| "Failed to read file".to_string())?;
serde_json::from_str(&content).map_err(|_| "Invalid project file format".to_string())
project_payload_from_content(&content)
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::{SystemTime, UNIX_EPOCH};

fn unique_test_dir(name: &str) -> PathBuf {
let suffix = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock should be after epoch")
.as_nanos();
std::env::temp_dir().join(format!("bandscope-{name}-{suffix}"))
}

fn shared_contract_payload(time_range: Value) -> Value {
json!({
"id": "demo-song",
"title": "Late Night Set",
"sections": [
{
"id": "verse-1",
"label": "verse",
"groove": "Straight eighths with a late snare feel",
"timeRange": time_range,
"confidence": {
"level": "medium",
"source": "model",
"notes": "Double-check the pickup into the chorus."
},
"roles": [
{
"id": "bass-guitar",
"name": "Bass Guitar",
"roleType": "instrument",
"harmony": {
"chord": "C#m7",
"functionLabel": "vi pedal anchor",
"source": "model"
},
"cue": {
"kind": "transition",
"value": "Hold through the pickup before the downbeat."
},
"range": {
"lowestNote": "C#2",
"highestNote": "E3"
},
"confidence": {
"level": "medium",
"source": "model",
"notes": "Watch the slide into the turnaround."
},
"rehearsalPriority": "high",
"simplification": "Stay on roots if the chorus entrance gets muddy.",
"setupNote": "Keep the attack short so the verse breathes.",
"manualOverrides": [],
"overlapWarnings": [
"Density warning: competing with Keyboard Left Hand in low register."
]
}
],
"partGraph": [
{
"role_id": "bass-guitar",
"is_active": true,
"handoff_to": ["lead-vocal"],
"handoff_from": []
}
]
}
],
"exportSummary": {
"format": "cue-sheet",
"headline": "Start with the verse handoff and low-register overlap.",
"focusSections": ["verse-1"]
}
})
}

#[test]
fn rehearsal_song_payload_accepts_shared_section_contract() {
let payload = shared_contract_payload(json!({ "start": 10, "end": 30 }));

let parsed = serde_json::from_value::<RehearsalSongPayload>(payload)
.expect("shared rehearsal song contract should deserialize in Tauri");

assert_eq!(parsed.sections[0].id, "verse-1");
}

#[test]
fn rehearsal_song_payload_rejects_reversed_time_range() {
let payload = shared_contract_payload(json!({ "start": 30, "end": 10 }));

assert!(serde_json::from_value::<RehearsalSongPayload>(payload).is_err());
}

#[test]
fn project_payload_from_content_rejects_legacy_missing_time_range() {
let mut payload = shared_contract_payload(json!({ "start": 10, "end": 30 }));
payload["sections"][0]
.as_object_mut()
.expect("section should be an object")
.remove("timeRange");
let content = serde_json::to_string(&payload).expect("legacy payload should serialize");

let error = project_payload_from_content(&content)
.expect_err("legacy sections without timing should fail closed");

assert!(error.contains("timeRange"));
}

#[test]
fn youtube_metadata_must_reference_supported_audio_inside_cache_root() {
let cache_root = unique_test_dir("youtube-cache");
let outside_root = unique_test_dir("youtube-outside");
std::fs::create_dir_all(&cache_root).expect("cache root should be created");
std::fs::create_dir_all(&outside_root).expect("outside root should be created");

let inside_file = cache_root.join("downloaded.m4a");
let empty_file = cache_root.join("empty.m4a");
let unsupported_file = cache_root.join("downloaded.txt");
let outside_file = outside_root.join("downloaded.m4a");
std::fs::write(&inside_file, b"audio").expect("inside file should be written");
std::fs::write(&empty_file, b"").expect("empty file should be written");
std::fs::write(&unsupported_file, b"not audio")
.expect("unsupported file should be written");
std::fs::write(&outside_file, b"audio").expect("outside file should be written");

let accepted = youtube_source_from_metadata(
&json!({ "filepath": inside_file, "title": "Live/Test" }),
&cache_root,
)
.expect("in-cache supported audio should be accepted");
assert_eq!(accepted.extension, "m4a");
assert_eq!(accepted.file_name, "Live_Test.m4a");

assert!(youtube_source_from_metadata(
&json!({ "filepath": empty_file, "title": "Live" }),
&cache_root,
)
.is_err());
assert!(youtube_source_from_metadata(
&json!({ "filepath": unsupported_file, "title": "Live" }),
&cache_root,
)
.is_err());
assert!(youtube_source_from_metadata(
&json!({ "filepath": outside_file, "title": "Live" }),
&cache_root,
)
.is_err());

#[cfg(unix)]
{
let symlink_file = cache_root.join("linked.m4a");
std::os::unix::fs::symlink(&inside_file, &symlink_file)
.expect("symlink should be created");
assert!(youtube_source_from_metadata(
&json!({ "filepath": symlink_file, "title": "Live" }),
&cache_root,
)
.is_err());
}

let _ = std::fs::remove_dir_all(cache_root);
let _ = std::fs::remove_dir_all(outside_root);
}
}

fn main() {
Expand Down
Loading
Loading