|
1 | 1 | { |
2 | 2 | "cells": [ |
3 | 3 | { |
| 4 | + "attachments": {}, |
4 | 5 | "cell_type": "markdown", |
5 | 6 | "metadata": {}, |
6 | 7 | "source": [ |
|
9 | 10 | }, |
10 | 11 | { |
11 | 12 | "cell_type": "code", |
12 | | - "execution_count": 1, |
| 13 | + "execution_count": 2, |
13 | 14 | "metadata": {}, |
14 | 15 | "outputs": [], |
15 | 16 | "source": [ |
|
22 | 23 | }, |
23 | 24 | { |
24 | 25 | "cell_type": "code", |
25 | | - "execution_count": 2, |
| 26 | + "execution_count": 3, |
26 | 27 | "metadata": {}, |
27 | 28 | "outputs": [ |
28 | 29 | { |
|
31 | 32 | "7" |
32 | 33 | ] |
33 | 34 | }, |
34 | | - "execution_count": 2, |
| 35 | + "execution_count": 3, |
35 | 36 | "metadata": {}, |
36 | 37 | "output_type": "execute_result" |
37 | 38 | } |
|
43 | 44 | }, |
44 | 45 | { |
45 | 46 | "cell_type": "code", |
46 | | - "execution_count": 3, |
| 47 | + "execution_count": 4, |
47 | 48 | "metadata": {}, |
48 | 49 | "outputs": [ |
49 | 50 | { |
|
66 | 67 | }, |
67 | 68 | { |
68 | 69 | "cell_type": "code", |
69 | | - "execution_count": 4, |
| 70 | + "execution_count": 5, |
70 | 71 | "metadata": {}, |
71 | 72 | "outputs": [], |
72 | 73 | "source": [ |
|
115 | 116 | " else:\n", |
116 | 117 | " self.grad += backward_grad\n", |
117 | 118 | " \n", |
| 119 | + " gradient_to_send = backward_grad if backward_grad is not None else 1\n", |
118 | 120 | " if self.creation_op == \"add\":\n", |
119 | | - " # Simply send backward self.grad, since increasing either of these \n", |
| 121 | + " # Simply send backward backward_grad, since increasing either of these \n", |
120 | 122 | " # elements will increase the output by that same amount\n", |
121 | | - " self.depends_on[0].backward(self.grad)\n", |
122 | | - " self.depends_on[1].backward(self.grad) \n", |
| 123 | + " self.depends_on[0].backward(gradient_to_send)\n", |
| 124 | + " self.depends_on[1].backward(gradient_to_send) \n", |
123 | 125 | "\n", |
124 | 126 | " if self.creation_op == \"mul\":\n", |
125 | 127 | "\n", |
126 | 128 | " # Calculate the derivative with respect to the first element\n", |
127 | | - " new = self.depends_on[1] * self.grad\n", |
| 129 | + " new = self.depends_on[1] * gradient_to_send\n", |
128 | 130 | " # Send backward the derivative with respect to that element\n", |
129 | 131 | " self.depends_on[0].backward(new.num)\n", |
130 | 132 | "\n", |
131 | 133 | " # Calculate the derivative with respect to the second element\n", |
132 | | - " new = self.depends_on[0] * self.grad\n", |
| 134 | + " new = self.depends_on[0] * gradient_to_send\n", |
133 | 135 | " # Send backward the derivative with respect to that element\n", |
134 | 136 | " self.depends_on[1].backward(new.num)" |
135 | 137 | ] |
136 | 138 | }, |
137 | 139 | { |
138 | 140 | "cell_type": "code", |
139 | | - "execution_count": 5, |
| 141 | + "execution_count": 6, |
140 | 142 | "metadata": {}, |
141 | 143 | "outputs": [ |
142 | 144 | { |
|
157 | 159 | "print(b.grad) # as expected" |
158 | 160 | ] |
159 | 161 | }, |
| 162 | + { |
| 163 | + "cell_type": "code", |
| 164 | + "execution_count": 8, |
| 165 | + "metadata": {}, |
| 166 | + "outputs": [ |
| 167 | + { |
| 168 | + "name": "stdout", |
| 169 | + "output_type": "stream", |
| 170 | + "text": [ |
| 171 | + "24\n" |
| 172 | + ] |
| 173 | + } |
| 174 | + ], |
| 175 | + "source": [ |
| 176 | + "a = NumberWithGrad(3)\n", |
| 177 | + "b = a * 4\n", |
| 178 | + "c = b + 3\n", |
| 179 | + "d = b * 5\n", |
| 180 | + "e = c + d\n", |
| 181 | + "\n", |
| 182 | + "\n", |
| 183 | + "e.backward()\n", |
| 184 | + "print(a.grad)" |
| 185 | + ] |
| 186 | + }, |
160 | 187 | { |
161 | 188 | "cell_type": "code", |
162 | 189 | "execution_count": 6, |
|
202 | 229 | ], |
203 | 230 | "metadata": { |
204 | 231 | "kernelspec": { |
205 | | - "display_name": "Python 3", |
| 232 | + "display_name": "Python 3.9.13 64-bit (microsoft store)", |
206 | 233 | "language": "python", |
207 | 234 | "name": "python3" |
208 | 235 | }, |
|
216 | 243 | "name": "python", |
217 | 244 | "nbconvert_exporter": "python", |
218 | 245 | "pygments_lexer": "ipython3", |
219 | | - "version": "3.7.4" |
| 246 | + "version": "3.9.13" |
| 247 | + }, |
| 248 | + "vscode": { |
| 249 | + "interpreter": { |
| 250 | + "hash": "96f88a1d939096e74b5883cdeb3bbaf3df602d5ab14210a6f9e7d8e0ea241fea" |
| 251 | + } |
220 | 252 | } |
221 | 253 | }, |
222 | 254 | "nbformat": 4, |
|
0 commit comments