|
6 | 6 | use std::fmt::{self, Display, Formatter}; |
7 | 7 | use std::str::FromStr; |
8 | 8 |
|
| 9 | +use crate::expand::typetree::TypeTree; |
9 | 10 | use crate::expand::{Decodable, Encodable, HashStable_Generic}; |
10 | 11 | use crate::{Ty, TyKind}; |
11 | 12 |
|
@@ -84,6 +85,8 @@ pub struct AutoDiffItem { |
84 | 85 | /// The name of the function being generated |
85 | 86 | pub target: String, |
86 | 87 | pub attrs: AutoDiffAttrs, |
| 88 | + pub inputs: Vec<TypeTree>, |
| 89 | + pub output: TypeTree, |
87 | 90 | } |
88 | 91 |
|
89 | 92 | #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] |
@@ -275,14 +278,22 @@ impl AutoDiffAttrs { |
275 | 278 | !matches!(self.mode, DiffMode::Error | DiffMode::Source) |
276 | 279 | } |
277 | 280 |
|
278 | | - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { |
279 | | - AutoDiffItem { source, target, attrs: self } |
| 281 | + pub fn into_item( |
| 282 | + self, |
| 283 | + source: String, |
| 284 | + target: String, |
| 285 | + inputs: Vec<TypeTree>, |
| 286 | + output: TypeTree, |
| 287 | + ) -> AutoDiffItem { |
| 288 | + AutoDiffItem { source, target, inputs, output, attrs: self } |
280 | 289 | } |
281 | 290 | } |
282 | 291 |
|
283 | 292 | impl fmt::Display for AutoDiffItem { |
284 | 293 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
285 | 294 | write!(f, "Differentiating {} -> {}", self.source, self.target)?; |
286 | | - write!(f, " with attributes: {:?}", self.attrs) |
| 295 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 296 | + write!(f, " with inputs: {:?}", self.inputs)?; |
| 297 | + write!(f, " with output: {:?}", self.output) |
287 | 298 | } |
288 | 299 | } |
0 commit comments