Skip to content

Commit

Permalink
omp: update target api following upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Nov 2, 2023
1 parent 96c7f39 commit c77b2d3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
1 change: 1 addition & 0 deletions lib/targets/omptarget/malloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <quda_internal.h>
#include <device.h>
#include <shmem_helper.cuh>
#include "timer.h"

#ifdef USE_QDPJIT
#include "qdp_quda.h"
Expand Down
42 changes: 34 additions & 8 deletions lib/targets/omptarget/quda_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ namespace quda
QudaMem copy(dst, src, count, kind, device::get_default_stream(), false, func, file, line);
}

void qudaMemcpy_(const quda_ptr &dst, const quda_ptr &src, size_t count, qudaMemcpyKind kind, const char *func,
const char *file, const char *line)
{
if (count == 0) return;
QudaMem copy(dst.data(), src.data(), count, kind, device::get_default_stream(), false, func,
file, line);
}

void qudaMemcpyAsync_(void *dst, const void *src, size_t count, qudaMemcpyKind kind, const qudaStream_t &stream,
const char *func, const char *file, const char *line)
{
Expand All @@ -339,25 +347,43 @@ namespace quda
QudaMem set(ptr, value, count, device::get_default_stream(), false, func, file, line);
}

void qudaMemset_(quda_ptr &ptr, int value, size_t count, const char *func, const char *file, const char *line)
{
if (count == 0) return;
if (ptr.is_device()) {
QudaMem set(ptr.data(), value, count, device::get_default_stream(), false, func, file, line);
} else {
memset(ptr.data(), value, count);
}
}

void qudaMemsetAsync_(void *ptr, int value, size_t count, const qudaStream_t &stream, const char *func,
const char *file, const char *line)
{
if (count == 0) return;
QudaMem copy(ptr, value, count, stream, true, func, file, line);
}

void qudaMemset2D_(void *ptr, size_t pitch, int value, size_t width, size_t height, const char *func,
const char *file, const char *line)
void qudaMemsetAsync_(quda_ptr &ptr, int value, size_t count, const qudaStream_t &stream, const char *func,
const char *file, const char *line)
{
auto error = ompMemset2D(ptr, pitch, value, width, height);
set_runtime_error(error, __func__, func, file, line);
if (count == 0) return;
if (ptr.is_device()) {
QudaMem set(ptr.data(), value, count, stream, true, func, file, line);
} else {
memset(ptr.data(), value, count);
}
}

void qudaMemset2DAsync_(void *ptr, size_t pitch, int value, size_t width, size_t height, const qudaStream_t &stream,
const char *func, const char *file, const char *line)
void qudaMemset2DAsync_(quda_ptr &ptr, size_t offset, size_t pitch, int value, size_t width, size_t height,
const qudaStream_t &stream, const char *func, const char *file, const char *line)
{
auto error = ompMemset2DAsync(ptr, pitch, value, width, height, stream);
set_runtime_error(error, __func__, func, file, line);
if (ptr.is_device()) {
auto error = ompMemset2DAsync(static_cast<char *>(ptr.data()) + offset, pitch, value, width, height, stream);
set_runtime_error(error, __func__, func, file, line);
} else {
for (auto i = 0u; i < height; i++) memset(static_cast<char *>(ptr.data()) + offset + i * pitch, value, width);
}
}

void qudaMemPrefetchAsync_(void *ptr, size_t count, QudaFieldLocation mem_space, const qudaStream_t &stream,
Expand Down

0 comments on commit c77b2d3

Please sign in to comment.