Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions cmd/nvidia-validator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ const (
shell = "sh"
// defaultVFWaitTimeout is the default timeout for waiting for VFs to be created
defaultVFWaitTimeout = 5 * time.Minute
// sriovManageBinaryPath is the path to NVIDIA's sriov-manage script inside the
// driver root. It ships with the vGPU Manager (host driver) and enables
// SR-IOV Virtual Functions on the NVIDIA GPUs.
sriovManageBinaryPath = "/usr/lib/nvidia/sriov-manage"
// constants for driver components
GDRCOPY = "gdrcopy"
NVIDIAFS = "nvidia-fs"
Expand Down Expand Up @@ -1747,6 +1751,14 @@ func (v *VGPUManager) validate() error {
return err
}

// SR-IOV VFs are runtime state and do not survive a node reboot, so
// re-establish them before waiting. This is best-effort: on failure we still
// fall through to waitForVFs, which preserves the prior behavior on setups
// where the VFs are created out-of-band.
if err := enableVFs(hostDriver); err != nil {
log.Warnf("Unable to enable SR-IOV VFs, will wait for them to appear: %v", err)
}

log.Info("Waiting for VFs to be available...")
if err := waitForVFs(ctx, defaultVFWaitTimeout); err != nil {
return fmt.Errorf("vGPU Manager VFs not ready: %w", err)
Expand Down Expand Up @@ -1783,6 +1795,60 @@ func (v *VGPUManager) runValidation(silent bool) (hostDriver bool, err error) {
return hostDriver, runCommand(command, args, silent)
}

// countVFs sums the expected (TotalVFs) and enabled (NumVFs) VF counts across
// all SR-IOV physical functions among the given NVIDIA GPUs, and returns the
// number of physical functions found.
func countVFs(gpus []*nvpci.NvidiaPCIDevice) (totalExpected, totalEnabled uint64, pfCount int) {
for _, gpu := range gpus {
sriovInfo := gpu.SriovInfo
if sriovInfo.IsPF() {
pfCount++
totalExpected += sriovInfo.PhysicalFunction.TotalVFs
totalEnabled += sriovInfo.PhysicalFunction.NumVFs
}
}
return totalExpected, totalEnabled, pfCount
}

// enableVFs re-creates SR-IOV Virtual Functions on the NVIDIA GPUs by invoking
// NVIDIA's 'sriov-manage -e ALL' inside the driver root. On the vGPU (sandbox)
// workload path, VFs are runtime state that does not survive a node reboot;
// without re-enabling them, after a reboot the vGPU devices cannot be created
// and validation blocks in waitForVFs waiting for VFs that never appear.
//
// It is idempotent: enablement is skipped when every SR-IOV-capable GPU already
// has its full VF count, which is the normal steady state, so the common case
// is a no-op. The post-reboot trigger this targets has no VFs enabled and no
// running VMs yet, so re-enabling is safe. It covers only VF re-enablement — no
// GPU reset and no MIG reconfiguration are performed here.
func enableVFs(hostDriver bool) error {
gpus, err := nvpci.New().GetGPUs()
if err != nil {
return fmt.Errorf("error getting GPUs: %w", err)
}

totalExpected, totalEnabled, _ := countVFs(gpus)
if totalExpected == 0 {
log.Info("No SR-IOV capable GPUs found, skipping VF enablement")
return nil
}
if totalEnabled >= totalExpected {
log.Info("SR-IOV VFs already enabled on all capable GPUs, skipping VF enablement")
return nil
}

// sriov-manage lives inside the driver root: the driver container root when
// the vGPU Manager is deployed as a container, or the host root when the
// vGPU Manager driver is pre-installed on the host.
driverRoot := defaultDriverInstallDir
if hostDriver {
driverRoot = "/host"
}

log.Infof("Enabling SR-IOV VFs on NVIDIA GPUs via 'sriov-manage -e ALL' (driver root: %q)", driverRoot)
return runCommand("chroot", []string{driverRoot, sriovManageBinaryPath, "-e", "ALL"}, false)
}

// waitForVFs waits for Virtual Functions to be created on all NVIDIA GPUs.
// It polls sriov_numvfs until all GPUs have their full VF count enabled.
func waitForVFs(ctx context.Context, timeout time.Duration) error {
Expand All @@ -1796,17 +1862,7 @@ func waitForVFs(ctx context.Context, timeout time.Duration) error {
return false, nil
}

var totalExpected, totalEnabled uint64
var pfCount int
for _, gpu := range gpus {
sriovInfo := gpu.SriovInfo
if sriovInfo.IsPF() {
pfCount++
totalExpected += sriovInfo.PhysicalFunction.TotalVFs
totalEnabled += sriovInfo.PhysicalFunction.NumVFs
}
}

totalExpected, totalEnabled, pfCount := countVFs(gpus)
if totalExpected == 0 {
log.Info("No SR-IOV capable GPUs found, skipping VF wait")
return true, nil
Expand Down
97 changes: 97 additions & 0 deletions cmd/nvidia-validator/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"testing"

"github.com/NVIDIA/go-nvlib/pkg/nvpci"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -290,3 +291,99 @@ UNKNOWN_FEATURE: true`,
})
}
}

// pfDevice returns an SR-IOV physical function GPU with the given VF counts.
func pfDevice(address string, totalVFs, numVFs uint64) *nvpci.NvidiaPCIDevice {
return &nvpci.NvidiaPCIDevice{
Address: address,
SriovInfo: nvpci.SriovInfo{
PhysicalFunction: &nvpci.SriovPhysicalFunction{
TotalVFs: totalVFs,
NumVFs: numVFs,
},
},
}
}

// TestCountVFs verifies the shared VF-accounting helper that drives both the
// idempotency guard in enableVFs and the readiness check in waitForVFs. Getting
// this wrong would either skip a needed 'sriov-manage -e' (VFs never come back
// after a reboot) or disturb VFs already assigned to running VMs, so the guard
// (totalEnabled >= totalExpected) is exercised across the boundary cases.
func TestCountVFs(t *testing.T) {
testCases := []struct {
description string
gpus []*nvpci.NvidiaPCIDevice
wantExpected uint64
wantEnabled uint64
wantPFCount int
wantNeedsEnabling bool
}{
{
description: "no SR-IOV capable GPUs",
gpus: []*nvpci.NvidiaPCIDevice{{Address: "0000:41:00.0"}},
wantExpected: 0,
wantEnabled: 0,
wantPFCount: 0,
wantNeedsEnabling: false,
},
{
description: "VFs missing after reboot",
gpus: []*nvpci.NvidiaPCIDevice{pfDevice("0000:41:00.0", 16, 0)},
wantExpected: 16,
wantEnabled: 0,
wantPFCount: 1,
wantNeedsEnabling: true,
},
{
description: "VFs fully enabled",
gpus: []*nvpci.NvidiaPCIDevice{pfDevice("0000:41:00.0", 16, 16)},
wantExpected: 16,
wantEnabled: 16,
wantPFCount: 1,
wantNeedsEnabling: false,
},
{
description: "partially enabled across multiple PFs",
gpus: []*nvpci.NvidiaPCIDevice{
pfDevice("0000:41:00.0", 16, 16),
pfDevice("0000:c1:00.0", 16, 0),
},
wantExpected: 32,
wantEnabled: 16,
wantPFCount: 2,
wantNeedsEnabling: true,
},
{
description: "virtual functions are not counted as PFs",
gpus: []*nvpci.NvidiaPCIDevice{
pfDevice("0000:41:00.0", 16, 16),
{
Address: "0000:41:00.4",
SriovInfo: nvpci.SriovInfo{
VirtualFunction: &nvpci.SriovVirtualFunction{},
},
},
},
wantExpected: 16,
wantEnabled: 16,
wantPFCount: 1,
wantNeedsEnabling: false,
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
totalExpected, totalEnabled, pfCount := countVFs(tc.gpus)
require.Equal(t, tc.wantExpected, totalExpected, "totalExpected")
require.Equal(t, tc.wantEnabled, totalEnabled, "totalEnabled")
require.Equal(t, tc.wantPFCount, pfCount, "pfCount")

// This mirrors the guard enableVFs uses to decide whether to invoke
// sriov-manage: enable only when there is at least one SR-IOV GPU and
// not every VF is already present.
needsEnabling := totalExpected > 0 && totalEnabled < totalExpected
require.Equal(t, tc.wantNeedsEnabling, needsEnabling, "needsEnabling")
})
}
}