Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions cuda_core/cuda/core/system/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ include "_fan.pxi"
include "_field_values.pxi"
include "_inforom.pxi"
include "_memory.pxi"
include "_mig.pxi"
include "_pci_info.pxi"
include "_performance.pxi"
include "_repair_status.pxi"
Expand Down Expand Up @@ -132,14 +133,20 @@ cdef class Device:
board serial identifier.

In the upstream NVML C++ API, the UUID includes a ``gpu-`` or ``mig-``
prefix. That is not included in ``cuda.core.system``.
prefix. If you want that prefix, use the `uuid_with_prefix` property.
"""
# NVML UUIDs have a `GPU-` or `MIG-` prefix. We remove that here.

# TODO: If the user cares about the prefix, we will expose that in the
# future using the MIG-related APIs in NVML.
return nvml.device_get_uuid(self._handle)[4:]

@property
def uuid_with_prefix(self) -> str:
"""
Retrieves the globally unique immutable UUID associated with this
device, as a 5 part hexadecimal string, that augments the immutable,
board serial identifier.
"""
return nvml.device_get_uuid(self._handle)

@property
def pci_bus_id(self) -> str:
"""
Expand Down Expand Up @@ -280,6 +287,8 @@ cdef class Device:
int
The number of available devices.
"""
initialize()

return nvml.device_get_count_v2()

@classmethod
Expand All @@ -292,6 +301,8 @@ cdef class Device:
Iterator of Device
An iterator over available devices.
"""
initialize()

for device_id in range(nvml.device_get_count_v2()):
yield cls(index=device_id)

Expand All @@ -317,6 +328,18 @@ cdef class Device:
"""
return AddressingMode(nvml.device_get_addressing_mode(self._handle).value)

#########################################################################
# MIG (MULTI-INSTANCE GPU) DEVICES

@property
def mig(self) -> MigInfo:
"""
Accessor for MIG (Multi-Instance GPU) information.

For Ampere™ or newer fully supported devices.
"""
return MigInfo(self)

#########################################################################
# AFFINITY

Expand Down Expand Up @@ -853,6 +876,7 @@ __all__ = [
"InforomInfo",
"InforomObject",
"MemoryInfo",
"MigInfo",
"PcieUtilCounter",
"PciInfo",
"Pstates",
Expand Down
175 changes: 175 additions & 0 deletions cuda_core/cuda/core/system/_mig.pxi
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0


from typing import Iterable


cdef class MigInfo:
cdef Device _device

def __init__(self, device: Device):
self._device = device

@property
def is_mig_device(self) -> bool:
"""
Whether this device is a MIG (Multi-Instance GPU) device.

A MIG device handle is an NVML abstraction which maps to a MIG compute
instance. These overloaded references can be used (with some
restrictions) interchangeably with a GPU device handle to execute
queries at a per-compute instance granularity.

For Ampere™ or newer fully supported devices.
"""
return bool(nvml.device_is_mig_device_handle(self._device._handle))

@property
def mode(self) -> bool:
"""
Get current MIG mode for the device.

For Ampere™ or newer fully supported devices.

Changing MIG modes may require device unbind or reset. The "pending" MIG
mode refers to the target mode following the next activation trigger.

If the device is not a MIG device, returns `False`.

Returns
-------
bool
`True` if current MIG mode is enabled.
"""
if not self.is_mig_device:
return False

current, _ = nvml.device_get_mig_mode(self._device._handle)
return current == nvml.EnableState.FEATURE_ENABLED

@mode.setter
def mode(self, mode: bool):
"""
Set the MIG mode for the device.

For Ampere™ or newer fully supported devices.

Changing MIG modes may require device unbind or reset. The "pending" MIG
mode refers to the target mode following the next activation trigger.

Parameters
----------
mode: bool
`True` to enable MIG mode, `False` to disable MIG mode.
"""
if not self.is_mig_device:
raise ValueError("Device is not a MIG device")

nvml.device_set_mig_mode(
self._device._handle,
nvml.EnableState.FEATURE_ENABLED if mode else nvml.EnableState.FEATURE_DISABLED
)

@property
def pending_mode(self) -> bool:
"""
Get pending MIG mode for the device.

For Ampere™ or newer fully supported devices.

Changing MIG modes may require device unbind or reset. The "pending" MIG
mode refers to the target mode following the next activation trigger.

If the device is not a MIG device, returns `False`.

Returns
-------
bool
`True` if pending MIG mode is enabled.
"""
if not self.is_mig_device:
return False

_, pending = nvml.device_get_mig_mode(self._device._handle)
return pending == nvml.EnableState.FEATURE_ENABLED

def get_device_count(self) -> int:
"""
Get the maximum number of MIG devices that can exist under this device.

Returns zero if MIG is not supported or enabled.

For Ampere™ or newer fully supported devices.

Returns
-------
int
The number of MIG devices (compute instances) on this GPU.
"""
return nvml.device_get_max_mig_device_count(self._device._handle)

def get_parent_device(self) -> Device:
"""
For MIG devices, get the parent GPU device.

For Ampere™ or newer fully supported devices.

Returns
-------
Device
The parent GPU device for this MIG device.
"""
parent_handle = nvml.device_get_handle_from_mig_device_handle(self._handle)
parent_device = Device.__new__(Device)
parent_device._handle = parent_handle
return parent_device

def get_device_by_index(self, index: int) -> Device:
"""
Get MIG device for the given index under its parent device.

If the compute instance is destroyed either explicitly or by destroying,
resetting or unbinding the parent GPU instance or the GPU device itself
the MIG device handle would remain invalid and must be requested again
using this API. Handles may be reused and their properties can change in
the process.

For Ampere™ or newer fully supported devices.

Parameters
----------
index: int
The index of the MIG device (compute instance) to retrieve. Must be
between 0 and the value returned by `get_device_count() - 1`.

Returns
-------
Device
The MIG device corresponding to the given index.
"""
mig_device_handle = nvml.device_get_mig_device_handle_by_index(self._device._handle, index)
mig_device = Device.__new__(Device)
mig_device._handle = mig_device_handle
return mig_device

def get_all_devices(self) -> Iterable[Device]:
"""
Get all MIG devices under its parent device.

If the compute instance is destroyed either explicitly or by destroying,
resetting or unbinding the parent GPU instance or the GPU device itself
the MIG device handle would remain invalid and must be requested again
using this API. Handles may be reused and their properties can change in
the process.

For Ampere™ or newer fully supported devices.

Returns
-------
list[Device]
A list of all MIG devices corresponding to this GPU.
"""
for i in range(self.get_device_count()):
yield self.get_device_by_index(i)
1 change: 1 addition & 0 deletions cuda_core/docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ Types
system.GpuTopologyLevel
system.InforomInfo
system.MemoryInfo
system.MigInfo
system.PciInfo
system.RepairStatus
system.Temperature
Expand Down
27 changes: 27 additions & 0 deletions cuda_core/tests/system/test_system_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,30 @@ def test_pstates():
assert isinstance(utilization.percentage, int)
assert isinstance(utilization.inc_threshold, int)
assert isinstance(utilization.dec_threshold, int)


@pytest.mark.skipif(helpers.IS_WSL or helpers.IS_WINDOWS, reason="MIG not supported on WSL or Windows")
def test_mig():
for device in system.Device.get_all_devices():
with unsupported_before(device, None):
mig = device.mig

assert isinstance(mig.is_mig_device, bool)
if mig.is_mig_device:
assert isinstance(mig.mode, bool)
assert isinstance(mig.pending_mode, bool)

device_count = mig.get_device_count()
assert isinstance(device_count, int)
assert device_count >= 0

for mig_device in mig.get_all_devices():
assert isinstance(mig_device, system.Device)


def test_uuid_with_prefix():
for device in system.Device.get_all_devices():
uuid_with_prefix = device.uuid_with_prefix
assert isinstance(uuid_with_prefix, str)
assert uuid_with_prefix.startswith(("GPU-", "MIG-"))
assert uuid_with_prefix[4:] == device.uuid
Loading