Skip to content

Commit 8b7d340

Browse files
ggml/hip: fix APU compatibility - soft error handling for hipMemAdviseSetCoarseGrain (ggml-org#20536)
* ggml/hip: fix APU compatibility - soft error handling for hipMemAdviseSetCoarseGrain On AMD APU/iGPU devices (unified memory architecture), hipMemAdviseSetCoarseGrain returns hipErrorInvalidValue because the hint is not applicable to UMA systems. The previous CUDA_CHECK() call treated this as a fatal error, causing crashes on APU systems such as AMD Strix Halo (gfx1151). Fix: treat hipMemAdviseSetCoarseGrain as an optional performance hint - call it without error checking and clear any resulting error with hipGetLastError(). Also add pre-allocation debug logging (GGML_LOG_DEBUG) to help diagnose memory issues on APU systems, and store totalGlobalMem in device info. Context: AMD APUs on Windows are affected by a ROCm runtime bug that limits hipMallocManaged to ~64GB regardless of available system RAM. A fix has been submitted upstream: ROCm/rocm-systems#4077 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * ggml/hip: remove unrelated changes, keep only hipMemAdviseSetCoarseGrain fix --------- Co-authored-by: moonshadow-25 <moonshadow-25@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5596464 commit 8b7d340

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
124124
err = cudaMallocManaged(ptr, size);
125125
#if defined(GGML_USE_HIP)
126126
if (err == hipSuccess) {
127-
CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
127+
// hipMemAdviseSetCoarseGrain is an optional performance hint;
128+
// ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
129+
cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
130+
(void)hipGetLastError(); // clear any error
128131
}
129132

130133
// fall back to cudaMalloc if not supported (e.g. on Windows)

0 commit comments

Comments
 (0)