@@ -74,7 +74,7 @@ class Add(ScalarFunction):
74
74
75
75
@staticmethod
76
76
def forward (ctx : Context , a : float , b : float ) -> float :
77
- return a + b
77
+ return operators . add ( a , b )
78
78
79
79
@staticmethod
80
80
def backward (ctx : Context , d_output : float ) -> Tuple [float , ...]:
@@ -103,8 +103,8 @@ class Mul(ScalarFunction):
103
103
104
104
@staticmethod
105
105
def forward (ctx : Context , a : float , b : float ) -> float :
106
- # TODO: Implement for Task 1.2.
107
- raise NotImplementedError ( "Need to implement for Task 1.2" )
106
+ ctx . save_for_backward ( b )
107
+ return operators . mul ( a , b )
108
108
109
109
@staticmethod
110
110
def backward (ctx : Context , d_output : float ) -> Tuple [float , float ]:
@@ -117,8 +117,8 @@ class Inv(ScalarFunction):
117
117
118
118
@staticmethod
119
119
def forward (ctx : Context , a : float ) -> float :
120
- # TODO: Implement for Task 1.2.
121
- raise NotImplementedError ( "Need to implement for Task 1.2" )
120
+ ctx . save_for_backward ( a )
121
+ return operators . inv ( a )
122
122
123
123
@staticmethod
124
124
def backward (ctx : Context , d_output : float ) -> float :
@@ -131,8 +131,8 @@ class Neg(ScalarFunction):
131
131
132
132
@staticmethod
133
133
def forward (ctx : Context , a : float ) -> float :
134
- # TODO: Implement for Task 1.2.
135
- raise NotImplementedError ( "Need to implement for Task 1.2" )
134
+ ctx . save_for_backward ( a )
135
+ return operators . neg ( a )
136
136
137
137
@staticmethod
138
138
def backward (ctx : Context , d_output : float ) -> float :
@@ -145,8 +145,8 @@ class Sigmoid(ScalarFunction):
145
145
146
146
@staticmethod
147
147
def forward (ctx : Context , a : float ) -> float :
148
- # TODO: Implement for Task 1.2.
149
- raise NotImplementedError ( "Need to implement for Task 1.2" )
148
+ ctx . save_for_backward ( a )
149
+ return operators . sigmoid ( a )
150
150
151
151
@staticmethod
152
152
def backward (ctx : Context , d_output : float ) -> float :
@@ -159,8 +159,8 @@ class ReLU(ScalarFunction):
159
159
160
160
@staticmethod
161
161
def forward (ctx : Context , a : float ) -> float :
162
- # TODO: Implement for Task 1.2.
163
- raise NotImplementedError ( "Need to implement for Task 1.2" )
162
+ ctx . save_for_backward ( a )
163
+ return operators . relu ( a )
164
164
165
165
@staticmethod
166
166
def backward (ctx : Context , d_output : float ) -> float :
@@ -173,8 +173,8 @@ class Exp(ScalarFunction):
173
173
174
174
@staticmethod
175
175
def forward (ctx : Context , a : float ) -> float :
176
- # TODO: Implement for Task 1.2.
177
- raise NotImplementedError ( "Need to implement for Task 1.2" )
176
+ ctx . save_for_backward ( a )
177
+ return operators . exp ( a )
178
178
179
179
@staticmethod
180
180
def backward (ctx : Context , d_output : float ) -> float :
@@ -187,8 +187,8 @@ class LT(ScalarFunction):
187
187
188
188
@staticmethod
189
189
def forward (ctx : Context , a : float , b : float ) -> float :
190
- # TODO: Implement for Task 1.2.
191
- raise NotImplementedError ( "Need to implement for Task 1.2" )
190
+ ctx . save_for_backward ( a , b )
191
+ return operators . lt ( a , b )
192
192
193
193
@staticmethod
194
194
def backward (ctx : Context , d_output : float ) -> Tuple [float , float ]:
@@ -201,8 +201,8 @@ class EQ(ScalarFunction):
201
201
202
202
@staticmethod
203
203
def forward (ctx : Context , a : float , b : float ) -> float :
204
- # TODO: Implement for Task 1.2.
205
- raise NotImplementedError ( "Need to implement for Task 1.2" )
204
+ ctx . save_for_backward ( a , b )
205
+ return operators . eq ( a , b )
206
206
207
207
@staticmethod
208
208
def backward (ctx : Context , d_output : float ) -> Tuple [float , float ]:
0 commit comments