Skip to content

Commit 3a336f3

Browse files
committed
Fix a bug in autograd example
1 parent bfeb647 commit 3a336f3

1 file changed

Lines changed: 31 additions & 5 deletions

File tree

06_rnns/Autograd_Simple.ipynb

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,22 @@
115115
" else:\n",
116116
" self.grad += backward_grad\n",
117117
" \n",
118+
" gradient_to_send = backward_grad if backward_grad is not None else 1\n",
118119
" if self.creation_op == \"add\":\n",
119-
" # Simply send backward self.grad, since increasing either of these \n",
120+
" # Simply send backward backward_grad, since increasing either of these \n",
120121
" # elements will increase the output by that same amount\n",
121-
" self.depends_on[0].backward(self.grad)\n",
122-
" self.depends_on[1].backward(self.grad) \n",
122+
" self.depends_on[0].backward(gradient_to_send)\n",
123+
" self.depends_on[1].backward(gradient_to_send) \n",
123124
"\n",
124125
" if self.creation_op == \"mul\":\n",
125126
"\n",
126127
" # Calculate the derivative with respect to the first element\n",
127-
" new = self.depends_on[1] * self.grad\n",
128+
" new = self.depends_on[1] * gradient_to_send\n",
128129
" # Send backward the derivative with respect to that element\n",
129130
" self.depends_on[0].backward(new.num)\n",
130131
"\n",
131132
" # Calculate the derivative with respect to the second element\n",
132-
" new = self.depends_on[0] * self.grad\n",
133+
" new = self.depends_on[0] * gradient_to_send\n",
133134
" # Send backward the derivative with respect to that element\n",
134135
" self.depends_on[1].backward(new.num)"
135136
]
@@ -157,6 +158,31 @@
157158
"print(b.grad) # as expected"
158159
]
159160
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": 8,
164+
"metadata": {},
165+
"outputs": [
166+
{
167+
"name": "stdout",
168+
"output_type": "stream",
169+
"text": [
170+
"24\n"
171+
]
172+
}
173+
],
174+
"source": [
175+
"a = NumberWithGrad(3)\n",
176+
"b = a * 4\n",
177+
"c = b + 3\n",
178+
"d = b * 5\n",
179+
"e = c + d\n",
180+
"\n",
181+
"\n",
182+
"e.backward()\n",
183+
"print(a.grad)"
184+
]
185+
},
160186
{
161187
"cell_type": "code",
162188
"execution_count": 6,

0 commit comments

Comments
 (0)