Skip to content

Commit

Permalink
cuda reduce kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
nychiang committed Mar 13, 2023
1 parent 63027b7 commit 3cf6bde
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/LinAlg/VectorCudaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,9 @@ __global__ void add_linear_damping_term_cu(int n, double* data, const double* ix
}

/** @brief y[i] = 1.0 if x[i] is positive and id[i] = 1.0, otherwise y[i] = 0 */
__global__ void is_posive_w_pattern_cu(int n, double* data, const double* vd, const double* id)
__global__ void is_posive_w_pattern_cu(int n, int* data, const double* vd, const double* id)
{
extern __shared__ float shared_sum[];
extern __shared__ int shared_sum[];
const int num_threads = blockDim.x * gridDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
int sum = 0;
Expand Down Expand Up @@ -957,7 +957,7 @@ int is_posive_w_pattern_kernel(int n_local,
cudaMemcpy(h_retval, d_retval, num_blocks*sizeof(int), cudaMemcpyDeviceToHost);

int sum_result = 0;
for(int i=0;i<block_size;i++) {
for(int i=0;i<num_blocks;i++) {
sum_result += h_retval[i];
}

Expand Down Expand Up @@ -1242,9 +1242,10 @@ int all_positive_w_pattern_kernel(int n, const double* d1, const double* id)
// TODO: how to avoid this temp vec?
// thrust::device_vector<double> v_temp(n);
// double* dv_ptr = thrust::raw_pointer_cast(v_temp.data());
// is_posive_w_pattern_kernel(n, dv_ptr, d1, id);
// return thrust::reduce(thrust::device, v_temp.begin(), v_temp.end(), (int)0, thrust::plus<int>());

int irev = hiop::cuda::is_posive_w_pattern_kernel(n, dv_ptr, d1, id);
int irev = hiop::cuda::is_posive_w_pattern_kernel(n, d1, id);
return irev;
}

Expand Down

0 comments on commit 3cf6bde

Please sign in to comment.