Skip to content

Commit 13048eb

Browse files
committed
Implement approx to derivative of f wrt to arg using central diff approx
1 parent b41dc13 commit 13048eb

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

minitorch/autodiff.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Any, Iterable, List, Tuple # noqa: F401
2+
from typing import Any, Iterable, Tuple
33

44
from typing_extensions import Protocol
55

@@ -22,8 +22,10 @@ def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6)
2222
Returns:
2323
An approximation of $f'_i(x_0, \ldots, x_{n-1})$
2424
"""
25-
# TODO: Implement for Task 1.1.
26-
raise NotImplementedError("Need to implement for Task 1.1")
25+
return (
26+
f(*[v + epsilon if i == arg else v for i, v in enumerate(vals)])
27+
- f(*[v - epsilon if i == arg else v for i, v in enumerate(vals)])
28+
) / (2 * epsilon)
2729

2830

2931
variable_count = 1

0 commit comments

Comments
 (0)