Skip to content

Commit

Permalink
...#253
Browse files Browse the repository at this point in the history
  • Loading branch information
crowlogic committed Dec 1, 2024
1 parent 899a5f9 commit 58e75b3
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions src/main/java/arb/expressions/nodes/unary/FunctionNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,66 @@ public String typeset()
@Override
public Node<D, R, F> differentiate(VariableNode<D, R, F> variable)
{
assert false : "TODO: apply chain rule, diff(f(g(x)),x)=D(f)(g(x))*diff(g(x),x)\n";
return null;
// Step 1: Differentiate the argument (g'(x)).
Node<D, R, F> argDerivative = arg.differentiate(variable);

// Step 2: Differentiate the function (f'(g(x))).
Node<D, R, F> functionDerivative = differentiateFunction();

// Step 3: Apply the chain rule: f'(g(x)) * g'(x).
return functionDerivative.mul(argDerivative);
}

/**
* Returns the node representing the derivative of the function. This will vary
* based on whether the function is built-in or contextual.
*/
private Node<D, R, F> differentiateFunction()
{
// Check if the function is built-in or contextual.
if (isBuiltin())
{
return differentiateBuiltinFunction();
}
else if (contextual)
{
return differentiateContextualFunction();
}
else
{
throw new UnsupportedOperationException("Cannot differentiate function: " + functionName);
}
}

/**
* Handles differentiation for built-in functions.
*/
private Node<D, R, F> differentiateBuiltinFunction()
{
switch (functionName)
{
case "sin":
return new FunctionNode<>("cos",
arg,
expression); // derivative of sin is cos
case "cos":
return new FunctionNode<>("sin",
arg,
expression).neg(); // derivative of cos is -sin
case "exp":
return this; // derivative of exp is exp
// Add other built-in function derivatives
default:
throw new UnsupportedOperationException("Derivative not implemented for function: " + functionName);
}
}

/**
* Handles differentiation for contextual functions.
*/
private Node<D, R, F> differentiateContextualFunction()
{
throw new UnsupportedOperationException("Contextual function differentiation not implemented: " + functionName);
}

@Override
Expand Down

0 comments on commit 58e75b3

Please sign in to comment.