Skip to content

Commit 5810e9a

Browse files
committed
Implement forward pass scalar functions:
* add * lt * gt * eq * sub * neg * log * exp * sigmoid * relu
1 parent b35d173 commit 5810e9a

File tree

2 files changed

+35
-45
lines changed

2 files changed

+35
-45
lines changed

minitorch/scalar.py

+18-28
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
import numpy as np
77

88
from .autodiff import Context, Variable, backpropagate, central_difference
9-
from .scalar_functions import EQ # noqa: F401
10-
from .scalar_functions import LT # noqa: F401
11-
from .scalar_functions import Add # noqa: F401
12-
from .scalar_functions import Exp # noqa: F401
13-
from .scalar_functions import Log # noqa: F401
14-
from .scalar_functions import Neg # noqa: F401
15-
from .scalar_functions import ReLU # noqa: F401
16-
from .scalar_functions import Sigmoid # noqa: F401
9+
from .scalar_functions import EQ
10+
from .scalar_functions import LT
11+
from .scalar_functions import Add
12+
from .scalar_functions import Exp
13+
from .scalar_functions import Log
14+
from .scalar_functions import Neg
15+
from .scalar_functions import ReLU
16+
from .scalar_functions import Sigmoid
1717
from .scalar_functions import (
1818
Inv,
1919
Mul,
@@ -92,31 +92,25 @@ def __rtruediv__(self, b: ScalarLike) -> Scalar:
9292
return Mul.apply(b, Inv.apply(self))
9393

9494
def __add__(self, b: ScalarLike) -> Scalar:
95-
# TODO: Implement for Task 1.2.
96-
raise NotImplementedError("Need to implement for Task 1.2")
95+
return Add.apply(self, b)
9796

9897
def __bool__(self) -> bool:
9998
return bool(self.data)
10099

101100
def __lt__(self, b: ScalarLike) -> Scalar:
102-
# TODO: Implement for Task 1.2.
103-
raise NotImplementedError("Need to implement for Task 1.2")
101+
return LT.apply(self, b)
104102

105103
def __gt__(self, b: ScalarLike) -> Scalar:
106-
# TODO: Implement for Task 1.2.
107-
raise NotImplementedError("Need to implement for Task 1.2")
104+
return LT.apply(b, self)
108105

109106
def __eq__(self, b: ScalarLike) -> Scalar: # type: ignore[override]
110-
# TODO: Implement for Task 1.2.
111-
raise NotImplementedError("Need to implement for Task 1.2")
107+
return EQ.apply(self, b)
112108

113109
def __sub__(self, b: ScalarLike) -> Scalar:
114-
# TODO: Implement for Task 1.2.
115-
raise NotImplementedError("Need to implement for Task 1.2")
110+
return Add.apply(self, -b)
116111

117112
def __neg__(self) -> Scalar:
118-
# TODO: Implement for Task 1.2.
119-
raise NotImplementedError("Need to implement for Task 1.2")
113+
return Neg.apply(self.data)
120114

121115
def __radd__(self, b: ScalarLike) -> Scalar:
122116
return self + b
@@ -125,20 +119,16 @@ def __rmul__(self, b: ScalarLike) -> Scalar:
125119
return self * b
126120

127121
def log(self) -> Scalar:
128-
# TODO: Implement for Task 1.2.
129-
raise NotImplementedError("Need to implement for Task 1.2")
122+
return Log.apply(self.data)
130123

131124
def exp(self) -> Scalar:
132-
# TODO: Implement for Task 1.2.
133-
raise NotImplementedError("Need to implement for Task 1.2")
125+
return Exp.apply(self.data)
134126

135127
def sigmoid(self) -> Scalar:
136-
# TODO: Implement for Task 1.2.
137-
raise NotImplementedError("Need to implement for Task 1.2")
128+
return Sigmoid.apply(self.data)
138129

139130
def relu(self) -> Scalar:
140-
# TODO: Implement for Task 1.2.
141-
raise NotImplementedError("Need to implement for Task 1.2")
131+
return ReLU.apply(self.data)
142132

143133
# Variable elements for backprop
144134

minitorch/scalar_functions.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class Add(ScalarFunction):
7474

7575
@staticmethod
7676
def forward(ctx: Context, a: float, b: float) -> float:
77-
return a + b
77+
return operators.add(a, b)
7878

7979
@staticmethod
8080
def backward(ctx: Context, d_output: float) -> Tuple[float, ...]:
@@ -103,8 +103,8 @@ class Mul(ScalarFunction):
103103

104104
@staticmethod
105105
def forward(ctx: Context, a: float, b: float) -> float:
106-
# TODO: Implement for Task 1.2.
107-
raise NotImplementedError("Need to implement for Task 1.2")
106+
ctx.save_for_backward(b)
107+
return operators.mul(a, b)
108108

109109
@staticmethod
110110
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
@@ -117,8 +117,8 @@ class Inv(ScalarFunction):
117117

118118
@staticmethod
119119
def forward(ctx: Context, a: float) -> float:
120-
# TODO: Implement for Task 1.2.
121-
raise NotImplementedError("Need to implement for Task 1.2")
120+
ctx.save_for_backward(a)
121+
return operators.inv(a)
122122

123123
@staticmethod
124124
def backward(ctx: Context, d_output: float) -> float:
@@ -131,8 +131,8 @@ class Neg(ScalarFunction):
131131

132132
@staticmethod
133133
def forward(ctx: Context, a: float) -> float:
134-
# TODO: Implement for Task 1.2.
135-
raise NotImplementedError("Need to implement for Task 1.2")
134+
ctx.save_for_backward(a)
135+
return operators.neg(a)
136136

137137
@staticmethod
138138
def backward(ctx: Context, d_output: float) -> float:
@@ -145,8 +145,8 @@ class Sigmoid(ScalarFunction):
145145

146146
@staticmethod
147147
def forward(ctx: Context, a: float) -> float:
148-
# TODO: Implement for Task 1.2.
149-
raise NotImplementedError("Need to implement for Task 1.2")
148+
ctx.save_for_backward(a)
149+
return operators.sigmoid(a)
150150

151151
@staticmethod
152152
def backward(ctx: Context, d_output: float) -> float:
@@ -159,8 +159,8 @@ class ReLU(ScalarFunction):
159159

160160
@staticmethod
161161
def forward(ctx: Context, a: float) -> float:
162-
# TODO: Implement for Task 1.2.
163-
raise NotImplementedError("Need to implement for Task 1.2")
162+
ctx.save_for_backward(a)
163+
return operators.relu(a)
164164

165165
@staticmethod
166166
def backward(ctx: Context, d_output: float) -> float:
@@ -173,8 +173,8 @@ class Exp(ScalarFunction):
173173

174174
@staticmethod
175175
def forward(ctx: Context, a: float) -> float:
176-
# TODO: Implement for Task 1.2.
177-
raise NotImplementedError("Need to implement for Task 1.2")
176+
ctx.save_for_backward(a)
177+
return operators.exp(a)
178178

179179
@staticmethod
180180
def backward(ctx: Context, d_output: float) -> float:
@@ -187,8 +187,8 @@ class LT(ScalarFunction):
187187

188188
@staticmethod
189189
def forward(ctx: Context, a: float, b: float) -> float:
190-
# TODO: Implement for Task 1.2.
191-
raise NotImplementedError("Need to implement for Task 1.2")
190+
ctx.save_for_backward(a, b)
191+
return operators.lt(a, b)
192192

193193
@staticmethod
194194
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
@@ -201,8 +201,8 @@ class EQ(ScalarFunction):
201201

202202
@staticmethod
203203
def forward(ctx: Context, a: float, b: float) -> float:
204-
# TODO: Implement for Task 1.2.
205-
raise NotImplementedError("Need to implement for Task 1.2")
204+
ctx.save_for_backward(a, b)
205+
return operators.eq(a, b)
206206

207207
@staticmethod
208208
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:

0 commit comments

Comments
 (0)