|
| 1 | +import argparse |
| 2 | +from dataclasses import dataclass |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.distributed as dist |
| 6 | +import deepspeed |
| 7 | +from transformers import AutoModelForCausalLM |
| 8 | + |
| 9 | + |
| 10 | +@dataclass |
| 11 | +class ModelParallelUnit: |
| 12 | + """Minimal MPU for DeepSpeed TP+DP.""" |
| 13 | + |
| 14 | + tp_group: dist.ProcessGroup |
| 15 | + dp_group: dist.ProcessGroup |
| 16 | + tp_size: int |
| 17 | + dp_size: int |
| 18 | + tp_rank: int |
| 19 | + dp_rank: int |
| 20 | + |
| 21 | + def get_data_parallel_group(self): |
| 22 | + return self.dp_group |
| 23 | + |
| 24 | + def get_model_parallel_group(self): |
| 25 | + return self.tp_group |
| 26 | + |
| 27 | + def get_data_parallel_world_size(self): |
| 28 | + return self.dp_size |
| 29 | + |
| 30 | + def get_model_parallel_world_size(self): |
| 31 | + return self.tp_size |
| 32 | + |
| 33 | + def get_data_parallel_rank(self): |
| 34 | + return self.dp_rank |
| 35 | + |
| 36 | + def get_model_parallel_rank(self): |
| 37 | + return self.tp_rank |
| 38 | + |
| 39 | + |
| 40 | +def parse_args(): |
| 41 | + parser = argparse.ArgumentParser(description="AutoTP training example (distilled from verify_autotp).") |
| 42 | + parser.add_argument("--local_rank", type=int, default=-1, help="Passed by deepspeed/torchrun.") |
| 43 | + parser.add_argument("--model_name", type=str, default="meta-llama/Llama-3.1-8B") |
| 44 | + parser.add_argument("--tp_size", type=int, default=4) |
| 45 | + parser.add_argument("--dp_size", type=int, default=2) |
| 46 | + parser.add_argument("--zero_stage", type=int, default=2) |
| 47 | + parser.add_argument("--batch_size", type=int, default=1) |
| 48 | + parser.add_argument("--seq_length", type=int, default=1024) |
| 49 | + parser.add_argument("--num_steps", type=int, default=10) |
| 50 | + parser.add_argument("--learning_rate", type=float, default=2e-5) |
| 51 | + parser.add_argument("--precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) |
| 52 | + return parser.parse_args() |
| 53 | + |
| 54 | + |
| 55 | +def build_tp_dp_groups(rank, world_size, tp_size, dp_size): |
| 56 | + if tp_size * dp_size != world_size: |
| 57 | + raise ValueError(f"tp_size ({tp_size}) * dp_size ({dp_size}) must equal world_size ({world_size})") |
| 58 | + |
| 59 | + tp_rank = rank % tp_size |
| 60 | + dp_rank = rank // tp_size |
| 61 | + |
| 62 | + tp_group = None |
| 63 | + dp_group = None |
| 64 | + |
| 65 | + for dp_idx in range(dp_size): |
| 66 | + tp_ranks = list(range(dp_idx * tp_size, (dp_idx + 1) * tp_size)) |
| 67 | + group = dist.new_group(tp_ranks) |
| 68 | + if rank in tp_ranks: |
| 69 | + tp_group = group |
| 70 | + |
| 71 | + for tp_idx in range(tp_size): |
| 72 | + dp_ranks = [tp_idx + dp_idx * tp_size for dp_idx in range(dp_size)] |
| 73 | + group = dist.new_group(dp_ranks) |
| 74 | + if rank in dp_ranks: |
| 75 | + dp_group = group |
| 76 | + |
| 77 | + return tp_group, dp_group, tp_rank, dp_rank |
| 78 | + |
| 79 | + |
| 80 | +def broadcast_inputs(input_ids, labels, tp_group, tp_src_rank): |
| 81 | + dist.broadcast(input_ids, src=tp_src_rank, group=tp_group) |
| 82 | + dist.broadcast(labels, src=tp_src_rank, group=tp_group) |
| 83 | + |
| 84 | + |
| 85 | +def main(): |
| 86 | + args = parse_args() |
| 87 | + deepspeed.init_distributed() |
| 88 | + |
| 89 | + rank = dist.get_rank() |
| 90 | + world_size = dist.get_world_size() |
| 91 | + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") |
| 92 | + |
| 93 | + tp_group, dp_group, tp_rank, dp_rank = build_tp_dp_groups( |
| 94 | + rank, world_size, args.tp_size, args.dp_size |
| 95 | + ) |
| 96 | + |
| 97 | + model = AutoModelForCausalLM.from_pretrained(args.model_name) |
| 98 | + model = model.to(device) |
| 99 | + |
| 100 | + # AutoTP is enabled via the DeepSpeed config. |
| 101 | + ds_config = { |
| 102 | + "train_batch_size": args.batch_size * args.dp_size, |
| 103 | + "train_micro_batch_size_per_gpu": args.batch_size, |
| 104 | + "zero_optimization": {"stage": args.zero_stage}, |
| 105 | + "tensor_parallel": {"autotp_size": args.tp_size}, |
| 106 | + "data_parallel_size": args.dp_size, |
| 107 | + } |
| 108 | + if args.precision == "bf16": |
| 109 | + ds_config["bf16"] = {"enabled": True} |
| 110 | + elif args.precision == "fp16": |
| 111 | + ds_config["fp16"] = {"enabled": True} |
| 112 | + |
| 113 | + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) |
| 114 | + mpu = ModelParallelUnit(tp_group, dp_group, args.tp_size, args.dp_size, tp_rank, dp_rank) |
| 115 | + engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config, mpu=mpu) |
| 116 | + |
| 117 | + vocab_size = model.config.vocab_size |
| 118 | + for _ in range(args.num_steps): |
| 119 | + if tp_rank == 0: |
| 120 | + input_ids = torch.randint(0, vocab_size, (args.batch_size, args.seq_length), device=device) |
| 121 | + labels = input_ids.clone() |
| 122 | + else: |
| 123 | + input_ids = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) |
| 124 | + labels = torch.empty((args.batch_size, args.seq_length), dtype=torch.long, device=device) |
| 125 | + |
| 126 | + tp_src_rank = dp_rank * args.tp_size |
| 127 | + broadcast_inputs(input_ids, labels, tp_group, tp_src_rank) |
| 128 | + outputs = engine(input_ids=input_ids, labels=labels) |
| 129 | + engine.backward(outputs.loss) |
| 130 | + engine.step() |
| 131 | + |
| 132 | + if rank == 0: |
| 133 | + print("AutoTP example completed.") |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == "__main__": |
| 137 | + main() |
0 commit comments