4 #include "../sparse_formats/csr.h" 9 template<
typename dtype_A,
typename dtype_B,
typename dtype_C, dtype_C(*f)(dtype_A, dtype_B),
void(*g)(dtype_C, dtype_C&)>
10 __global__
void cuda_gemmf(
char tA,
18 int bidx = blockIdx.x;
19 int tidx = threadIdx.x;
20 int lda_A_m = tA ==
'N' ? 1 : k;
21 int lda_A_k = tA ==
'N' ? m : 1;
22 int lda_B_k = tB ==
'N' ? 1 : n;
23 int lda_B_n = tB ==
'N' ? k : 1;
24 for (
int mi=bidx; mi<m; mi+=NBLK){
25 for (
int ni=tidx; ni<n; ni+=NTRD){
26 for (
int ki=0; ki<k; ki++){
27 g(f(A[mi*lda_A_m+ki*lda_A_k],
28 B[ki*lda_B_k+ni*lda_B_n]),
35 template<
typename dtype_A,
typename dtype_B,
typename dtype_C, dtype_C(*f)(dtype_A, dtype_B),
void(*g)(dtype_C, dtype_C&)>
37 void cuda_csrmmf(
int m,
45 int bidx = blockIdx.x;
46 int tidx = threadIdx.x;
47 for (
int col_B=bidx; col_B<n; col_B+=NBLK){
48 for (
int row_A=tidx; row_A<m; row_A+=NTRD){
49 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
50 int col_A = JA[i_A]-1;
51 g(f(A[i_A],B[col_B*k+col_A]),C[col_B*m+row_A]);
58 template<
typename dtype_A,
typename dtype_B,
typename dtype_C, dtype_C(*f)(dtype_A, dtype_B),
void(*g)(dtype_C, dtype_C&)>
60 void cuda_csrmmf(
int m,
68 int bidx = blockIdx.x;
69 int tidx = threadIdx.x;
70 for (
int col_B=bidx; col_B<n; col_B+=NBLK){
71 for (
int row_A=tidx; row_A<m; row_A+=NTRD){
72 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
73 int col_A = JA[i_A]-1;
74 g(f(A[i_A],B[col_B*k+col_A]),C[col_B*m+row_A]);
83 template<
typename dtype_A,
typename dtype_B,
typename dtype_C, dtype_C(*f)(dtype_A, dtype_B),
void(*g)(dtype_C, dtype_C&)>
85 void offload_csrmm(
int m,
88 char const * all_data,
91 int64_t nnz_A = ((int64_t*)all_data)[0];
92 int offset = 3*
sizeof(int64_t);
94 dtype_A
const * A = (dtype_A
const *)(all_data + offset);
95 offset += nnz_A*
sizeof(dtype_A);
97 int const * IA = (
int*)(all_data + offset);
98 offset += (m+1)*
sizeof(
int);
100 int const * JA = (
int*)(all_data + offset);
101 cuda_csrmmf<dtype_A,dtype_B,dtype_C,f,g>(m,n,k,A,JA,IA,B,C);
108 template<
typename dtype>
114 template<
typename dtype=
double,
void(*g)(dtype, dtype&)=default_mono
id<dtype> >
122 [](
void *
a,
void *
b,
int * n, MPI_Datatype*){
123 for (
int i=0; i<*n; i++){
147 for (
int i=0; i<n; i++){
148 g(X[incX*i],Y[incY*i]);
157 for (
int i=0; i<n; i++){
158 memcpy(arr+i*
el_size,(
char*)&dummy,el_size);
165 template<
typename dtype_A,
typename dtype_B,
typename dtype_C, dtype_C(*f)(dtype_A, dtype_B),
void(*g)(dtype_C, dtype_C&)=default_mono
id<dtype_C> >
169 this->has_kernel =
true;
171 this->has_off_gemm =
true;
173 this->
el_size =
sizeof(dtype_C);
177 this->has_kernel =
true;
179 this->has_off_gemm =
true;
193 int lda_A_m = tA ==
'N' ? 1 : k;
194 int lda_A_k = tA ==
'N' ? m : 1;
195 int lda_B_k = tB ==
'N' ? 1 : n;
196 int lda_B_n = tB ==
'N' ? k : 1;
198 #pragma omp parallel for 200 for (
int mi=0; mi<m; mi++){
202 #pragma omp parallel for 204 for (
int ni=0; ni<n; ni++){
205 for (
int ki=0; ki<k; ki++){
206 g(f(A[mi*lda_A_m+ki*lda_A_k],
207 B[ki*lda_B_k+ni*lda_B_n]),
225 for (
int i=0; i<nnz_A; i++){
226 int row_A = rows_A[i]-1;
227 int col_A = cols_A[i]-1;
228 for (
int col_C=0; col_C<n; col_C++){
229 g(f(A[i],B[col_C*k+col_A]),C[col_C*m+row_A]);
244 coomm(m, n, k, (dtype_A
const *)A, rows_A, cols_A, nnz_A,
245 (dtype_B
const *)B, (dtype_C *)C);
259 #pragma omp parallel for 261 for (
int row_A=0; row_A<m; row_A++){
263 #pragma omp parallel for 265 for (
int col_B=0; col_B<n; col_B++){
266 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
267 int col_A = JA[i_A]-1;
268 g(f(A[i_A],B[col_B*k+col_A]),C[col_B*m+row_A]);
282 gemm(tA, tB, m, n, k,
283 (dtype_A
const *)A, (dtype_B
const *)B, (dtype_C *)C);
302 #pragma omp parallel for 304 for (
int row_A=0; row_A<m; row_A++){
305 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
306 int row_B = JA[i_A]-1;
307 for (
int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
308 int col_B = JB[i_B]-1;
309 g(f(A[i_A],B[i_B]),C[col_B*m+row_A]);
328 char *& C_CSR)
const {
332 for (
int i=0; i<m; i++){
333 memset(has_col, 0,
sizeof(
int)*n);
336 for (
int j=0; j<n; j++){
337 IC[i+1] += has_col[j];
341 dtype_C * vC = (dtype_C*)C.
vals();
343 memcpy(C.
IA(), IC,
sizeof(int)*(m+1));
347 for (
int i=0; i<m; i++){
348 memset(has_col, 0,
sizeof(
int)*n);
351 for (
int j=0; j<n; j++){
353 JC[IC[i]+vs-1] = j+1;
354 rev_col[j] = IC[i]+vs-1;
358 memset(has_col, 0,
sizeof(
int)*n);
359 for (
int j=0; j<IA[i+1]-IA[i]; j++){
360 int row_B = JA[IA[i]+j-1]-1;
361 int idx_A = IA[i]+j-1;
362 for (
int l=0; l<IB[row_B+1]-IB[row_B]; l++){
363 int idx_B = IB[row_B]+l-1;
364 if (has_col[JB[idx_B]-1])
365 g(f(A[idx_A],B[idx_B]), vC[rev_col[JB[idx_B]-1]]);
367 vC[rev_col[JB[idx_B]-1]] = f(A[idx_A],B[idx_B]);
368 has_col[JB[idx_B]-1] = 1;
373 if (C_CSR == NULL || C_in.
nnz() == 0){
396 char *& C_CSR)
const {
399 memset(IC, 0,
sizeof(
int)*(m+1));
407 #pragma omp for schedule(dynamic) // TO DO test other strategies 409 for (
int i=0; i<m; i++){
410 memset(has_col, 0,
sizeof(
int)*(n+1));
412 for (
int j=0; j<IA[i+1]-IA[i]; j++){
413 int row_B = JA[IA[i]+j-1]-1;
414 for (
int kk=0; kk<IB[row_B+1]-IB[row_B]; kk++){
415 int idx_B = IB[row_B]+kk-1;
416 if (has_col[JB[idx_B]] == 0){
418 has_col[JB[idx_B]] = 1;
429 for(
int i=0;i < m+1; i++){
434 dtype_C * vC = (dtype_C*)C.
vals();
436 memcpy(C.
IA(), IC,
sizeof(int)*(m+1));
445 dtype_C *acc_data =
new dtype_C[n];
449 for (
int i=0; i<m; i++){
450 memset(dcol, 0,
sizeof(
int)*(n));
452 for (
int j=0; j<IA[i+1]-IA[i]; j++){
453 int row_b = JA[IA[i]+j-1]-1;
454 int idx_a = IA[i]+j-1;
455 for (
int ii = 0; ii < IB[row_b+1]-IB[row_b]; ii++){
456 int col_b = IB[row_b]+ii-1;
457 int col_c = JB[col_b]-1;
459 if (dcol[col_c] == 0){
460 dcol[col_c] = JB[col_b];
461 acc_data[col_c] =f(A[idx_a],B[col_b]);
463 g(f(A[idx_a],B[col_b]), acc_data[col_c]);
467 for(
int jj = 0; jj < n; jj++){
469 JC[IC[i]+ins-1] = dcol[jj];
470 vC[IC[i]+ins-1] = acc_data[jj];
481 if (C_CSR == NULL || C_in.
nnz() == 0){
507 csrmultd(m,n,k,(dtype_A
const *)A,JA,IA,nnz_A,(dtype_B
const *)B,JB,IB,nnz_B,(dtype_C *)C);
524 csrmultcsr(m,n,k,(dtype_A
const *)A,JA,IA,nnz_A,(dtype_B
const *)B, JB, IB, nnz_B, C_CSR);
537 csrmm(m,n,k,(dtype_A
const *)A,JA,IA,nnz_A,(dtype_B
const *)B, (dtype_C *)C);
551 #ifdef PROFILE_CUGEMM 554 cuda_gemmf<dtype_A,dtype_B,dtype_C,f,g><<<NBLK,NTRD>>>(tA, tB, m, n, k, A, B, C);
555 #ifdef PROFILE_CUGEMM 556 cudaDeviceSynchronize();
572 offload_gemm(tA, tB, m, n, k, (dtype_A
const *)A, (dtype_B
const *)B, (dtype_C*)C);
591 char const * all_data,
595 #ifdef PROFILE_CUGEMM 598 offload_csrmm<dtype_A,dtype_B,dtype_C,f,g><<<NBLK,NTRD>>>(m, n, k, all_data, (dtype_B
const *)B, (dtype_C *)C);
599 #ifdef PROFILE_CUGEMM 600 cudaDeviceSynchronize();
void ccsrmm(int m, int n, int k, char const *A, int const *JA, int const *IA, int64_t nnz_A, char const *B, char *C, CTF_int::algstrct const *sr_C) const
int * IA() const
retrieves prefix sum of number of nonzeros for each row (of size nrow()+1) out of all_data ...
void cgemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
void accum(char const *a, char *b) const
b+=a
static char * csr_add(char *cA, char *cB, accumulatable const *adder)
void ccoomm(int m, int n, int k, char const *A, int const *rows_A, int const *cols_A, int64_t nnz_A, char const *B, char *C) const
void * alloc(int64_t len)
alloc abstraction
static void gemm(char tA, char tB, int m, int n, int k, dtype_A const *A, dtype_B const *B, dtype_C *C)
custom bivariate function on two tensors: e.g. C["ij"] = f(A["ik"],B["kj"])
virtual void init_shell(int64_t n, char *arr) const
initialize n objects to zero
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
static void compute_has_col(int const *JA, int const *IA, int const *JB, int const *IB, int i, int *has_col)
int * JA() const
retrieves column indices of each value in vals stored in sorted form by row
void coffload_gemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
int64_t nnz() const
retrieves number of nonzeros out of all_data
abstract class that knows how to add
static void offload_gemm(char tA, char tB, int m, int n, int k, dtype_A const *A, dtype_B const *B, dtype_C *C)
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
static MPI_Op get_MPI_Op()
static void coomm(int m, int n, int k, dtype_A const *A, int const *rows_A, int const *cols_A, int nnz_A, dtype_B const *B, dtype_C *C)
Bivar_Kernel(bool is_comm)
char * all_data
serialized buffer containing all info, index, and values related to matrix
char * vals() const
retrieves array of values out of all_data
static void csrmm(int m, int n, int k, dtype_A const *A, int const *JA, int const *IA, int64_t nnz_A, dtype_B const *B, dtype_C *C)
int el_size
size of each element of algstrct in bytes
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void default_monoid(dtype a, dtype &b)
static void xpy(int n, dtype const *X, int incX, dtype *Y, int incY)
void coffload_csrmm(int m, int n, int k, char const *all_data, char const *B, char *C) const