diff --git a/miette-derive/src/diagnostic.rs b/miette-derive/src/diagnostic.rs index ef1bdae3..1239fcea 100644 --- a/miette-derive/src/diagnostic.rs +++ b/miette-derive/src/diagnostic.rs @@ -4,6 +4,7 @@ use syn::{punctuated::Punctuated, DeriveInput, Token}; use crate::code::Code; use crate::diagnostic_arg::DiagnosticArg; +use crate::diagnostic_source::DiagnosticSource; use crate::forward::{Forward, WhichFn}; use crate::help::Help; use crate::label::Labels; @@ -66,6 +67,7 @@ pub struct DiagnosticConcreteArgs { pub url: Option, pub forward: Option, pub related: Option, + pub diagnostic_source: Option, } impl DiagnosticConcreteArgs { @@ -74,6 +76,7 @@ impl DiagnosticConcreteArgs { let source_code = SourceCode::from_fields(fields)?; let related = Related::from_fields(fields)?; let help = Help::from_fields(fields)?; + let diagnostic_source = DiagnosticSource::from_fields(fields)?; Ok(DiagnosticConcreteArgs { code: None, help, @@ -83,6 +86,7 @@ impl DiagnosticConcreteArgs { url: None, forward: None, source_code, + diagnostic_source, }) } @@ -283,6 +287,8 @@ impl Diagnostic { let source_code_method = forward.gen_struct_method(WhichFn::SourceCode); let severity_method = forward.gen_struct_method(WhichFn::Severity); let related_method = forward.gen_struct_method(WhichFn::Related); + let diagnostic_source_method = + forward.gen_struct_method(WhichFn::DiagnosticSource); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { @@ -293,6 +299,7 @@ impl Diagnostic { #severity_method #source_code_method #related_method + #diagnostic_source_method } } } @@ -338,6 +345,11 @@ impl Diagnostic { .as_ref() .and_then(|x| x.gen_struct(fields)) .or_else(|| forward(WhichFn::SourceCode)); + let diagnostic_source = concrete + .diagnostic_source + .as_ref() + .and_then(|x| x.gen_struct()) + .or_else(|| forward(WhichFn::DiagnosticSource)); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { #code_body @@ -347,6 +359,7 @@ impl Diagnostic { #url_body #labels_body #src_body + #diagnostic_source } } } @@ -365,6 +378,7 @@ impl Diagnostic { let src_body = SourceCode::gen_enum(variants); let rel_body = Related::gen_enum(variants); let url_body = Url::gen_enum(ident, variants); + let diagnostic_source_body = DiagnosticSource::gen_enum(variants); quote! { impl #impl_generics miette::Diagnostic for #ident #ty_generics #where_clause { #code_body @@ -374,6 +388,7 @@ impl Diagnostic { #src_body #rel_body #url_body + #diagnostic_source_body } } } diff --git a/miette-derive/src/diagnostic_source.rs b/miette-derive/src/diagnostic_source.rs new file mode 100644 index 00000000..949defed --- /dev/null +++ b/miette-derive/src/diagnostic_source.rs @@ -0,0 +1,78 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::spanned::Spanned; + +use crate::forward::WhichFn; +use crate::{ + diagnostic::{DiagnosticConcreteArgs, DiagnosticDef}, + utils::{display_pat_members, gen_all_variants_with}, +}; + +pub struct DiagnosticSource(syn::Member); + +impl DiagnosticSource { + pub(crate) fn from_fields(fields: &syn::Fields) -> syn::Result> { + match fields { + syn::Fields::Named(named) => Self::from_fields_vec(named.named.iter().collect()), + syn::Fields::Unnamed(unnamed) => { + Self::from_fields_vec(unnamed.unnamed.iter().collect()) + } + syn::Fields::Unit => Ok(None), + } + } + + fn from_fields_vec(fields: Vec<&syn::Field>) -> syn::Result> { + for (i, field) in fields.iter().enumerate() { + for attr in &field.attrs { + if attr.path.is_ident("diagnostic_source") { + let diagnostic_source = if let Some(ident) = field.ident.clone() { + syn::Member::Named(ident) + } else { + syn::Member::Unnamed(syn::Index { + index: i as u32, + span: field.span(), + }) + }; + return Ok(Some(DiagnosticSource(diagnostic_source))); + } + } + } + Ok(None) + } + + pub(crate) fn gen_enum(variants: &[DiagnosticDef]) -> Option { + gen_all_variants_with( + variants, + WhichFn::DiagnosticSource, + |ident, + fields, + DiagnosticConcreteArgs { + diagnostic_source, .. + }| { + let (display_pat, _display_members) = display_pat_members(fields); + diagnostic_source.as_ref().map(|diagnostic_source| { + let rel = match &diagnostic_source.0 { + syn::Member::Named(ident) => ident.clone(), + syn::Member::Unnamed(syn::Index { index, .. }) => { + quote::format_ident!("_{}", index) + } + }; + quote! { + Self::#ident #display_pat => { + std::option::Option::Some(#rel.as_ref()) + } + } + }) + }, + ) + } + + pub(crate) fn gen_struct(&self) -> Option { + let rel = &self.0; + Some(quote! { + fn diagnostic_source<'a>(&'a self) -> std::option::Option<&'a dyn miette::Diagnostic> { + std::option::Option::Some(&self.#rel) + } + }) + } +} diff --git a/miette-derive/src/forward.rs b/miette-derive/src/forward.rs index ca7e1b38..c8757b20 100644 --- a/miette-derive/src/forward.rs +++ b/miette-derive/src/forward.rs @@ -38,6 +38,7 @@ pub enum WhichFn { Labels, SourceCode, Related, + DiagnosticSource, } impl WhichFn { @@ -50,6 +51,7 @@ impl WhichFn { Self::Labels => quote! { labels() }, Self::SourceCode => quote! { source_code() }, Self::Related => quote! { related() }, + Self::DiagnosticSource => quote! { diagnostic_source() }, } } @@ -76,6 +78,9 @@ impl WhichFn { Self::SourceCode => quote! { fn source_code(&self) -> std::option::Option<&dyn miette::SourceCode> }, + Self::DiagnosticSource => quote! { + fn diagnostic_source(&self) -> std::option::Option<&dyn miette::Diagnostic> + }, } } diff --git a/miette-derive/src/lib.rs b/miette-derive/src/lib.rs index da8f8bbb..0f7e64e5 100644 --- a/miette-derive/src/lib.rs +++ b/miette-derive/src/lib.rs @@ -6,6 +6,7 @@ use diagnostic::Diagnostic; mod code; mod diagnostic; mod diagnostic_arg; +mod diagnostic_source; mod fmt; mod forward; mod help; @@ -16,7 +17,10 @@ mod source_code; mod url; mod utils; -#[proc_macro_derive(Diagnostic, attributes(diagnostic, source_code, label, related, help))] +#[proc_macro_derive( + Diagnostic, + attributes(diagnostic, source_code, label, related, help, diagnostic_source) +)] pub fn derive_diagnostic(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let cmd = match Diagnostic::from_derive_input(input) { diff --git a/src/diagnostic_chain.rs b/src/diagnostic_chain.rs new file mode 100644 index 00000000..1e5e0c2b --- /dev/null +++ b/src/diagnostic_chain.rs @@ -0,0 +1,93 @@ +/*! +Iterate over error `.diagnostic_source()` chains. +*/ + +use crate::protocol::Diagnostic; + +/// Iterator of a chain of cause errors. +#[derive(Clone, Default)] +#[allow(missing_debug_implementations)] +pub(crate) struct DiagnosticChain<'a> { + state: Option>, +} + +impl<'a> DiagnosticChain<'a> { + pub(crate) fn from_diagnostic(head: &'a dyn Diagnostic) -> Self { + DiagnosticChain { + state: Some(ErrorKind::Diagnostic(head)), + } + } + + pub(crate) fn from_stderror(head: &'a (dyn std::error::Error + 'static)) -> Self { + DiagnosticChain { + state: Some(ErrorKind::StdError(head)), + } + } +} + +impl<'a> Iterator for DiagnosticChain<'a> { + type Item = ErrorKind<'a>; + + fn next(&mut self) -> Option { + if let Some(err) = self.state.take() { + self.state = err.get_nested(); + Some(err) + } else { + None + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } +} + +impl ExactSizeIterator for DiagnosticChain<'_> { + fn len(&self) -> usize { + fn depth(d: Option<&ErrorKind<'_>>) -> usize { + match d { + Some(d) => 1 + depth(d.get_nested().as_ref()), + None => 0, + } + } + + depth(self.state.as_ref()) + } +} + +#[derive(Clone)] +pub(crate) enum ErrorKind<'a> { + Diagnostic(&'a dyn Diagnostic), + StdError(&'a (dyn std::error::Error + 'static)), +} + +impl<'a> ErrorKind<'a> { + fn get_nested(&self) -> Option> { + match self { + ErrorKind::Diagnostic(d) => d + .diagnostic_source() + .map(ErrorKind::Diagnostic) + .or_else(|| d.source().map(ErrorKind::StdError)), + ErrorKind::StdError(e) => e.source().map(ErrorKind::StdError), + } + } +} + +impl<'a> std::fmt::Debug for ErrorKind<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorKind::Diagnostic(d) => d.fmt(f), + ErrorKind::StdError(e) => e.fmt(f), + } + } +} + +impl<'a> std::fmt::Display for ErrorKind<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ErrorKind::Diagnostic(d) => d.fmt(f), + ErrorKind::StdError(e) => e.fmt(f), + } + } +} diff --git a/src/handlers/debug.rs b/src/handlers/debug.rs index bc245bac..a9460d1d 100644 --- a/src/handlers/debug.rs +++ b/src/handlers/debug.rs @@ -51,6 +51,9 @@ impl DebugReportHandler { let labels: Vec<_> = labels.collect(); diag.field("labels", &format!("{:?}", labels)); } + if let Some(cause) = diagnostic.diagnostic_source() { + diag.field("caused by", &format!("{:?}", cause)); + } diag.finish()?; writeln!(f)?; writeln!(f, "NOTE: If you're looking for the fancy error reports, install miette with the `fancy` feature, or write your own and hook it up with miette::set_hook().") diff --git a/src/handlers/graphical.rs b/src/handlers/graphical.rs index 377922f1..b0ab478d 100644 --- a/src/handlers/graphical.rs +++ b/src/handlers/graphical.rs @@ -2,7 +2,7 @@ use std::fmt::{self, Write}; use owo_colors::{OwoColorize, Style}; -use crate::chain::Chain; +use crate::diagnostic_chain::DiagnosticChain; use crate::handlers::theme::*; use crate::protocol::{Diagnostic, Severity}; use crate::{LabeledSpan, MietteError, ReportHandler, SourceCode, SourceSpan, SpanContents}; @@ -199,8 +199,12 @@ impl GraphicalReportHandler { writeln!(f, "{}", textwrap::fill(&diagnostic.to_string(), opts))?; - if let Some(cause) = diagnostic.source() { - let mut cause_iter = Chain::new(cause).peekable(); + if let Some(mut cause_iter) = diagnostic + .diagnostic_source() + .map(DiagnosticChain::from_diagnostic) + .or_else(|| diagnostic.source().map(DiagnosticChain::from_stderror)) + .map(|it| it.peekable()) + { while let Some(error) = cause_iter.next() { let is_last = cause_iter.peek().is_none(); let char = if !is_last { diff --git a/src/handlers/narratable.rs b/src/handlers/narratable.rs index c9b656a6..5ff3ae95 100644 --- a/src/handlers/narratable.rs +++ b/src/handlers/narratable.rs @@ -2,7 +2,7 @@ use std::fmt; use unicode_width::{UnicodeWidthChar, UnicodeWidthStr}; -use crate::chain::Chain; +use crate::diagnostic_chain::DiagnosticChain; use crate::protocol::{Diagnostic, Severity}; use crate::{LabeledSpan, MietteError, ReportHandler, SourceCode, SourceSpan, SpanContents}; @@ -80,8 +80,12 @@ impl NarratableReportHandler { } fn render_causes(&self, f: &mut impl fmt::Write, diagnostic: &(dyn Diagnostic)) -> fmt::Result { - if let Some(cause) = diagnostic.source() { - for error in Chain::new(cause) { + if let Some(cause_iter) = diagnostic + .diagnostic_source() + .map(DiagnosticChain::from_diagnostic) + .or_else(|| diagnostic.source().map(DiagnosticChain::from_stderror)) + { + for error in cause_iter { writeln!(f, " Caused by: {}", error)?; } } diff --git a/src/lib.rs b/src/lib.rs index d76f8051..e397712e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -584,6 +584,7 @@ pub use panic::*; pub use protocol::*; mod chain; +mod diagnostic_chain; mod error; mod eyreish; #[cfg(feature = "fancy-no-backtrace")] diff --git a/src/protocol.rs b/src/protocol.rs index 2a6545a1..d5cd1fa6 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -59,6 +59,11 @@ pub trait Diagnostic: std::error::Error { fn related<'a>(&'a self) -> Option + 'a>> { None } + + /// The cause of the error. + fn diagnostic_source(&self) -> Option<&dyn Diagnostic> { + None + } } impl std::error::Error for Box { diff --git a/tests/test_diagnostic_source_macro.rs b/tests/test_diagnostic_source_macro.rs new file mode 100644 index 00000000..d7c5d2d9 --- /dev/null +++ b/tests/test_diagnostic_source_macro.rs @@ -0,0 +1,20 @@ +use miette::Diagnostic; + +#[derive(Debug, miette::Diagnostic, thiserror::Error)] +#[error("AnErr")] +struct AnErr; + +#[derive(Debug, miette::Diagnostic, thiserror::Error)] +#[error("TestError")] +struct TestError { + #[diagnostic_source] + asdf_inner_foo: AnErr, +} + +#[test] +fn test_diagnostic_source() { + let error = TestError { + asdf_inner_foo: AnErr, + }; + assert!(error.diagnostic_source().is_some()); +}