@@ -44,7 +44,7 @@ def print_rank_0(message, debug=False, force=False):
4444 class LinearFunctionForZeroStage3 (torch .autograd .Function ):
4545
4646 @staticmethod
47- @ autocast_custom_fwd
47+ # bias is an optional argument
4848 def forward (input , weight , bias = None ):
4949
5050 if input .dim () == 2 and bias is not None :
@@ -60,7 +60,13 @@ def forward(input, weight, bias=None):
6060
6161 @staticmethod
6262 def setup_context (ctx , inputs , output ):
63- input , weight , bias = inputs
63+ # Replicate autocast state that @autocast_custom_fwd normally sets on ctx,
64+ # since the decorator assumes args[0] is ctx which is unavailable in the
65+ # separate forward() + setup_context() pattern.
66+ device_type = get_accelerator ().device_name ()
67+ ctx ._dtype = torch .get_autocast_dtype (device_type )
68+ ctx ._fwd_used_autocast = torch .is_autocast_enabled (device_type )
69+ input , weight , bias = inputs [0 ], inputs [1 ], inputs [2 ] if len (inputs ) > 2 else None
6470 ctx .save_for_backward (input , weight , bias )
6571
6672 # This function has only a single output, so it gets only one gradient
0 commit comments