diff --git a/autoparallel/__init__.py b/autoparallel/__init__.py index d3b0aca0..3ca70871 100644 --- a/autoparallel/__init__.py +++ b/autoparallel/__init__.py @@ -7,6 +7,7 @@ from autoparallel.collectives import with_sharding_constraint from autoparallel.compile import autoparallel_backend from autoparallel.input_validation import ForwardInputs +from autoparallel.mesh_search import build_split_dim_seed __all__ = [ "auto_parallel", @@ -14,4 +15,5 @@ "autoparallel_backend", "ForwardInputs", "with_sharding_constraint", + "build_split_dim_seed", ] diff --git a/autoparallel/cost_models/nccl_cost_model.py b/autoparallel/cost_models/nccl_cost_model.py index 15d6a3d5..dfcc9bc7 100644 --- a/autoparallel/cost_models/nccl_cost_model.py +++ b/autoparallel/cost_models/nccl_cost_model.py @@ -80,6 +80,7 @@ class NCCLTopoConfig: has_collnet: bool = False # Enables CollNet Direct/Chain (SHARP) # Additional network latency beyond base hw latency (us) net_latency: float = 0.0 + mesh_dim_topo_override: "MeshDimTopo | None" = None @dataclass @@ -1005,6 +1006,15 @@ def derive_mesh_dim_topo( ) -> MeshDimTopo: """Derive per-mesh-dimension NCCL topology parameters.""" dim_size = mesh_shape[dim_idx] + if config.mesh_dim_topo_override is not None: + if len(mesh_shape) != 1 or dim_idx != 0: + raise ValueError("mesh_dim_topo_override can only be used for 1D dim0") + if config.mesh_dim_topo_override.n_ranks != dim_size: + raise ValueError( + "mesh_dim_topo_override.n_ranks must match the 1D mesh size" + ) + return config.mesh_dim_topo_override + inner_product = math.prod(mesh_shape[dim_idx + 1 :]) ppn = max(1, min(config.gpus_per_node // inner_product, dim_size)) n_nodes = dim_size // ppn diff --git a/autoparallel/mesh_search.py b/autoparallel/mesh_search.py new file mode 100644 index 00000000..dd0d7e29 --- /dev/null +++ b/autoparallel/mesh_search.py @@ -0,0 +1,239 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import replace +from typing import Any, Callable, cast + +import torch +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard + +from .cost_models.collective_runtime_estimation import ( + get_nccl_topo_config, + reset_comms_cost_cache, + set_nccl_topo_config, +) +from .cost_models.compute_estimation import reset_compute_cost_cache +from .cost_models.nccl_cost_model import ( + NCCLTopoConfig, + derive_mesh_dim_topo, + detect_nccl_topo_config, +) +from .optimize_sharding import ShardingOptimizer +from .shardings.placement_options import reset_placement_options_cache + + +def reset_mesh_search_caches() -> None: + """Clear mesh-dependent strategy and cost caches.""" + + reset_placement_options_cache() + reset_comms_cost_cache() + reset_compute_cost_cache() + + +def _set_cost_model_for_mesh(mesh, cost_model: Any) -> None: + if isinstance(cost_model, NCCLTopoConfig): + set_nccl_topo_config(cost_model) + elif cost_model == "nccl": + set_nccl_topo_config(detect_nccl_topo_config(mesh)) + else: + set_nccl_topo_config(None) + + +def _placement_code(p: Placement) -> str: + if isinstance(p, Shard): + return f"S{p.dim}" + if isinstance(p, Replicate): + return "R" + return type(p).__name__ + + +def _split_dim_seed_dim_cost_model( + cost_model: Any, + mesh_shape: tuple[int, ...], + dim_idx: int, + *, + fabric_aware: bool, +) -> Any: + if not fabric_aware or not isinstance(cost_model, NCCLTopoConfig): + return cost_model + + topo = derive_mesh_dim_topo(cost_model, mesh_shape, dim_idx) + return replace(cost_model, mesh_dim_topo_override=topo) + + +def _split_dim_seed_cache_key( + size: int, + input_placement: Placement, + cost_model: Any, + mesh_shape: tuple[int, ...], + dim_idx: int, + *, + fabric_aware: bool, +) -> tuple[Any, ...]: + placement = _placement_code(input_placement) + if isinstance(cost_model, NCCLTopoConfig): + dim_cost_model = _split_dim_seed_dim_cost_model( + cost_model, mesh_shape, dim_idx, fabric_aware=fabric_aware + ) + topo = derive_mesh_dim_topo(dim_cost_model, (int(size),), 0) + return ( + "nccl", + int(size), + placement, + cost_model.arch.name, + cost_model.num_nodes, + cost_model.gpus_per_node, + cost_model.bw_intra, + cost_model.bw_inter, + cost_model.num_channels, + topo.n_nodes, + topo.ppn, + topo.bw_intra, + topo.bw_inter, + topo.n_channels, + dim_cost_model.has_nvswitch, + dim_cost_model.has_collnet, + dim_cost_model.net_latency, + ) + return (str(cost_model), int(size), placement) + + +def _first_output_placements(output_specs) -> tuple[Placement, ...] | None: + if isinstance(output_specs, DTensorSpec): + return output_specs.placements + if isinstance(output_specs, (tuple, list)): + for output_spec in output_specs: + if isinstance(output_spec, DTensorSpec): + return output_spec.placements + return None + + +def build_split_dim_seed( + gm: torch.fx.GraphModule, + mesh_shape: tuple[int, ...], + input_placements: tuple[Placement, ...], + *, + cost_model: Any = "nccl", + force_grad_reduce_in_higher_precision: bool = False, + repeated_subgraphs: bool = True, + memory_high_fn: Callable[[int], float] | None = None, + one_d_cache: dict[tuple[Any, ...], dict[str, Placement]] | None = None, + device_type: str = "cuda", + fabric_aware: bool = True, +) -> dict[str, tuple[Placement, ...]]: + """Return a per-node placement seed for a target mesh shape. + + Args: + gm: Joint graph to optimize. + mesh_shape: Target mesh shape. + input_placements: Required input placement for each target mesh dim. + cost_model: Cost model identifier or NCCL topology config. + force_grad_reduce_in_higher_precision: Whether gradient reductions use + higher precision costs. + repeated_subgraphs: Whether repeated graph regions share decisions. + memory_high_fn: Function returning the parameter memory upper bound for + a one-dimensional solve size. + one_d_cache: Optional cache reused across calls. + device_type: Device mesh type. + fabric_aware: Whether one-dimensional solves use per-dim fabric topology. + + Returns: + A mapping from FX node name to placement tuple. + """ + + ndim = len(mesh_shape) + if len(input_placements) != ndim: + raise ValueError( + f"input_placements has {len(input_placements)} entries, expected {ndim}" + ) + if memory_high_fn is None: + memory_high_fn = lambda size: 1.0 / size # noqa: E731 + + cache = one_d_cache if one_d_cache is not None else {} + seed_cost_model = cost_model + if fabric_aware and cost_model == "nccl": + with unset_fake_temporarily(): + full_mesh = init_device_mesh( + device_type, + mesh_shape, + mesh_dim_names=tuple(f"d{i}" for i in range(ndim)), + ) + seed_cost_model = detect_nccl_topo_config(full_mesh) + + per_dim: list[dict[str, Placement]] = [] + for dim_idx, size in enumerate(mesh_shape): + input_placement = input_placements[dim_idx] + key = _split_dim_seed_cache_key( + int(size), + input_placement, + seed_cost_model, + mesh_shape, + dim_idx, + fabric_aware=fabric_aware, + ) + if key not in cache: + with unset_fake_temporarily(): + mesh_1d = init_device_mesh( + device_type, + (int(size),), + mesh_dim_names=("d",), + ) + prev = get_nccl_topo_config() + try: + dim_cost_model = _split_dim_seed_dim_cost_model( + seed_cost_model, + mesh_shape, + dim_idx, + fabric_aware=fabric_aware, + ) + _set_cost_model_for_mesh(mesh_1d, dim_cost_model) + reset_mesh_search_caches() + opt = ShardingOptimizer( + gm, + mesh_1d, + force_grad_reduce_in_higher_precision, + repeated_subgraphs=repeated_subgraphs, + ) + opt.add_sharded_input_constraint([(input_placement,)]) + opt.add_sharded_output_constraint([(input_placement,)]) + opt.add_parameter_memory_constraint(0.0, memory_high_fn(int(size))) + solution = opt.get_solution() + finally: + set_nccl_topo_config(prev) + + node_placements: dict[str, Placement] = {} + for node, strategy in solution.items(): + placements = _first_output_placements(strategy.output_specs) + if placements is not None: + node_placements[node.name] = placements[0] + cache[key] = node_placements + per_dim.append(cache[key]) + + seed: dict[str, tuple[Placement, ...]] = {} + for node in gm.graph.nodes: + if node.op == "output": + continue + seed[node.name] = tuple( + per_dim[i].get(node.name, Replicate()) for i in range(ndim) + ) + + from torch._functorch._aot_autograd.fx_utils import ( + get_plain_input_and_grad_nodes, + get_plain_output_and_tangent_nodes, + ) + + input_tuple = tuple(input_placements) + for getter in (get_plain_input_and_grad_nodes, get_plain_output_and_tangent_nodes): + for _desc, (node, companion) in cast(Any, getter(gm.graph)).items(): + seed[node.name] = input_tuple + if companion is not None: + seed[companion.name] = input_tuple + + return seed diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 6b72878b..48a59633 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -102,8 +102,11 @@ build_param_derived_set, build_terminal_derived_set, ) -from .shardings.placement_options import get_placement_options_for_node -from .shardings.propagation_rules import _create_all_options +from .shardings.placement_options import ( + get_placement_options_for_node, + reset_placement_options_cache, +) +from .shardings.propagation_rules import _create_all_options, set_current_seed_node logger = logging.getLogger(__name__) @@ -224,6 +227,8 @@ def __init__( mesh, force_grad_reduce_in_higher_precision=False, repeated_subgraphs=False, + strategy_seed=None, + strategy_radius=None, ): self.orig_gm = gm # The optimizer works on a concretized copy of the graph where all @@ -239,6 +244,8 @@ def __init__( self.force_grad_reduce_in_higher_precision = ( force_grad_reduce_in_higher_precision ) + self.strategy_seed = strategy_seed + self.strategy_radius = strategy_radius self._constraint_log: list[tuple[str, dict]] = [] self._memory_constraint: tuple[float, float] | None = None # Maps ILP constraint name → node_name for active node constraints, @@ -247,7 +254,21 @@ def __init__( self._node_constraint_names: dict[str, str] = {} self._name_counters: dict[str, int] = {} t0 = time.perf_counter() - self.strats = self.build_sharding_metadata() + if self.strategy_seed is None: + self.strats = self.build_sharding_metadata() + else: + from .shardings import propagation_rules as _propagation_rules + + _propagation_rules.set_strategy_seed( + self.strategy_seed, self.strategy_radius + ) + reset_placement_options_cache() + try: + self.strats = self.build_sharding_metadata() + finally: + _propagation_rules.set_strategy_seed(None, None) + set_current_seed_node(None) + reset_placement_options_cache() # nodes/node_map are derived from strats (not graph.nodes) so that # shape-computation nodes skipped by build_sharding_metadata don't # appear and indices stay consistent. @@ -306,6 +327,7 @@ def _normalize_node(self, node): def build_sharding_metadata(self): strats = {} for node in self.graph.nodes: + set_current_seed_node(node.name) if node.op in ("placeholder", "get_attr"): val = node.meta.get("val") if isinstance(val, torch.Tensor): @@ -351,6 +373,7 @@ def build_sharding_metadata(self): strats[node] = user_strats else: raise ValueError(f"Unexpected node op: {node.op}") + set_current_seed_node(None) return strats def create_cluster_links(self, clusters): @@ -962,6 +985,68 @@ def get_solution(self, verbose=False): ) return self._to_orig_solution(self._extract_and_validate_solution()) + def solve_lp_relaxation(self, verbose=False, frac_tol=1e-6, extract=False): + """Solve the LP relaxation and return objective, status, and fractionality. + + Args: + verbose: Whether to print solver output. + frac_tol: Tolerance for counting a variable as fractional. + extract: Return a sharding solution when the relaxation is integral. + """ + old_objective = self.prob.objective + if old_objective is None: + self._set_objective() + self._apply_memory_constraint() + + variables = self.prob.variables() + original_cats = [v.cat for v in variables] + t0 = time.perf_counter() + solution = None + try: + for v in variables: + v.cat = pulp.LpContinuous + + solver = pulp.PULP_CBC_CMD(msg=verbose) + with tempfile.TemporaryDirectory() as tmpdir: + solver.tmpDir = tmpdir + self.prob.solve(solver) + + objective = pulp.value(self.prob.objective) + n_fractional = 0 + n_vars = 0 + for v in variables: + val = v.value() + if val is None: + continue + n_vars += 1 + if frac_tol < val < 1.0 - frac_tol: + n_fractional += 1 + + status = pulp.LpStatus[self.prob.status] + if extract and status == "Optimal" and n_fractional == 0: + self.selected_keys = [ + key + for key, dv in self.decision_vars.items() + if dv.var.value() is not None and dv.var.value() > 0.5 + ] + for root_key in list(self.selected_keys): + self.selected_keys.extend(self._root_to_linked.get(root_key, [])) + solution = self._to_orig_solution(self._extract_and_validate_solution()) + finally: + for v, cat in zip(variables, original_cats): + v.cat = cat + if old_objective is None: + self.prob.objective = None + + return { + "objective": objective, + "solve_time": time.perf_counter() - t0, + "n_fractional": n_fractional, + "n_vars": n_vars, + "status": status, + "solution": solution, + } + def resolve(self, verbose=False): """Re-solve the ILP after adding or removing constraints. diff --git a/autoparallel/shardings/placement_options.py b/autoparallel/shardings/placement_options.py index e2d3496a..bce8b308 100644 --- a/autoparallel/shardings/placement_options.py +++ b/autoparallel/shardings/placement_options.py @@ -30,7 +30,14 @@ from autoparallel.shardings.propagation_rules import generate_dummy_redistribute_costs from .dtensor_sharding_helpers import get_op_strategy, with_implicit_strategies -from .propagation_rules import _op_rules, remove_invalid_configs +from .propagation_rules import ( + _op_rules, + get_current_seed_node, + get_strategy_radius, + get_strategy_seed, + remove_invalid_configs, + within_strategy_seed_ball, +) logger = logging.getLogger(__name__) @@ -276,19 +283,53 @@ def reset_placement_options_timer(): _placement_options_timer = PlacementOptionsTimer() +def _seed_cache_key(): + seed = get_strategy_seed() + if seed is None: + return None + node_name = get_current_seed_node() + placements = seed.get(node_name) if node_name is not None else None + placement_key = None if placements is None else tuple(str(p) for p in placements) + return node_name, placement_key, get_strategy_radius() + + +def _filter_by_strategy_seed(out_strat): + if get_strategy_seed() is None: + return out_strat + + kept = [] + for strategy in out_strat.strategies: + specs = strategy.output_specs + placements = None + if isinstance(specs, DTensorSpec): + placements = specs.placements + elif isinstance(specs, (tuple, list)): + for spec in specs: + if isinstance(spec, DTensorSpec): + placements = spec.placements + break + if placements is None or within_strategy_seed_ball(placements): + kept.append(strategy) + + return out_strat if not kept else OpStrategy(kept) + + def get_placement_options(mesh, op, specs, user_args, user_kwargs): assert len(specs) == len(user_args) timer = _placement_options_timer t_start = time.perf_counter() try: + mesh_key = (id(mesh), mesh.device_type, tuple(mesh.shape), mesh.ndim) cache_key = ( + mesh_key, op, tuple(_fingerprint_arg(s) for s in specs), tuple(_fingerprint_arg(a) for a in user_args), tuple(_fingerprint_arg(v) for v in user_kwargs.values()) if user_kwargs else (), + _seed_cache_key(), ) hash(cache_key) # fail fast if key contains unhashable types (e.g. SymInts) except TypeError: @@ -330,6 +371,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): else: with with_implicit_strategies(): out_strat = get_op_strategy(op, op_schema) + out_strat = _filter_by_strategy_seed(out_strat) t1 = time.perf_counter() # operator.getitem is self-contained: its input is a tuple of tensors @@ -397,9 +439,7 @@ def get_local_map_placement_option( mesh, None, ), "Not yet implemented" - assert "call_local_map" in str(node.target) or "call_local_map_backward" in str( - node.target - ) + assert "local_map" in str(node.target) in_specs = [] num_activation_inputs = len(user_args) - len(in_placements) # activations are always replicated diff --git a/autoparallel/shardings/propagation_rules.py b/autoparallel/shardings/propagation_rules.py index 0c5a9849..58ec8c1d 100644 --- a/autoparallel/shardings/propagation_rules.py +++ b/autoparallel/shardings/propagation_rules.py @@ -53,6 +53,55 @@ from ..cast_parametrization import dtype_cast # noqa from .dtensor_sharding_helpers import _try_single_dim_strategy, get_op_strategy +_strategy_seed: "dict[str, tuple[Placement, ...]] | None" = None +_strategy_radius: "int | dict[str, int] | None" = None +_current_seed_node: "str | None" = None + + +def set_strategy_seed(seed, radius): + global _strategy_seed, _strategy_radius + _strategy_seed = seed + _strategy_radius = radius + + +def get_strategy_seed(): + return _strategy_seed + + +def get_strategy_radius(): + return _strategy_radius + + +def set_current_seed_node(name): + global _current_seed_node + _current_seed_node = name + + +def get_current_seed_node(): + return _current_seed_node + + +def within_strategy_seed_ball(placements) -> bool: + if _strategy_seed is None: + return True + node_name = _current_seed_node + if node_name is None: + return True + seed_placements = _strategy_seed.get(node_name) + if seed_placements is None: + return True + + radius = _strategy_radius + if isinstance(radius, dict): + radius = radius.get(node_name, 0) + if radius is None: + radius = 0 + + distance = abs(len(placements) - len(seed_placements)) + distance += sum(1 for a, b in zip(placements, seed_placements) if a != b) + return distance <= radius + + _op_rules = {} @@ -194,6 +243,8 @@ def _create_all_options(mesh, shape, tensor_meta=None, tensor=None): all_options = list(itertools.product(*[possible_options for _ in range(mesh.ndim)])) strats = [] for placement in all_options: + if not within_strategy_seed_ball(placement): + continue spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta) strats.append(OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])) out_strats = OpStrategy(strats) diff --git a/autoparallel/tools/overlap_simulator/run.py b/autoparallel/tools/overlap_simulator/run.py index ec5532d2..83584c07 100644 --- a/autoparallel/tools/overlap_simulator/run.py +++ b/autoparallel/tools/overlap_simulator/run.py @@ -252,7 +252,9 @@ def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]: if isinstance(x, int): return x if hasattr(x, "node") and hasattr(x.node, "hint"): - return x.node.hint + hint = x.node.hint + if isinstance(hint, int): + return hint return None @staticmethod diff --git a/docs/README.md b/docs/README.md index 9299286f..dc5d718e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,8 +20,10 @@ If you're new to the project, use the reading order below. - [How AutoParallel Chooses a Strategy](how_autoparallel_chooses_a_strategy.md) - [Adaptive Sharding: Sequence-Parallel vs Column-Parallel](adaptive_sharding.md) +- [Split-Dim Seed Design](split_dim_seed_design.md) ## Advanced usage - [Using `local_map` for MoE and Custom Communication Patterns](local_map_and_moe.md) +- [Split-Dim Seed Search](split_dim_seed.md) - [Saving and Loading Optimizer State](save_load.md) diff --git a/docs/split_dim_seed.md b/docs/split_dim_seed.md new file mode 100644 index 00000000..8d156620 --- /dev/null +++ b/docs/split_dim_seed.md @@ -0,0 +1,68 @@ +# Split-Dim Seed Search + +Split-dim seed search builds a placement seed for a full mesh by solving one +one-dimensional sharding problem per mesh dimension. The full mesh optimizer can +then search only the Hamming ball around that seed with the existing PuLP ILP +solver, or solve the LP relaxation of the same restricted problem. + +This is intended for mesh-discovery sweeps where full strategy enumeration is +too large, but a small neighborhood around a fabric-aware seed is still useful. + +## Usage + +```python +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel.mesh_search import build_split_dim_seed +from autoparallel.optimize_sharding import ShardingOptimizer + +input_placement = (Shard(0), Replicate(), Replicate()) + +seed = build_split_dim_seed( + gm, + mesh_shape=(16, 8, 4), + input_placements=input_placement, + cost_model=topo_config, + force_grad_reduce_in_higher_precision=True, + repeated_subgraphs=True, + one_d_cache=seed_cache, +) + +opt = ShardingOptimizer( + gm, + mesh, + force_grad_reduce_in_higher_precision=True, + repeated_subgraphs=True, + strategy_seed=seed, + strategy_radius=2, +) + +opt.add_sharded_input_constraint([input_placement]) +opt.add_sharded_output_constraint([input_placement]) +opt.add_parameter_memory_constraint(0.0, 1.0 / mesh.size()) + +solution = opt.get_solution() +lp_result = opt.solve_lp_relaxation(extract=True) +``` + +`strategy_seed` maps FX node names to placement tuples. `strategy_radius` +keeps placements whose Hamming distance from the seed placement is at most that +radius. Nodes not present in the seed keep their full strategy space. + +## Topology + +When `cost_model` is an `NCCLTopoConfig`, each one-dimensional seed solve uses +the topology of the corresponding full-mesh dimension. This preserves the +original physical node configuration while making the 1D solve see the right +fabric tier for that dimension. + +For `cost_model="nccl"`, the topology is detected from the full mesh before the +per-dimension seed solves. + +## Solvers + +`get_solution()` solves the restricted Hamming space as an ILP. + +`solve_lp_relaxation()` solves the LP relaxation of the same restricted problem +and returns the objective, status, variable fractionality, and optionally an +extracted sharding solution when the relaxation is integral. diff --git a/docs/split_dim_seed_design.md b/docs/split_dim_seed_design.md new file mode 100644 index 00000000..705dad34 --- /dev/null +++ b/docs/split_dim_seed_design.md @@ -0,0 +1,66 @@ +# Split-Dim Seed Design + +## Goal + +Provide a small, reusable mesh-discovery primitive: + +- build a fabric-aware seed by solving independent 1D problems; +- restrict the full-mesh strategy space to a Hamming ball around that seed; +- solve that restricted problem with the existing ILP path or its LP relaxation. + +This branch intentionally does not include TRW-S, lazy cost builds, approximate +solver changes, or experiment runners. + +## Public Surface + +`autoparallel.mesh_search.build_split_dim_seed(...)` returns +`{node.name: placement_tuple}` for a target mesh shape. + +`ShardingOptimizer(..., strategy_seed=seed, strategy_radius=r)` applies the +Hamming-ball restriction during strategy generation, then builds the normal PuLP +problem. + +`ShardingOptimizer.solve_lp_relaxation(...)` solves the continuous relaxation of +the same PuLP problem and reports objective/status diagnostics. + +## Topology Handling + +`NCCLTopoConfig.mesh_dim_topo_override` lets a one-dimensional seed solve reuse +the `MeshDimTopo` derived from the corresponding full-mesh dimension. The +override is accepted only for 1D dim0 and only when the override rank count +matches the 1D mesh size. + +The cache key for 1D seeds includes: + +- one-dimensional size and input placement; +- physical NCCL config, including original `num_nodes` and `gpus_per_node`; +- derived per-dimension `MeshDimTopo`. + +That keeps same-size dimensions on different fabrics from sharing a seed. + +## Strategy Filtering + +Strategy generation stores the active seed and current FX node name in +`propagation_rules`. Placement generation keeps strategies whose output +placement is inside the active Hamming radius. + +The placement-option cache includes mesh identity and seed information. The +optimizer resets that cache when installing and removing a seed so seeded and +unseeded builds cannot reuse each other's filtered entries. + +If a node has no seed entry, its full strategy space is kept. If filtering an +operator strategy would remove every option, the original options are kept so +the build does not fail because of an over-tight seed. + +## Solver Behavior + +The restricted search still uses the normal optimizer lifecycle: + +1. generate placement options; +2. build decision variables and costs; +3. add default and user constraints; +4. solve with PuLP. + +The ILP path is `get_solution()`. The LP path relaxes existing binary variables +to continuous variables, solves, reports fractionality, and restores the +variable categories before returning. diff --git a/tests/conftest.py b/tests/conftest.py index d5d23ea1..df9e7d75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ "name": "H100", "total_memory": 80 * 1024**3, "multi_processor_count": 132, + "L2_cache_size": 50 * 1024**2, }, )(), ), diff --git a/tests/test_split_dim_seed.py b/tests/test_split_dim_seed.py new file mode 100644 index 00000000..3970953b --- /dev/null +++ b/tests/test_split_dim_seed.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import pulp +import pytest +import torch +from conftest import apply_cuda_patches +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel.api import AutoParallel +from autoparallel.cost_models.nccl_cost_model import h100_topo_config +from autoparallel.mesh_search import build_split_dim_seed +from autoparallel.optimize_sharding import ShardingOptimizer + +pytestmark = [ + pytest.mark.filterwarnings("ignore:Constructing LpVariable.*:DeprecationWarning"), + pytest.mark.filterwarnings( + "ignore:Using LpProblem.constraints.*:DeprecationWarning" + ), +] + + +class TinyMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.in_proj = torch.nn.Linear(16, 32) + self.out_proj = torch.nn.Linear(32, 16) + + def forward(self, x): + return self.out_proj(torch.relu(self.in_proj(x))) + + +def _input_fn(): + return torch.randn(8, 16, device="cuda", requires_grad=True) + + +@apply_cuda_patches +def test_split_dim_seed_hamming_space_solves_with_ilp_and_lp(): + config = h100_topo_config(num_nodes=2, gpus_per_node=4) + with unset_fake_temporarily(): + mesh = init_device_mesh( + "cuda", + (2, 2, 2), + mesh_dim_names=("dp", "mid", "inner"), + ) + + with torch.device("meta"): + model = TinyMLP() + + input_placement = (Shard(0), Replicate(), Replicate()) + one_d_cache = {} + + with AutoParallel( + model, + _input_fn, + mesh, + cost_model=config, + repeated_subgraphs=False, + ) as autop: + seed = build_split_dim_seed( + autop.gm, + tuple(mesh.shape), + input_placement, + cost_model=config, + repeated_subgraphs=False, + one_d_cache=one_d_cache, + ) + + opt = ShardingOptimizer( + autop.gm, + mesh, + repeated_subgraphs=False, + strategy_seed=seed, + strategy_radius=2, + ) + opt.add_sharded_input_constraint([input_placement]) + opt.add_sharded_output_constraint([input_placement]) + opt.add_parameter_memory_constraint(0.0, 1.0) + + lp_result = opt.solve_lp_relaxation(extract=True) + assert lp_result["status"] == "Optimal" + assert math.isfinite(lp_result["objective"]) + + solution = opt.get_solution(verbose=False) + assert solution + assert pulp.LpStatus[opt.prob.status] == "Optimal" + assert math.isfinite(pulp.value(opt.prob.objective))