Skip to content

Commit

Permalink
Monkey-patch pynvml._nvmlGetFunctionPointer against process v3 APIs
Browse files Browse the repository at this point in the history
pynvml 11.510.69 has broken the backward compatibility by removing
`nvml.nvmlDeviceGetComputeRunningProcesses_v2` which is replaced by v3
APIs (`nvml.nvmlDeviceGetComputeRunningProcesses_v3`), but this function
does not exist for old nvidia drivers less than 510.39.01.

Therefore we pinned pynvml version at 11.495.46 in gpustat v1.0 (#107),
but we actually have to use recent pynvml versions for "latest" or modern
NVIDIA drivers. To make compute/graphics process information work
correctly when a combination of old nvidia drivers (`< 510.39`) AND
`pynvml >= 11.510.69` is used, we need to monkey-patch pynvml functions
in our custom manner such that, for instance, when v3 API is introduced,
we can simply fallback to v2 APIs to retrieve the process information.
  • Loading branch information
wookayin committed Nov 27, 2022
1 parent f664762 commit 5e486c7
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 23 deletions.
66 changes: 61 additions & 5 deletions gpustat/nvml.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Imports pynvml with sanity checks and custom patches."""

import textwrap
import functools
import os


pynvml = None
import sys
import textwrap

# If this environment variable is set, we will bypass pynvml version validation
# so that legacy pynvml (nvidia-ml-py3) can be used. This would be useful
Expand All @@ -25,15 +24,18 @@
hasattr(pynvml, 'nvmlDeviceGetComputeRunningProcesses_v2')
) and not ALLOW_LEGACY_PYNVML:
raise RuntimeError("pynvml library is outdated.")

except (ImportError, SyntaxError, RuntimeError) as e:
_pynvml = sys.modules.get('pynvml', None)

raise ImportError(textwrap.dedent(
"""\
pynvml is missing or an outdated version is installed.
We require nvidia-ml-py>=11.450.129, and nvidia-ml-py3 shall not be used.
For more details, please refer to: https://github.com/wookayin/gpustat/issues/107
Your pynvml installation: """ + repr(pynvml) +
Your pynvml installation: """ + repr(_pynvml) +
"""
-----------------------------------------------------------
Expand All @@ -48,4 +50,58 @@
""")) from e


# Monkey-patch nvml due to breaking changes in pynvml.
# See #107, #141, and test_gpustat.py for more details.

_original_nvmlGetFunctionPointer = pynvml._nvmlGetFunctionPointer


class pynvml_monkeypatch:

@staticmethod # Note: must be defined as a staticmethod to allow mocking.
def original_nvmlGetFunctionPointer(name):
return _original_nvmlGetFunctionPointer(name)

FUNCTION_FALLBACKS = {
# for pynvml._nvmlGetFunctionPointer
'nvmlDeviceGetComputeRunningProcesses_v3': 'nvmlDeviceGetComputeRunningProcesses_v2',
'nvmlDeviceGetGraphicsRunningProcesses_v3': 'nvmlDeviceGetGraphicsRunningProcesses_v2',
}

@staticmethod
@functools.wraps(pynvml._nvmlGetFunctionPointer)
def _nvmlGetFunctionPointer(name):
"""Our monkey-patched pynvml._nvmlGetFunctionPointer().
See also:
test_gpustat::NvidiaDriverMock for test scenarios
"""

try:
ret = pynvml_monkeypatch.original_nvmlGetFunctionPointer(name)
return ret
except pynvml.NVMLError as e:
if e.value != pynvml.NVML_ERROR_FUNCTION_NOT_FOUND: # type: ignore
raise

if name in pynvml_monkeypatch.FUNCTION_FALLBACKS:
# Lack of ...Processes_v3 APIs happens for
# OLD drivers < 510.39.01 && pynvml >= 11.510, where
# we fallback to v2 APIs. (see #107 for more details)

ret = pynvml_monkeypatch.original_nvmlGetFunctionPointer(
pynvml_monkeypatch.FUNCTION_FALLBACKS[name]
)
# populate the cache, so this handler won't get executed again
pynvml._nvmlGetFunctionPointer_cache[name] = ret

else:
# Unknown case, cannot handle. re-raise again
raise

return ret

setattr(pynvml, '_nvmlGetFunctionPointer', _nvmlGetFunctionPointer)


__all__ = ['pynvml']
36 changes: 19 additions & 17 deletions gpustat/test_gpustat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@

import psutil
import pytest
from mockito import mock, unstub, when
from mockito import mock, unstub, when, when2

import gpustat
from gpustat.nvml import pynvml
from gpustat.nvml import pynvml, pynvml_monkeypatch

MB = 1024 * 1024

Expand Down Expand Up @@ -46,8 +46,7 @@ def _configure_mock(N=pynvml,
when(N).nvmlShutdown().thenReturn()
when(N).nvmlSystemGetDriverVersion().thenReturn('415.27.mock')

when(N)._nvmlGetFunctionPointer('nvmlErrorString')\
.thenCallOriginalImplementation()
when(N)._nvmlGetFunctionPointer(...).thenCallOriginalImplementation()

NUM_GPUS = 3
mock_handles = [types.SimpleNamespace(value='mock-handle-%d' % i, index=i)
Expand Down Expand Up @@ -323,21 +322,24 @@ def _nvmlDeviceGetGraphicsRunningProcesses_v2(handle, c_count, c_procs):
return pynvml.NVML_ERROR_NOT_SUPPORTED
return pynvml.NVML_SUCCESS

def _fn_notfound(*args, **kwargs):
return pynvml.NVML_ERROR_FUNCTION_NOT_FOUND

# Note: N._nvmlGetFunctionPointer might have been monkey-patched,
# so this mock should decorate the underlying, unwrapped raw function,
# NOT a monkey-patched version of pynvml._nvmlGetFunctionPointer.
for v in [1, 2, 3]:
_v = f'_v{v}' if v != 1 else '' # backward compatible v3 -> v2
when(N) \
._nvmlGetFunctionPointer(f'nvmlDeviceGetComputeRunningProcesses{_v}') \
.thenReturn(_nvmlDeviceGetComputeRunningProcesses_v2
if v <= self.nvmlDeviceGetComputeRunningProcesses_v
else _fn_notfound)
when(N) \
._nvmlGetFunctionPointer(f'nvmlDeviceGetGraphicsRunningProcesses{_v}') \
.thenReturn(_nvmlDeviceGetGraphicsRunningProcesses_v2
if v <= self.nvmlDeviceGetComputeRunningProcesses_v
else _fn_notfound)
stub = when2(pynvml_monkeypatch.original_nvmlGetFunctionPointer,
f'nvmlDeviceGetComputeRunningProcesses{_v}')
if v <= self.nvmlDeviceGetComputeRunningProcesses_v:
stub.thenReturn(_nvmlDeviceGetComputeRunningProcesses_v2)
else:
stub.thenRaise(pynvml.NVMLError(pynvml.NVML_ERROR_FUNCTION_NOT_FOUND))

stub = when2(pynvml_monkeypatch.original_nvmlGetFunctionPointer,
f'nvmlDeviceGetGraphicsRunningProcesses{_v}')
if v <= self.nvmlDeviceGetComputeRunningProcesses_v:
stub.thenReturn(_nvmlDeviceGetGraphicsRunningProcesses_v2)
else:
stub.thenRaise(pynvml.NVMLError(pynvml.NVML_ERROR_FUNCTION_NOT_FOUND))

def __getattr__(self, k):
return self.feat[k]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def run(self):


install_requires = [
'nvidia-ml-py>=11.450.129,<=11.495.46', # see #107
'nvidia-ml-py>=11.450.129', # see #107, #143
'psutil>=5.6.0', # GH-1447
'blessed>=1.17.1', # GH-126
]
Expand Down

0 comments on commit 5e486c7

Please sign in to comment.