Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions devito/passes/iet/langbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from devito.passes import is_on_device
from devito.passes.iet.engine import iet_pass
from devito.symbolics import Byref, CondNe, SizeOf
from sympy import Ge
from devito.tools import as_list, is_integer, prod
from devito.types import Symbol, QueueID, Wildcard

Expand Down Expand Up @@ -56,11 +57,16 @@ class LangBB(metaclass=LangMeta):
"""

@classmethod
def _get_num_devices(cls):
def _get_num_devices(cls, platform):
"""
Get the number of accessible devices.
Returns a tuple of (ngpus_symbol, call_to_get_num_devices).
"""
raise NotImplementedError
from devito.types import Symbol
ngpus = Symbol(name='ngpus', dtype='int32')
devicetype = as_list(cls[platform])
call_ngpus = cls['num-devices'](devicetype, retobj=ngpus)
return ngpus, call_ngpus

@classmethod
def _map_to(cls, f, imask=None, qid=None):
Expand Down Expand Up @@ -426,9 +432,27 @@ def _make_setdevice_seq(iet, nodes=()):
devicetype = as_list(self.langbb[self.platform])
deviceid = self.deviceid

# Add device validation check
ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)

validation = Conditional(
Ge(deviceid, ngpus),
List(body=[
Call('printf', ['"%s: Error - device %d >= %d devices\\n"',
self.langbb['name'], deviceid, ngpus]),
Call('exit', [1])
])
)

device_setup = List(body=[
call_ngpus,
validation,
self.langbb['set-device']([deviceid] + devicetype)
])

return list(nodes) + [Conditional(
CondNe(deviceid, -1),
self.langbb['set-device']([deviceid] + devicetype)
device_setup
)]

def _make_setdevice_mpi(iet, objcomm, nodes=()):
Expand All @@ -441,7 +465,21 @@ def _make_setdevice_mpi(iet, objcomm, nodes=()):

ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)

osdd_then = self.langbb['set-device']([deviceid] + devicetype)
# Add device validation for explicit device ID
validation = Conditional(
Ge(deviceid, ngpus),
List(body=[
Call('printf', ['"%s: Error - device %d >= %d devices\\n"',
self.langbb['name'], deviceid, ngpus]),
Call('exit', [1])
])
)

osdd_then = List(body=[
call_ngpus,
validation,
self.langbb['set-device']([deviceid] + devicetype)
])
osdd_else = self.langbb['set-device']([rank % ngpus] + devicetype)

return list(nodes) + [Conditional(
Expand Down
24 changes: 24 additions & 0 deletions tests/test_gpu_openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,30 @@ def test_op_apply(self):

assert np.all(np.array(u.data[0, :, :, :]) == time_steps)

def test_device_validation_error_message(self):
"""Test that OpenACC device validation includes helpful error messages."""
grid = Grid(shape=(3, 3, 3))

u = TimeFunction(name='u', grid=grid, dtype=np.int32)

op = Operator(Eq(u.forward, u + 1), platform='nvidiaX', language='openacc')

# Check that the generated code contains device validation
code = str(op)

# Should contain device count check
assert 'acc_get_num_devices' in code, "Missing OpenACC device count check"

# Should contain validation condition
assert 'deviceid >= ngpus' in code, "Missing OpenACC device ID " + \
"validation condition"

# Should contain error message
assert 'Error - device' in code, "Missing error message"

# Should contain exit call to prevent undefined behavior
assert 'exit(1)' in code, "Missing exit call on validation failure"

def iso_acoustic(self, opt):
shape = (101, 101)
extent = (1000, 1000)
Expand Down
52 changes: 42 additions & 10 deletions tests/test_gpu_openmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def test_init_omp_env(self):

op = Operator(Eq(u.forward, u.dx+1), language='openmp')

assert str(op.body.init[0].body[0]) ==\
'if (deviceid != -1)\n{\n omp_set_default_device(deviceid);\n}'
# With device validation, the generated code now includes validation logic
init_code = str(op.body.init[0].body[0])
assert 'if (deviceid != -1)' in init_code
assert 'int ngpus = omp_get_num_devices()' in init_code
assert 'if (deviceid >= ngpus)' in init_code
assert 'Error - device' in init_code
assert 'omp_set_default_device(deviceid)' in init_code

@pytest.mark.parallel(mode=1)
def test_init_omp_env_w_mpi(self, mode):
Expand All @@ -31,14 +36,41 @@ def test_init_omp_env_w_mpi(self, mode):

op = Operator(Eq(u.forward, u.dx+1), language='openmp')

assert str(op.body.init[0].body[0]) ==\
('if (deviceid != -1)\n'
'{\n omp_set_default_device(deviceid);\n}\n'
'else\n'
'{\n int rank = 0;\n'
' MPI_Comm_rank(comm,&rank);\n'
' int ngpus = omp_get_num_devices();\n'
' omp_set_default_device((rank)%(ngpus));\n}')
# With device validation, the MPI case also includes validation for explicit
# deviceid
init_code = str(op.body.init[0].body[0])
assert 'if (deviceid != -1)' in init_code
assert 'int ngpus = omp_get_num_devices()' in init_code
# For MPI case with explicit deviceid, should have validation
assert 'if (deviceid >= ngpus)' in init_code
assert 'Error - device' in init_code
# Should still have MPI rank-based assignment in else clause
assert 'int rank = 0' in init_code
assert 'MPI_Comm_rank(comm,&rank)' in init_code
assert '(rank)%(ngpus)' in init_code

def test_device_validation_error_message(self):
"""Test that device validation includes helpful error messages."""
grid = Grid(shape=(3, 3, 3))

u = TimeFunction(name='u', grid=grid)

op = Operator(Eq(u.forward, u.dx+1), language='openmp')

# Check that the generated code contains device validation
code = str(op)

# Should contain device count check
assert 'omp_get_num_devices()' in code, "Missing device count check"

# Should contain validation condition
assert 'deviceid >= ngpus' in code, "Missing device ID validation condition"

# Should contain error message
assert 'Error - device' in code, "Missing error message"

# Should contain exit call to prevent undefined behavior
assert 'exit(1)' in code, "Missing exit call on validation failure"

def test_basic(self):
grid = Grid(shape=(3, 3, 3))
Expand Down