|
115 | 115 | " else:\n", |
116 | 116 | " self.grad += backward_grad\n", |
117 | 117 | " \n", |
| 118 | + " gradient_to_send = backward_grad if backward_grad is not None else 1\n", |
118 | 119 | " if self.creation_op == \"add\":\n", |
119 | | - " # Simply send backward self.grad, since increasing either of these \n", |
| 120 | + " # Simply send backward backward_grad, since increasing either of these \n", |
120 | 121 | " # elements will increase the output by that same amount\n", |
121 | | - " self.depends_on[0].backward(self.grad)\n", |
122 | | - " self.depends_on[1].backward(self.grad) \n", |
| 122 | + " self.depends_on[0].backward(gradient_to_send)\n", |
| 123 | + " self.depends_on[1].backward(gradient_to_send) \n", |
123 | 124 | "\n", |
124 | 125 | " if self.creation_op == \"mul\":\n", |
125 | 126 | "\n", |
126 | 127 | " # Calculate the derivative with respect to the first element\n", |
127 | | - " new = self.depends_on[1] * self.grad\n", |
| 128 | + " new = self.depends_on[1] * gradient_to_send\n", |
128 | 129 | " # Send backward the derivative with respect to that element\n", |
129 | 130 | " self.depends_on[0].backward(new.num)\n", |
130 | 131 | "\n", |
131 | 132 | " # Calculate the derivative with respect to the second element\n", |
132 | | - " new = self.depends_on[0] * self.grad\n", |
| 133 | + " new = self.depends_on[0] * gradient_to_send\n", |
133 | 134 | " # Send backward the derivative with respect to that element\n", |
134 | 135 | " self.depends_on[1].backward(new.num)" |
135 | 136 | ] |
|
157 | 158 | "print(b.grad) # as expected" |
158 | 159 | ] |
159 | 160 | }, |
| 161 | + { |
| 162 | + "cell_type": "code", |
| 163 | + "execution_count": 8, |
| 164 | + "metadata": {}, |
| 165 | + "outputs": [ |
| 166 | + { |
| 167 | + "name": "stdout", |
| 168 | + "output_type": "stream", |
| 169 | + "text": [ |
| 170 | + "24\n" |
| 171 | + ] |
| 172 | + } |
| 173 | + ], |
| 174 | + "source": [ |
| 175 | + "a = NumberWithGrad(3)\n", |
| 176 | + "b = a * 4\n", |
| 177 | + "c = b + 3\n", |
| 178 | + "d = b * 5\n", |
| 179 | + "e = c + d\n", |
| 180 | + "\n", |
| 181 | + "\n", |
| 182 | + "e.backward()\n", |
| 183 | + "print(a.grad)" |
| 184 | + ] |
| 185 | + }, |
160 | 186 | { |
161 | 187 | "cell_type": "code", |
162 | 188 | "execution_count": 6, |
|
0 commit comments