Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCT: Memory window #10332

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
15 changes: 14 additions & 1 deletion src/uct/api/v2/uct_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ typedef struct {
typedef enum {
UCT_MD_MEM_REG_FIELD_FLAGS = UCS_BIT(0),
UCT_MD_MEM_REG_FIELD_DMABUF_FD = UCS_BIT(1),
UCT_MD_MEM_REG_FIELD_DMABUF_OFFSET = UCS_BIT(2)
UCT_MD_MEM_REG_FIELD_DMABUF_OFFSET = UCS_BIT(2),
UCT_MD_MEM_REG_FIELD_MEMH = UCS_BIT(3)
} uct_md_mem_reg_field_mask_t;


Expand Down Expand Up @@ -480,6 +481,18 @@ typedef struct uct_md_mem_reg_params {
* dmabuf region, then this field must be omitted or set to 0.
*/
size_t dmabuf_offset;

/**
* Represents a pointer to the existing memory handle.
* Used to register a derived memory handle: a shallow copy of existing UCT
* memory handle, which can be used to access the same memory region. When
* created, the derived memh inherits the original memh access flags and
* state. The lifetime of the derived memh is bound to the original memh,
* and the original memh cannot be destroyed until all its derived handles
* are destroyed. The derived memh cannot be used to register another
* derived memh.
*/
uct_mem_h memh;
brminich marked this conversation as resolved.
Show resolved Hide resolved
} uct_md_mem_reg_params_t;


Expand Down
8 changes: 8 additions & 0 deletions src/uct/cuda/cuda_ipc/cuda_ipc_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,17 @@ static ucs_status_t
uct_cuda_ipc_mem_reg(uct_md_h md, void *address, size_t length,
const uct_md_mem_reg_params_t *params, uct_mem_h *memh_p)
{
uct_mem_h base = (params != NULL) ?
UCT_MD_MEM_REG_FIELD_VALUE(params, memh, FIELD_MEMH, NULL) :
NULL;
uct_cuda_ipc_memh_t *memh;
CUdevice cu_device;

if (ENABLE_PARAMS_CHECK && (base != NULL)) {
ucs_error("CUDA IPC does not support derived memory handles");
return UCS_ERR_UNSUPPORTED;
}

UCT_CUDA_IPC_GET_DEVICE(cu_device);

memh = ucs_malloc(sizeof(*memh), "uct_cuda_ipc_memh_t");
Expand Down
41 changes: 37 additions & 4 deletions src/uct/ib/base/ib_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,19 +584,35 @@ ucs_status_t uct_ib_mem_advise(uct_md_h uct_md, uct_mem_h memh, void *addr,
return UCS_OK;
}

ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p)
static uct_ib_mem_t *
uct_ib_memh_alloc_internal(uct_ib_md_t *md, size_t memh_base_size,
size_t mr_size, size_t *memh_size_p)
{
int num_mrs = md->relaxed_order ?
2 /* UCT_IB_MR_DEFAULT and UCT_IB_MR_STRICT_ORDER */ :
1 /* UCT_IB_MR_DEFAULT */;
uct_ib_mem_t *memh;

memh = ucs_calloc(1, memh_base_size + (mr_size * num_mrs), "ib_memh");
*memh_size_p = memh_base_size + (mr_size * num_mrs);
memh = ucs_calloc(1, *memh_size_p, "ib_memh");
if (memh == NULL) {
ucs_error("%s: failed to allocated memh struct",
uct_ib_device_name(&md->dev));
return NULL;
}

return memh;
}

ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p)
{
uct_ib_mem_t *memh;
size_t memh_size;

memh = uct_ib_memh_alloc_internal(md, memh_base_size, mr_size, &memh_size);
if (memh == NULL) {
return UCS_ERR_NO_MEMORY;
}

Expand Down Expand Up @@ -626,6 +642,23 @@ ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
return UCS_OK;
}

ucs_status_t uct_ib_memh_clone(uct_ib_md_t *md, const uct_ib_mem_t *src,
size_t memh_base_size, size_t mr_size,
uct_ib_mem_t **memh_p)
{
uct_ib_mem_t *memh;
size_t memh_size;

memh = uct_ib_memh_alloc_internal(md, memh_base_size, mr_size, &memh_size);
if (memh == NULL) {
return UCS_ERR_NO_MEMORY;
}

memcpy(memh, src, memh_size);
*memh_p = memh;
return UCS_OK;
}

uint64_t uct_ib_memh_access_flags(uct_ib_mem_t *memh, int relaxed_order,
uint64_t access_flags)
{
Expand Down
6 changes: 6 additions & 0 deletions src/uct/ib/base/ib_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ enum {
#endif
UCT_IB_MEM_FLAG_GVA = UCS_BIT(5), /**< The memory handle is a
GVA region */
UCT_IB_MEM_FLAG_DERIVED = UCS_BIT(6), /**< The memory handle is a
derived memh */
};

enum {
Expand Down Expand Up @@ -432,4 +434,8 @@ ucs_status_t uct_ib_memh_alloc(uct_ib_md_t *md, size_t length,
unsigned mem_flags, size_t memh_base_size,
size_t mr_size, uct_ib_mem_t **memh_p);

ucs_status_t uct_ib_memh_clone(uct_ib_md_t *md, const uct_ib_mem_t *src,
size_t memh_base_size, size_t mr_size,
uct_ib_mem_t **memh_p);

#endif
59 changes: 59 additions & 0 deletions src/uct/ib/mlx5/dv/ib_mlx5dv_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,26 @@ uct_ib_mlx5_devx_memh_alloc(uct_ib_mlx5_md_t *md, size_t length,
return UCS_OK;
}

static ucs_status_t
uct_ib_mlx5_devx_memh_clone(uct_ib_mlx5_md_t *md,
const uct_ib_mlx5_devx_mem_t *src,
uct_ib_mlx5_devx_mem_t **memh_p)
{
size_t mr_size = src->super.flags & UCT_IB_MEM_IMPORTED ?
0 : sizeof(src->mrs[0]);
uct_ib_mem_t *ib_memh;
ucs_status_t status;

status = uct_ib_memh_clone(&md->super, &src->super, sizeof(**memh_p),
mr_size, &ib_memh);
if (status != UCS_OK) {
return status;
}

*memh_p = ucs_derived_of(ib_memh, uct_ib_mlx5_devx_mem_t);
return UCS_OK;
}

static int
uct_ib_mlx5_devx_memh_has_ro(uct_ib_mlx5_md_t *md, uct_ib_mlx5_devx_mem_t *memh)
{
Expand Down Expand Up @@ -837,13 +857,42 @@ uct_ib_mlx5_devx_mem_reg_gva(uct_md_h uct_md, unsigned flags, uct_mem_h *memh_p)
return status;
}

static ucs_status_t
uct_ib_mlx5_devx_derived_mem_reg(uct_md_h uct_md, uct_ib_mlx5_devx_mem_t *base,
uct_mem_h *memh_p)
{
uct_ib_mlx5_md_t *md = ucs_derived_of(uct_md, uct_ib_mlx5_md_t);
uct_ib_mlx5_devx_mem_t *memh;
ucs_status_t status;

ucs_assertv(!(base->super.flags & UCT_IB_MEM_FLAG_DERIVED),
"memh=%p is already a derived memh", base);

status = uct_ib_mlx5_devx_memh_clone(md, base, &memh);
if (status != UCS_OK) {
ucs_error("%s: failed to clone memory handle: %s",
uct_ib_mlx5_dev_name(md), ucs_status_string(status));
return status;
}

memh->super.flags |= UCT_IB_MEM_FLAG_DERIVED;
memh->atomic_dvmr = NULL;
memh->atomic_rkey = UCT_IB_INVALID_MKEY;
memh->indirect_dvmr = NULL;
memh->indirect_rkey = UCT_IB_INVALID_MKEY;

*memh_p = memh;
return UCS_OK;
}

ucs_status_t
uct_ib_mlx5_devx_mem_reg(uct_md_h uct_md, void *address, size_t length,
const uct_md_mem_reg_params_t *params,
uct_mem_h *memh_p)
{
uct_ib_mlx5_md_t *md = ucs_derived_of(uct_md, uct_ib_mlx5_md_t);
unsigned flags = UCT_MD_MEM_REG_FIELD_VALUE(params, flags, FIELD_FLAGS, 0);
uct_mem_h base = UCT_MD_MEM_REG_FIELD_VALUE(params, memh, FIELD_MEMH, NULL);
uct_ib_mlx5_devx_mem_t *memh;
ucs_status_t status;
uint32_t dummy_mkey;
Expand All @@ -852,6 +901,10 @@ uct_ib_mlx5_devx_mem_reg(uct_md_h uct_md, void *address, size_t length,
return uct_ib_mlx5_devx_mem_reg_gva(uct_md, flags, memh_p);
}

if (base != NULL) {
return uct_ib_mlx5_devx_derived_mem_reg(uct_md, base, memh_p);
}

status = uct_ib_mlx5_devx_memh_alloc(md, length, flags,
sizeof(memh->mrs[0]), &memh);
if (status != UCS_OK) {
Expand Down Expand Up @@ -1509,6 +1562,11 @@ uct_ib_mlx5_devx_mem_dereg(uct_md_h uct_md,
return status;
}

/* Derived memh owns only indirect keys, but not the other state */
if (memh->super.flags & UCT_IB_MEM_FLAG_DERIVED) {
goto out;
}

if (memh->smkey_mr != NULL) {
ucs_trace("%s: destroy smkey_mr %p with key %x",
uct_ib_device_name(&md->super.dev), memh->smkey_mr,
Expand Down Expand Up @@ -1567,6 +1625,7 @@ uct_ib_mlx5_devx_mem_dereg(uct_md_h uct_md,
uct_invoke_completion(params->comp, UCS_OK);
}

out:
ucs_free(memh);
return UCS_OK;
}
Expand Down
56 changes: 50 additions & 6 deletions test/gtest/uct/ib/test_ib_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ class test_ib_md : public test_md
void test_mkey_pack_mt_internal(unsigned access_mask, bool invalidate);
void test_smkey_reg_atomic(void);

uct_mem_h reg_derived_mem(uct_mem_h base) const {
uct_mem_h memh;
uct_md_mem_reg_params_t params;
params.field_mask = UCT_MD_MEM_REG_FIELD_MEMH;
params.memh = base;
ASSERT_UCS_OK(uct_md_mem_reg_v2(md(), NULL, SIZE_MAX, &params, &memh));
return memh;
}

private:
#ifdef HAVE_MLX5_DV
uint32_t m_mlx5_flags = 0;
Expand Down Expand Up @@ -272,12 +281,7 @@ void test_ib_md::test_mkey_pack_mt_internal(unsigned access_mask,
uct_ib_mem_t *ib_memh = (uct_ib_mem_t*)memh;
EXPECT_TRUE(ib_memh->flags & UCT_IB_MEM_MULTITHREADED);

std::vector<uint8_t> rkey(md_attr().rkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = pack_flags;
ASSERT_UCS_OK(uct_md_mkey_pack_v2(md(), memh, buffer, size,
&pack_params, rkey.data()));
mkey_pack(memh, pack_flags, buffer, size);

uct_md_mem_dereg_params_t params;
params.field_mask = UCT_MD_MEM_DEREG_FIELD_MEMH |
Expand Down Expand Up @@ -378,6 +382,46 @@ UCS_TEST_P(test_ib_md, mt_fail, "IB_REG_MT_THRESH=128K", "IB_REG_MT_CHUNK=16K")
}
}

UCS_TEST_SKIP_COND_P(test_ib_md, derived_mem,
!check_invalidate_support(UCT_MD_MEM_ACCESS_RMA))
{
bool is_atomic = check_caps(UCT_MD_FLAG_INVALIDATE_AMO);
unsigned flags = UCT_MD_MKEY_PACK_FLAG_INVALIDATE_RMA |
(is_atomic ? UCT_MD_MKEY_PACK_FLAG_INVALIDATE_AMO : 0);
unsigned md_flags = UCT_MD_MEM_ACCESS_RMA |
(is_atomic ? UCT_MD_MEM_ACCESS_REMOTE_ATOMIC : 0);
std::vector<uint8_t> buffer(1024);
uct_mem_h base;
EXPECT_UCS_OK(reg_mem(md_flags, buffer.data(), buffer.size(), &base));

/* Test case 1: creating derived memh from memh before mkey_pack */
uct_mem_h der1 = reg_derived_mem(base);

/* Test case 2: creating derived memh from memh after mkey_pack */
std::vector<uint8_t> base_rkey1 = mkey_pack(base, flags);
uct_mem_h der2 = reg_derived_mem(base);
std::vector<uint8_t> der2_rkey1 = mkey_pack(der2, flags);
EXPECT_NE(base_rkey1, der2_rkey1);

/* Test case 3: subsequent mkey_pack calls return the same result */
std::vector<uint8_t> der2_rkey2 = mkey_pack(der2, flags);
EXPECT_EQ(der2_rkey1, der2_rkey2);

/* Test case 4: multiple derived memhs do not share the same rkeys */
std::vector<uint8_t> der1_rkey1 = mkey_pack(der1, flags);
EXPECT_NE(der1_rkey1, der2_rkey1);

/* Invalidation = destroying derived memh */
EXPECT_UCS_OK(uct_md_mem_dereg(md(), der1));
EXPECT_UCS_OK(uct_md_mem_dereg(md(), der2));

/* Test case 5: base memh can still be used to pack mkeys */
std::vector<uint8_t> base_rkey2 = mkey_pack(base, flags);
EXPECT_EQ(base_rkey1, base_rkey2);

EXPECT_UCS_OK(uct_md_mem_dereg(md(), base));
}

_UCT_MD_INSTANTIATE_TEST_CASE(test_ib_md, ib)

class test_ib_md_non_blocking : public test_md_non_blocking {
Expand Down
16 changes: 2 additions & 14 deletions test/gtest/uct/test_md.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,7 @@ void test_md::test_reg_mem(unsigned access_mask,
status = uct_md_mem_dereg_v2(md(), &params);
ASSERT_UCS_STATUS_EQ(UCS_ERR_INVALID_PARAM, status);

std::vector<uint8_t> rkey(md_attr().rkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = invalidate_flag;
status = uct_md_mkey_pack_v2(md(), memh, ptr, size, &pack_params,
rkey.data());
EXPECT_UCS_OK(status);
mkey_pack(memh, invalidate_flag, ptr, size);

status = uct_md_mem_dereg_v2(md(), &params);
}
Expand Down Expand Up @@ -963,13 +957,7 @@ UCS_TEST_SKIP_COND_P(test_md, exported_mkey,
status = reg_mem(UCT_MD_MEM_ACCESS_ALL, address, size, &export_memh);
ASSERT_UCS_OK(status);

std::vector<uint8_t> mkey_buffer(md_attr().exported_mkey_packed_size);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = UCT_MD_MKEY_PACK_FLAG_EXPORT;
status = uct_md_mkey_pack_v2(md(), export_memh, address, size, &pack_params,
mkey_buffer.data());
ASSERT_UCS_OK(status);
mkey_pack(export_memh, UCT_MD_MKEY_PACK_FLAG_EXPORT, address, size);

uct_md_mem_dereg_params_t dereg_params;
dereg_params.field_mask = UCT_MD_MEM_DEREG_FIELD_MEMH;
Expand Down
15 changes: 15 additions & 0 deletions test/gtest/uct/test_md.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,21 @@ class test_md : public testing::TestWithParam<test_md_param>,
return m_md_attr;
}

std::vector<uint8_t>
mkey_pack(uct_mem_h memh, unsigned flags = 0, void *ptr = NULL,
size_t size = SIZE_MAX) const {
size_t rkey_size = flags & UCT_MD_MKEY_PACK_FLAG_EXPORT ?
md_attr().exported_mkey_packed_size :
md_attr().rkey_packed_size;
std::vector<uint8_t> rkey(rkey_size, 0);
uct_md_mkey_pack_params_t pack_params;
pack_params.field_mask = UCT_MD_MKEY_PACK_FIELD_FLAGS;
pack_params.flags = flags;
EXPECT_UCS_OK(uct_md_mkey_pack_v2(md(), memh, ptr, size, &pack_params,
rkey.data()));
return rkey;
}

typedef struct {
test_md *self;
uct_completion_t comp;
Expand Down
Loading