Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 81 additions & 41 deletions cuda_core/tests/test_green_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,58 @@ def fill_kernel(init_cuda):
return mod.get_kernel("fill")


def _safe_two_group_count(sm):
"""Return a safe per-group SM count for a 2-group split.
def _is_invalid_resource_configuration(exc):
return "CUDA_ERROR_INVALID_RESOURCE_CONFIGURATION" in str(exc)

Uses min_partition_size which is always a valid split size regardless
of hardware topology. Returns None if the device doesn't have enough SMs.
"""
min_size = sm.min_partition_size
if sm.sm_count < 2 * min_size:
return None
return min_size

def _iter_requested_sm_counts(sm, n_groups=1, *, descending=False):
"""Yield even per-group SM counts worth probing on this device."""
start = max(2, sm.min_partition_size)
if start % 2:
start += 1
stop = sm.sm_count // n_groups
counts = range(start, stop + 1, 2)
return reversed(counts) if descending else counts


def _try_sm_split(sm, *, count, backfill=False):
try:
return sm.split(SMResourceOptions(count=count, backfill=backfill))
except CUDAError as exc:
if _is_invalid_resource_configuration(exc):
return None
raise


def _find_supported_split(sm, *, n_groups=1, backfill=False, descending=False):
"""Return a supported explicit split request for this device, if any."""
for count in _iter_requested_sm_counts(sm, n_groups=n_groups, descending=descending):
request = count if n_groups == 1 else (count,) * n_groups
split = _try_sm_split(sm, count=request, backfill=backfill)
if split is not None:
groups, rem = split
return count, groups, rem
return None


def _find_any_two_group_split(sm):
split = _find_supported_split(sm, n_groups=2)
if split is not None:
return split
return _find_supported_split(sm, n_groups=2, backfill=True)


def _find_backfill_only_two_group_split(sm):
"""Return a 2-group split size that needs backfill, if the device has one."""
for count in _iter_requested_sm_counts(sm, n_groups=2, descending=True):
request = (count, count)
if _try_sm_split(sm, count=request) is not None:
continue
split = _try_sm_split(sm, count=request, backfill=True)
if split is not None:
groups, rem = split
return count, groups, rem
return None


@contextlib.contextmanager
Expand Down Expand Up @@ -153,8 +195,10 @@ def test_arch_constraints_pre_hopper(self, init_cuda, sm_resource):
def test_arch_constraints_hopper_plus(self, init_cuda, sm_resource):
if init_cuda.compute_capability < (9, 0):
pytest.skip("Test is for Hopper+ architectures")
assert sm_resource.min_partition_size >= 8
assert sm_resource.coscheduled_alignment >= 8
assert sm_resource.min_partition_size >= 2
assert sm_resource.coscheduled_alignment >= 2
assert sm_resource.min_partition_size % 2 == 0
assert sm_resource.coscheduled_alignment % 2 == 0


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -221,9 +265,11 @@ def test_dry_run_cannot_create_context(self, init_cuda, sm_resource):

class TestSMResourceSplit:
def test_single_group_counts(self, sm_resource):
"""Single-group split: group gets at least requested SMs."""
requested = sm_resource.min_partition_size
groups, rem = sm_resource.split(SMResourceOptions(count=requested))
"""Single-group split: group gets at least a supported requested size."""
split = _find_supported_split(sm_resource)
if split is None:
pytest.skip("Device does not expose a valid explicit single-group split")
requested, groups, rem = split

assert len(groups) == 1
assert groups[0].sm_count >= requested
Expand All @@ -243,12 +289,11 @@ def test_discovery_respects_alignment(self, sm_resource):
assert groups[0].sm_count % sm_resource.coscheduled_alignment == 0

def test_two_groups(self, sm_resource):
"""Two-group split with min_partition_size (always topology-safe)."""
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, rem = sm_resource.split(SMResourceOptions(count=(count, count)))
"""Two-group split succeeds for a supported explicit request."""
split = _find_supported_split(sm_resource, n_groups=2)
if split is None:
pytest.skip("Device does not expose a valid 2-group split without backfill")
count, groups, rem = split

assert len(groups) == 2
assert groups[0].sm_count >= count
Expand All @@ -257,19 +302,16 @@ def test_two_groups(self, sm_resource):
assert total <= sm_resource.sm_count

def test_two_groups_backfill(self, sm_resource):
"""Two-group split with backfill allows larger partitions."""
align = sm_resource.coscheduled_alignment
if align == 0:
align = sm_resource.min_partition_size
half = (sm_resource.sm_count // 2 // align) * align
if half < sm_resource.min_partition_size:
pytest.skip("Not enough SMs for a 2-group backfill split")

groups, rem = sm_resource.split(SMResourceOptions(count=(half, half), backfill=True))
"""Backfill unlocks a 2-group split size that default placement rejects."""
split = _find_backfill_only_two_group_split(sm_resource)
if split is None:
pytest.skip("Device does not expose a backfill-only 2-group split")
requested, groups, rem = split

assert len(groups) == 2
assert groups[0].sm_count >= half
assert groups[1].sm_count >= half
assert groups[0].sm_count >= requested
assert groups[1].sm_count >= requested
assert groups[0].sm_count + groups[1].sm_count + rem.sm_count <= sm_resource.sm_count

def test_dry_run_matches_real(self, sm_resource):
"""Dry-run reports the same SM counts as a real split."""
Expand Down Expand Up @@ -360,11 +402,10 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):

def test_green_ctx_resources_reflect_partition(self, init_cuda, sm_resource):
"""Two green contexts should have disjoint SM partitions."""
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
split = _find_any_two_group_split(sm_resource)
if split is None:
pytest.skip("Device does not expose a valid 2-group split")
_, groups, _ = split

ctx_a = ctx_b = None
try:
Expand Down Expand Up @@ -433,11 +474,10 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
def test_two_green_contexts_independent(self, init_cuda, sm_resource, fill_kernel):
"""Two SM groups -> two green contexts -> two independent kernels."""
dev = init_cuda
count = _safe_two_group_count(sm_resource)
if count is None:
pytest.skip("Not enough SMs for a 2-group split")

groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
split = _find_any_two_group_split(sm_resource)
if split is None:
pytest.skip("Device does not expose a valid 2-group split")
_, groups, _ = split
assert len(groups) == 2

ctx_a = ctx_b = None
Expand Down
Loading