Skip to content

Commit bb245b2

Browse files
committed
drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
1 parent 60d20da commit bb245b2

3 files changed

Lines changed: 194 additions & 215 deletions

File tree

deepspeed/runtime/zero/linear.py

Lines changed: 31 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#when implemented outside of torch.autograd.Function
1717

1818
import math
19-
import functools
2019

2120
import torch
2221
from torch import Tensor
@@ -32,139 +31,57 @@ def print_rank_0(message, debug=False, force=False):
3231
print(message)
3332

3433

35-
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
36-
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
37-
38-
# PyTorch >= 2.0 supports setup_context, which is required for
39-
# torch.func transforms (vmap, grad, jvp, jacrev, etc.)
40-
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context')
41-
42-
if _SUPPORTS_SETUP_CONTEXT:
43-
44-
class LinearFunctionForZeroStage3(torch.autograd.Function):
45-
46-
@staticmethod
47-
# bias is an optional argument
48-
def forward(input, weight, bias=None):
49-
50-
if input.dim() == 2 and bias is not None:
51-
# fused op is marginally faster
52-
ret = torch.addmm(bias, input, weight.t())
53-
else:
54-
output = input.matmul(weight.t())
55-
if bias is not None:
56-
output += bias
57-
ret = output
58-
59-
return ret
60-
61-
@staticmethod
62-
def setup_context(ctx, inputs, output):
63-
# Replicate autocast state that @autocast_custom_fwd normally sets on ctx,
64-
# since the decorator assumes args[0] is ctx which is unavailable in the
65-
# separate forward() + setup_context() pattern.
66-
device_type = get_accelerator().device_name()
67-
ctx._dtype = torch.get_autocast_dtype(device_type)
68-
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
69-
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
70-
ctx.save_for_backward(input, weight, bias)
71-
72-
# This function has only a single output, so it gets only one gradient
73-
@staticmethod
74-
def backward(ctx, grad_output):
75-
# Do not use @autocast_custom_bwd here: it pairs with @autocast_custom_fwd on
76-
# legacy forward(ctx, ...). With forward + setup_context, use AMP state from setup_context.
77-
device_type = get_accelerator().device_name()
78-
if getattr(ctx, "_fwd_used_autocast", False):
79-
with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype):
80-
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
81-
return LinearFunctionForZeroStage3._backward_core(ctx, grad_output)
82-
83-
@staticmethod
84-
def _backward_core(ctx, grad_output):
85-
input, weight, bias = ctx.saved_tensors
34+
class LinearFunctionForZeroStage3(torch.autograd.Function):
8635

87-
grad_input = grad_weight = grad_bias = None
36+
@staticmethod
37+
# bias is an optional argument
38+
def forward(input, weight, bias=None):
8839

89-
dim = grad_output.dim()
90-
if ctx.needs_input_grad[0]:
91-
grad_input = grad_output.matmul(weight)
92-
if ctx.needs_input_grad[1]:
93-
if dim > 2:
94-
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
95-
input.reshape(-1, input.shape[-1]))
96-
else:
97-
grad_weight = grad_output.t().matmul(input)
98-
if bias is not None and ctx.needs_input_grad[2]:
99-
if dim > 2:
100-
grad_bias = grad_output.sum([i for i in range(dim - 1)])
101-
else:
102-
grad_bias = grad_output.sum(0)
103-
return grad_input, grad_weight, grad_bias
104-
105-
else:
106-
107-
class LinearFunctionForZeroStage3(torch.autograd.Function):
108-
109-
# Note that both forward and backward are @staticmethods
110-
@staticmethod
111-
@autocast_custom_fwd
112-
# bias is an optional argument
113-
def forward(ctx, input, weight, bias=None):
114-
115-
ctx.save_for_backward(input, weight, bias)
116-
117-
if input.dim() == 2 and bias is not None:
118-
# fused op is marginally faster
119-
ret = torch.addmm(bias, input, weight.t())
120-
else:
121-
output = input.matmul(weight.t())
122-
if bias is not None:
123-
output += bias
124-
ret = output
125-
126-
return ret
127-
128-
# This function has only a single output, so it gets only one gradient
129-
@staticmethod
130-
@autocast_custom_bwd
131-
def backward(ctx, grad_output):
132-
# This is a pattern that is very convenient - at the top of backward
133-
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
134-
# None. Thanks to the fact that additional trailing Nones are
135-
# ignored, the return statement is simple even when the function has
136-
# optional inputs.
40+
if input.dim() == 2 and bias is not None:
41+
# fused op is marginally faster
42+
ret = torch.addmm(bias, input, weight.t())
43+
else:
44+
output = input.matmul(weight.t())
45+
if bias is not None:
46+
output += bias
47+
ret = output
48+
49+
return ret
50+
51+
@staticmethod
52+
def setup_context(ctx, inputs, output):
53+
device_type = get_accelerator().device_name()
54+
ctx._dtype = torch.get_autocast_dtype(device_type)
55+
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
56+
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
57+
ctx.save_for_backward(input, weight, bias)
58+
59+
# This function has only a single output, so it gets only one gradient
60+
@staticmethod
61+
def backward(ctx, grad_output):
62+
# Match @custom_bwd semantics: always run backward under the same
63+
# autocast state as forward — including explicitly disabling autocast
64+
# when forward did not use it, to guard against outer autocast regions.
65+
device_type = get_accelerator().device_name()
66+
with torch.amp.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype):
13767
input, weight, bias = ctx.saved_tensors
13868

13969
grad_input = grad_weight = grad_bias = None
14070

141-
#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
142-
# These needs_input_grad checks are optional and there only to
143-
# improve efficiency. If you want to make your code simpler, you can
144-
# skip them. Returning gradients for inputs that don't require it is
145-
# not an error.
14671
dim = grad_output.dim()
14772
if ctx.needs_input_grad[0]:
148-
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
14973
grad_input = grad_output.matmul(weight)
150-
#print(f"Computed grad input {grad_input.shape}")
15174
if ctx.needs_input_grad[1]:
152-
#print("Computing grad weight")
15375
if dim > 2:
15476
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
15577
input.reshape(-1, input.shape[-1]))
15678
else:
15779
grad_weight = grad_output.t().matmul(input)
158-
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
15980
if bias is not None and ctx.needs_input_grad[2]:
160-
#print("Computing grad bias")
16181
if dim > 2:
16282
grad_bias = grad_output.sum([i for i in range(dim - 1)])
16383
else:
16484
grad_bias = grad_output.sum(0)
165-
#print("Done computing grad bias")
166-
#print("needs bias")
167-
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
16885
return grad_input, grad_weight, grad_bias
16986

17087

deepspeed/runtime/zero/parameter_offload.py

Lines changed: 43 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818

1919
FWD_MODULE_STACK = list()
2020

21-
# PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms.
22-
# Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable.
23-
_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, "setup_context")
24-
2521

2622
#for each tensor in outputs run the forward_function and register backward_function as hook
2723
def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs):
@@ -405,45 +401,24 @@ def _run_before_backward_function(sub_module):
405401
sub_module.applied_pre_backward_ref_cnt -= 1
406402
#print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}")
407403

408-
if _SUPPORTS_SETUP_CONTEXT:
409-
410-
class PreBackwardFunctionForModule(torch.autograd.Function):
411-
412-
@staticmethod
413-
def forward(outputs):
414-
return outputs.detach()
415-
416-
@staticmethod
417-
def setup_context(ctx, inputs, output):
418-
ctx.module = module
419-
ctx.pre_backward_function = _run_before_backward_function
420-
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
421-
ctx.module.applied_pre_backward_ref_cnt = 0
422-
ctx.module.applied_pre_backward_ref_cnt += 1
423-
424-
@staticmethod
425-
def backward(ctx, *args):
426-
ctx.pre_backward_function(ctx.module)
427-
return args
404+
class PreBackwardFunctionForModule(torch.autograd.Function):
428405

429-
else:
430-
431-
class PreBackwardFunctionForModule(torch.autograd.Function):
406+
@staticmethod
407+
def forward(outputs):
408+
return outputs.detach()
432409

433-
@staticmethod
434-
def forward(ctx, outputs):
435-
ctx.module = module
436-
ctx.pre_backward_function = _run_before_backward_function
437-
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
438-
ctx.module.applied_pre_backward_ref_cnt = 0
439-
ctx.module.applied_pre_backward_ref_cnt += 1
440-
outputs = outputs.detach()
441-
return outputs
410+
@staticmethod
411+
def setup_context(ctx, inputs, output):
412+
ctx.module = module
413+
ctx.pre_backward_function = _run_before_backward_function
414+
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
415+
ctx.module.applied_pre_backward_ref_cnt = 0
416+
ctx.module.applied_pre_backward_ref_cnt += 1
442417

443-
@staticmethod
444-
def backward(ctx, *args):
445-
ctx.pre_backward_function(ctx.module)
446-
return args
418+
@staticmethod
419+
def backward(ctx, *args):
420+
ctx.pre_backward_function(ctx.module)
421+
return args
447422

448423
module.pre_bwd_fn = PreBackwardFunctionForModule
449424

@@ -457,64 +432,34 @@ def _run_after_backward_function(sub_module):
457432
if sub_module.ds_grads_remaining == 0:
458433
self.post_sub_module_backward_function(sub_module)
459434

460-
if _SUPPORTS_SETUP_CONTEXT:
461-
462-
class PostBackwardFunctionModule(torch.autograd.Function):
463-
464-
@staticmethod
465-
def forward(output):
466-
return output.detach()
467-
468-
@staticmethod
469-
def setup_context(ctx, inputs, output):
470-
(output_in, ) = inputs
471-
ctx.module = module
472-
if output_in.requires_grad:
473-
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
474-
#Should only cause increase in memory not correctness issue
475-
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
476-
# ctx.view=True
477-
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
478-
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
479-
#if module.ds_grads_remaining == 0:
480-
# print(f"Before Forward: {ctx.module.__class__.__name__}")
481-
module.ds_grads_remaining += 1
482-
ctx.post_backward_function = _run_after_backward_function
483-
484-
@staticmethod
485-
def backward(ctx, *args):
486-
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
487-
if ctx.module.ds_grads_remaining == 0:
488-
ctx.post_backward_function(ctx.module)
489-
return args
490-
491-
else:
492-
493-
class PostBackwardFunctionModule(torch.autograd.Function):
494-
495-
@staticmethod
496-
def forward(ctx, output):
497-
ctx.module = module
498-
if output.requires_grad:
499-
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
500-
#Should only cause increase in memory not correctness issue
501-
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
502-
# ctx.view=True
503-
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
504-
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
505-
#if module.ds_grads_remaining == 0:
506-
# print(f"Before Forward: {ctx.module.__class__.__name__}")
507-
module.ds_grads_remaining += 1
508-
ctx.post_backward_function = _run_after_backward_function
509-
output = output.detach()
510-
return output
511-
512-
@staticmethod
513-
def backward(ctx, *args):
514-
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
515-
if ctx.module.ds_grads_remaining == 0:
516-
ctx.post_backward_function(ctx.module)
517-
return args
435+
class PostBackwardFunctionModule(torch.autograd.Function):
436+
437+
@staticmethod
438+
def forward(output):
439+
return output.detach()
440+
441+
@staticmethod
442+
def setup_context(ctx, inputs, output):
443+
(output_in, ) = inputs
444+
ctx.module = module
445+
if output_in.requires_grad:
446+
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
447+
#Should only cause increase in memory not correctness issue
448+
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
449+
# ctx.view=True
450+
# print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly")
451+
#assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors."
452+
#if module.ds_grads_remaining == 0:
453+
# print(f"Before Forward: {ctx.module.__class__.__name__}")
454+
module.ds_grads_remaining += 1
455+
ctx.post_backward_function = _run_after_backward_function
456+
457+
@staticmethod
458+
def backward(ctx, *args):
459+
ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1
460+
if ctx.module.ds_grads_remaining == 0:
461+
ctx.post_backward_function(ctx.module)
462+
return args
518463

519464
module.post_bwd_fn = PostBackwardFunctionModule
520465

0 commit comments

Comments
 (0)