Skip to content

Commit e658d4a

Browse files
committed
OSHMEM/RUNTIME: initialized SHMEM_TEAM_WORLD and SHMEM_TEAM_SHARED pSync and pWrk
Signed-off-by: Roie Danino <rdanino@nvidia.com>
1 parent e1d0ebb commit e658d4a

File tree

6 files changed

+115
-30
lines changed

6 files changed

+115
-30
lines changed

oshmem/mca/spml/base/base.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929

3030
BEGIN_C_DECLS
3131

32+
/**
33+
* Base team structure - common fields for all SPML team implementations
34+
*/
35+
struct mca_spml_base_team_t {
36+
long *pSync; /* Synchronization work array */
37+
long *pWrk; /* Reduction work array */
38+
};
39+
typedef struct mca_spml_base_team_t mca_spml_base_team_t;
40+
3241
/*
3342
* This is the base priority for a SPML wrapper component
3443
* If there exists more than one then it is undefined
@@ -97,6 +106,20 @@ OSHMEM_DECLSPEC void mca_spml_base_memuse_hook(void *addr, size_t length);
97106
OSHMEM_DECLSPEC int mca_spml_base_put_all_nb(void *target, const void *source,
98107
size_t size, long *counter);
99108

109+
/**
110+
* Helper function to allocate and initialize a sync array using private_alloc
111+
* @param count Number of long elements to allocate
112+
* @param array Pointer to store the allocated array address
113+
* @return OSHMEM_SUCCESS or OSHMEM_ERROR
114+
*/
115+
OSHMEM_DECLSPEC int mca_spml_base_alloc_sync_array(size_t count, long **array);
116+
117+
/**
118+
* Helper function to free a sync array using private_free
119+
* @param array Pointer to the array pointer (will be set to NULL)
120+
*/
121+
OSHMEM_DECLSPEC void mca_spml_base_free_sync_array(long **array);
122+
100123
/*
101124
* MCA framework
102125
*/

oshmem/mca/spml/base/spml_base.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "opal/datatype/opal_convertor.h"
1717
#include "oshmem/proc/proc.h"
1818
#include "oshmem/mca/spml/base/base.h"
19+
#include "oshmem/mca/memheap/memheap.h"
20+
#include "oshmem/mca/memheap/base/base.h"
21+
#include "oshmem/include/shmem.h"
1922
#include "opal/mca/btl/btl.h"
2023

2124
#define SPML_BASE_DO_CMP(_res, _addr, _op, _val) \
@@ -176,3 +179,24 @@ int mca_spml_base_put_all_nb(void *target, const void *source,
176179
{
177180
return OSHMEM_ERR_NOT_IMPLEMENTED;
178181
}
182+
183+
/* Helper function to allocate and initialize a single sync array */
184+
int mca_spml_base_alloc_sync_array(size_t count, long **array)
185+
{
186+
MCA_MEMHEAP_CALL(private_alloc(count * sizeof(long), (void **)array));
187+
if (*array == NULL) {
188+
SPML_ERROR("Failed to allocate sync array");
189+
return OSHMEM_ERROR;
190+
}
191+
memset(*array, 0, count * sizeof(long));
192+
return OSHMEM_SUCCESS;
193+
}
194+
195+
/* Helper function to free a single sync array */
196+
void mca_spml_base_free_sync_array(long **array)
197+
{
198+
if (*array != NULL) {
199+
MCA_MEMHEAP_CALL(private_free(*array));
200+
*array = NULL;
201+
}
202+
}

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,27 +1820,6 @@ int mca_spml_ucx_team_translate_pe(shmem_team_t src_team, int src_pe,
18201820
return (global_pe - ucx_dest_team->start) / ucx_dest_team->stride;
18211821
}
18221822

1823-
/* Helper function to allocate and initialize a single sync array */
1824-
static int mca_spml_ucx_alloc_sync_array(size_t count, long **array)
1825-
{
1826-
MCA_MEMHEAP_CALL(private_alloc(count * sizeof(long), (void **)array));
1827-
if (*array == NULL) {
1828-
SPML_UCX_ERROR("Failed to allocate sync array");
1829-
return OSHMEM_ERROR;
1830-
}
1831-
memset(*array, 0, count * sizeof(long));
1832-
return OSHMEM_SUCCESS;
1833-
}
1834-
1835-
/* Helper function to free a single sync array */
1836-
static void mca_spml_ucx_free_sync_array(long **array)
1837-
{
1838-
if (*array != NULL) {
1839-
MCA_MEMHEAP_CALL(private_free(*array));
1840-
*array = NULL;
1841-
}
1842-
}
1843-
18441823
int mca_spml_ucx_team_split_strided(shmem_team_t parent_team, int start, int
18451824
stride, int size, const shmem_team_config_t *config, long config_mask,
18461825
shmem_team_t *new_team)
@@ -1892,20 +1871,20 @@ int mca_spml_ucx_team_split_strided(shmem_team_t parent_team, int start, int
18921871
ucx_new_team->parent_team = (mca_spml_ucx_team_t*)parent_team;
18931872

18941873
/* Allocate pSync array */
1895-
if (mca_spml_ucx_alloc_sync_array(SHMEM_SYNC_SIZE, &ucx_new_team->pSync) != OSHMEM_SUCCESS) {
1874+
if (mca_spml_base_alloc_sync_array(SHMEM_SYNC_SIZE, &ucx_new_team->super.pSync) != OSHMEM_SUCCESS) {
18961875
goto cleanup_config;
18971876
}
18981877

18991878
/* Allocate pWrk array */
1900-
if (mca_spml_ucx_alloc_sync_array(SHMEM_REDUCE_MIN_WRKDATA_SIZE, &ucx_new_team->pWrk) != OSHMEM_SUCCESS) {
1879+
if (mca_spml_base_alloc_sync_array(SHMEM_REDUCE_MIN_WRKDATA_SIZE, &ucx_new_team->super.pWrk) != OSHMEM_SUCCESS) {
19011880
goto cleanup_psync;
19021881
}
19031882

19041883
*new_team = (shmem_team_t)ucx_new_team;
19051884
return OSHMEM_SUCCESS;
19061885

19071886
cleanup_psync:
1908-
mca_spml_ucx_free_sync_array(&ucx_new_team->pSync);
1887+
mca_spml_base_free_sync_array(&ucx_new_team->super.pSync);
19091888
cleanup_config:
19101889
free(ucx_new_team->config);
19111890
free(ucx_new_team);
@@ -1961,8 +1940,8 @@ int mca_spml_ucx_team_destroy(shmem_team_t team)
19611940
SPML_UCX_VALIDATE_TEAM(team);
19621941

19631942
/* Free pSync and pWrk using private_free */
1964-
mca_spml_ucx_free_sync_array(&ucx_team->pSync);
1965-
mca_spml_ucx_free_sync_array(&ucx_team->pWrk);
1943+
mca_spml_base_free_sync_array(&ucx_team->super.pSync);
1944+
mca_spml_base_free_sync_array(&ucx_team->super.pWrk);
19661945

19671946
free(ucx_team->config);
19681947
free(team);

oshmem/mca/spml/ucx/spml_ucx.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,13 @@ typedef struct mca_spml_ucx_team_config {
130130
} mca_spml_ucx_team_config_t;
131131

132132
typedef struct mca_spml_ucx_team {
133+
mca_spml_base_team_t super;
133134
int n_pes;
134135
int my_pe;
135136
int stride;
136137
int start;
137138
mca_spml_ucx_team_config_t *config;
138139
struct mca_spml_ucx_team *parent_team;
139-
long *pSync;
140-
long *pWrk;
141140
} mca_spml_ucx_team_t;
142141

143142
struct mca_spml_ucx {

oshmem/runtime/oshmem_shmem_finalize.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,19 @@ int oshmem_shmem_finalize(void)
9090
static int _shmem_finalize(void)
9191
{
9292
int ret = OSHMEM_SUCCESS;
93+
mca_spml_base_team_t *world_team, *shared_team;
9394

9495
shmem_barrier_all();
9596

97+
/* Free pSync and pWrk for predefined teams */
98+
world_team = (mca_spml_base_team_t *)oshmem_team_world;
99+
shared_team = (mca_spml_base_team_t *)oshmem_team_shared;
100+
101+
mca_spml_base_free_sync_array(&world_team->pSync);
102+
mca_spml_base_free_sync_array(&world_team->pWrk);
103+
mca_spml_base_free_sync_array(&shared_team->pSync);
104+
mca_spml_base_free_sync_array(&shared_team->pWrk);
105+
96106
shmem_lock_finalize();
97107

98108
/* Finalize preconnect framework */

oshmem/runtime/oshmem_shmem_init.c

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,54 @@ shmem_internal_mutex_t shmem_internal_mutex_alloc = {{0}};
9595

9696
shmem_ctx_t oshmem_ctx_default = NULL;
9797

98-
shmem_team_t oshmem_team_shared = NULL;
99-
shmem_team_t oshmem_team_world = NULL;
98+
/* Predefined teams - statically allocated base team structures */
99+
mca_spml_base_team_t oshmem_team_world_instance = {.pSync = NULL, .pWrk = NULL};
100+
mca_spml_base_team_t oshmem_team_shared_instance = {.pSync = NULL, .pWrk = NULL};
101+
102+
/* Pointers to predefined teams */
103+
shmem_team_t oshmem_team_world = (shmem_team_t)&oshmem_team_world_instance;
104+
shmem_team_t oshmem_team_shared = (shmem_team_t)&oshmem_team_shared_instance;
100105

101106
static int _shmem_init(int argc, char **argv, int requested, int *provided);
102107

108+
/* Helper function to allocate pSync and pWrk for predefined teams */
109+
static int world_and_shared_teams_alloc(void)
110+
{
111+
int ret;
112+
113+
/* Allocate pSync for SHMEM_TEAM_WORLD */
114+
ret = mca_spml_base_alloc_sync_array(SHMEM_SYNC_SIZE, &oshmem_team_world_instance.pSync);
115+
if (OSHMEM_SUCCESS != ret) {
116+
return ret;
117+
}
118+
119+
/* Allocate pWrk for SHMEM_TEAM_WORLD */
120+
ret = mca_spml_base_alloc_sync_array(SHMEM_REDUCE_MIN_WRKDATA_SIZE, &oshmem_team_world_instance.pWrk);
121+
if (OSHMEM_SUCCESS != ret) {
122+
mca_spml_base_free_sync_array(&oshmem_team_world_instance.pSync);
123+
return ret;
124+
}
125+
126+
/* Allocate pSync for SHMEM_TEAM_SHARED */
127+
ret = mca_spml_base_alloc_sync_array(SHMEM_SYNC_SIZE, &oshmem_team_shared_instance.pSync);
128+
if (OSHMEM_SUCCESS != ret) {
129+
mca_spml_base_free_sync_array(&oshmem_team_world_instance.pWrk);
130+
mca_spml_base_free_sync_array(&oshmem_team_world_instance.pSync);
131+
return ret;
132+
}
133+
134+
/* Allocate pWrk for SHMEM_TEAM_SHARED */
135+
ret = mca_spml_base_alloc_sync_array(SHMEM_REDUCE_MIN_WRKDATA_SIZE, &oshmem_team_shared_instance.pWrk);
136+
if (OSHMEM_SUCCESS != ret) {
137+
mca_spml_base_free_sync_array(&oshmem_team_shared_instance.pSync);
138+
mca_spml_base_free_sync_array(&oshmem_team_world_instance.pWrk);
139+
mca_spml_base_free_sync_array(&oshmem_team_world_instance.pSync);
140+
return ret;
141+
}
142+
143+
return OSHMEM_SUCCESS;
144+
}
145+
103146
#if OSHMEM_OPAL_THREAD_ENABLE
104147
static void* shmem_opal_thread(void* argc)
105148
{
@@ -403,6 +446,13 @@ static int _shmem_init(int argc, char **argv, int requested, int *provided)
403446

404447
OPAL_TIMING_ENV_NEXT(timing, "mca_scoll_enable()");
405448

449+
/* Initialize pSync and pWrk for SHMEM_TEAM_WORLD and SHMEM_TEAM_SHARED teams */
450+
if (OSHMEM_SUCCESS != (ret = world_and_shared_teams_alloc())) {
451+
error = "Failed to allocate sync arrays for predefined teams";
452+
goto error;
453+
}
454+
OPAL_TIMING_ENV_NEXT(timing, "world_and_shared_teams_alloc()");
455+
406456
(*provided) = oshmem_mpi_thread_provided;
407457

408458
oshmem_mpi_thread_multiple = (oshmem_mpi_thread_provided == SHMEM_THREAD_MULTIPLE) ? true : false;

0 commit comments

Comments
 (0)