Skip to content

Commit 444122c

Browse files
roycho96zhangj1an
andcommitted
fix: remove @autocast_custom_fwd from forward, move autocast state to setup_context
Co-authored-by: zhangj1an <jianmusings@gmail.com>
1 parent 39b1755 commit 444122c

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

deepspeed/runtime/zero/linear.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def print_rank_0(message, debug=False, force=False):
4444
class LinearFunctionForZeroStage3(torch.autograd.Function):
4545

4646
@staticmethod
47-
@autocast_custom_fwd
47+
# bias is an optional argument
4848
def forward(input, weight, bias=None):
4949

5050
if input.dim() == 2 and bias is not None:
@@ -60,7 +60,13 @@ def forward(input, weight, bias=None):
6060

6161
@staticmethod
6262
def setup_context(ctx, inputs, output):
63-
input, weight, bias = inputs
63+
# Replicate autocast state that @autocast_custom_fwd normally sets on ctx,
64+
# since the decorator assumes args[0] is ctx which is unavailable in the
65+
# separate forward() + setup_context() pattern.
66+
device_type = get_accelerator().device_name()
67+
ctx._dtype = torch.get_autocast_dtype(device_type)
68+
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
69+
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
6470
ctx.save_for_backward(input, weight, bias)
6571

6672
# This function has only a single output, so it gets only one gradient

0 commit comments

Comments
 (0)