From 13048eb193d5b77be0fe2c6cc764322d48d6c9e4 Mon Sep 17 00:00:00 2001 From: daniel <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:16:39 +0100 Subject: [PATCH] Implement approx to derivative of f wrt to arg using central diff approx --- minitorch/autodiff.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/minitorch/autodiff.py b/minitorch/autodiff.py index 9431908..3a52170 100644 --- a/minitorch/autodiff.py +++ b/minitorch/autodiff.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Iterable, List, Tuple # noqa: F401 +from typing import Any, Iterable, Tuple from typing_extensions import Protocol @@ -22,8 +22,10 @@ def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6) Returns: An approximation of $f'_i(x_0, \ldots, x_{n-1})$ """ - # TODO: Implement for Task 1.1. - raise NotImplementedError("Need to implement for Task 1.1") + return ( + f(*[v + epsilon if i == arg else v for i, v in enumerate(vals)]) + - f(*[v - epsilon if i == arg else v for i, v in enumerate(vals)]) + ) / (2 * epsilon) variable_count = 1