Skip to content

Commit

Permalink
Implement cublasDgetrsBatched.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Mar 18, 2024
1 parent 605254b commit cd1e0a3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
14 changes: 13 additions & 1 deletion zluda_blas/src/cublas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4366,7 +4366,19 @@ pub unsafe extern "system" fn cublasDgetrsBatched(
info: *mut ::std::os::raw::c_int,
batchSize: ::std::os::raw::c_int,
) -> cublasStatus_t {
crate::unsupported()
crate::dgetrs_batched(
handle,
trans,
n,
nrhs,
Aarray,
lda,
devIpiv,
Barray,
ldb,
info,
batchSize,
)
}

#[no_mangle]
Expand Down
31 changes: 31 additions & 0 deletions zluda_blas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rocsolver_sys::{
rocsolver_cgetrf_batched,
rocsolver_cgetri_outofplace_batched,
rocsolver_sgetrs_batched,
rocsolver_dgetrs_batched,
rocsolver_zgetrf_batched,
rocsolver_zgetri_outofplace_batched,
};
Expand Down Expand Up @@ -742,6 +743,36 @@ unsafe fn sgetrs_batched(
))
}

unsafe fn dgetrs_batched(
handle: *mut cublasContext,
trans: cublasOperation_t,
n: i32,
nrhs: i32,
a: *const *const f64,
lda: i32,
dev_ipiv: *const i32,
b: *const *mut f64,
ldb: i32,
info: *mut i32,
batch_size: i32,
) -> cublasStatus_t {
let trans = op_from_cuda_for_solver(trans);
let stride = n * nrhs;
to_cuda_solver(rocsolver_dgetrs_batched(
handle.cast(),
trans,
n,
nrhs,
a.cast(),
lda,
dev_ipiv,
stride as _,
b,
ldb,
batch_size,
))
}

unsafe fn dtrmm_v2(
handle: *mut cublasContext,
side: cublasSideMode_t,
Expand Down

0 comments on commit cd1e0a3

Please sign in to comment.