Skip to content

Commit c3e8315

Browse files
committed
fix bug in autograd example
1 parent bfeb647 commit c3e8315

1 file changed

Lines changed: 45 additions & 13 deletions

File tree

06_rnns/Autograd_Simple.ipynb

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
22
"cells": [
33
{
4+
"attachments": {},
45
"cell_type": "markdown",
56
"metadata": {},
67
"source": [
@@ -9,7 +10,7 @@
910
},
1011
{
1112
"cell_type": "code",
12-
"execution_count": 1,
13+
"execution_count": 2,
1314
"metadata": {},
1415
"outputs": [],
1516
"source": [
@@ -22,7 +23,7 @@
2223
},
2324
{
2425
"cell_type": "code",
25-
"execution_count": 2,
26+
"execution_count": 3,
2627
"metadata": {},
2728
"outputs": [
2829
{
@@ -31,7 +32,7 @@
3132
"7"
3233
]
3334
},
34-
"execution_count": 2,
35+
"execution_count": 3,
3536
"metadata": {},
3637
"output_type": "execute_result"
3738
}
@@ -43,7 +44,7 @@
4344
},
4445
{
4546
"cell_type": "code",
46-
"execution_count": 3,
47+
"execution_count": 4,
4748
"metadata": {},
4849
"outputs": [
4950
{
@@ -66,7 +67,7 @@
6667
},
6768
{
6869
"cell_type": "code",
69-
"execution_count": 4,
70+
"execution_count": 5,
7071
"metadata": {},
7172
"outputs": [],
7273
"source": [
@@ -115,28 +116,29 @@
115116
" else:\n",
116117
" self.grad += backward_grad\n",
117118
" \n",
119+
" gradient_to_send = backward_grad if backward_grad is not None else 1\n",
118120
" if self.creation_op == \"add\":\n",
119-
" # Simply send backward self.grad, since increasing either of these \n",
121+
" # Simply send backward backward_grad, since increasing either of these \n",
120122
" # elements will increase the output by that same amount\n",
121-
" self.depends_on[0].backward(self.grad)\n",
122-
" self.depends_on[1].backward(self.grad) \n",
123+
" self.depends_on[0].backward(gradient_to_send)\n",
124+
" self.depends_on[1].backward(gradient_to_send) \n",
123125
"\n",
124126
" if self.creation_op == \"mul\":\n",
125127
"\n",
126128
" # Calculate the derivative with respect to the first element\n",
127-
" new = self.depends_on[1] * self.grad\n",
129+
" new = self.depends_on[1] * gradient_to_send\n",
128130
" # Send backward the derivative with respect to that element\n",
129131
" self.depends_on[0].backward(new.num)\n",
130132
"\n",
131133
" # Calculate the derivative with respect to the second element\n",
132-
" new = self.depends_on[0] * self.grad\n",
134+
" new = self.depends_on[0] * gradient_to_send\n",
133135
" # Send backward the derivative with respect to that element\n",
134136
" self.depends_on[1].backward(new.num)"
135137
]
136138
},
137139
{
138140
"cell_type": "code",
139-
"execution_count": 5,
141+
"execution_count": 6,
140142
"metadata": {},
141143
"outputs": [
142144
{
@@ -157,6 +159,31 @@
157159
"print(b.grad) # as expected"
158160
]
159161
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 8,
165+
"metadata": {},
166+
"outputs": [
167+
{
168+
"name": "stdout",
169+
"output_type": "stream",
170+
"text": [
171+
"24\n"
172+
]
173+
}
174+
],
175+
"source": [
176+
"a = NumberWithGrad(3)\n",
177+
"b = a * 4\n",
178+
"c = b + 3\n",
179+
"d = b * 5\n",
180+
"e = c + d\n",
181+
"\n",
182+
"\n",
183+
"e.backward()\n",
184+
"print(a.grad)"
185+
]
186+
},
160187
{
161188
"cell_type": "code",
162189
"execution_count": 6,
@@ -202,7 +229,7 @@
202229
],
203230
"metadata": {
204231
"kernelspec": {
205-
"display_name": "Python 3",
232+
"display_name": "Python 3.9.13 64-bit (microsoft store)",
206233
"language": "python",
207234
"name": "python3"
208235
},
@@ -216,7 +243,12 @@
216243
"name": "python",
217244
"nbconvert_exporter": "python",
218245
"pygments_lexer": "ipython3",
219-
"version": "3.7.4"
246+
"version": "3.9.13"
247+
},
248+
"vscode": {
249+
"interpreter": {
250+
"hash": "96f88a1d939096e74b5883cdeb3bbaf3df602d5ab14210a6f9e7d8e0ea241fea"
251+
}
220252
}
221253
},
222254
"nbformat": 4,

0 commit comments

Comments
 (0)