2 #include "../shared/blas_symbs.h" 3 #include "../shared/offload.h" 4 #include "../sparse_formats/csr.h" 7 #include "../shared/mkl_symbs.h" 14 template <
typename dtype>
27 if (m == 1 && n == 1 && k == 1) {
28 for (
int i=0; i<l; i++){
30 C[i]+=alpha*A[i]*B[i];
36 if (taA ==
'n' || taA ==
'N'){
41 if (taB ==
'n' || taB ==
'N'){
51 int size_per_group = l;
52 CTF_BLAS::gemm_batch<dtype>(&taA, &taB, &m, &n, &k, &alpha, ptrs_A, &lda, ptrs_B, &ldb, &beta, ptrs_C, &ldc, &group_count, &size_per_group);
54 for (
int i=0; i<l; i++){
55 CTF_BLAS::gemm<dtype>(&taA,&taB,&m,&n,&k,&alpha, ptrs_A[i] ,&lda, ptrs_B[i] ,&ldb,&beta, ptrs_C[i] ,&ldc);
63 #define INST_GEMM_BATCH(dtype) \ 64 template void gemm_batch<dtype>( char , \ 79 #undef INST_GEMM_BATCH 81 template <
typename dtype>
92 int lda, lda_B, lda_C;
94 if (tA ==
'n' || tA ==
'N'){
99 if (tB ==
'n' || tB ==
'N'){
104 CTF_BLAS::gemm<dtype>(&tA,&tB,&m,&n,&k,&alpha,A,&lda,B,&lda_B,&beta,C,&lda_C);
107 #define INST_GEMM(dtype) \ 108 template void gemm<dtype>( char , \ 149 void default_axpy< std::complex<float> >
151 std::complex<float> alpha,
152 std::complex<float>
const * X,
154 std::complex<float> * Y,
160 void default_axpy< std::complex<double> >
162 std::complex<double> alpha,
163 std::complex<double>
const * X,
165 std::complex<double> * Y,
181 void default_scal< std::complex<float> >
182 (
int n, std::complex<float> alpha, std::complex<float> * X,
int incX){
187 void default_scal< std::complex<double> >
188 (
int n, std::complex<double> alpha, std::complex<double> * X,
int incX){
192 #define DEF_COOMM_KERNEL() \ 193 for (int j=0; j<n; j++){ \ 194 for (int i=0; i<m; i++){ \ 198 for (int i=0; i<nnz_A; i++){ \ 199 int row_A = rows_A[i]-1; \ 200 int col_A = cols_A[i]-1; \ 201 for (int col_C=0; col_C<n; col_C++){ \ 202 C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A]; \ 221 char matdescra[6] = {
'G',0,0,
'F',0,0};
223 matdescra, (
float*)A, rows_A, cols_A, &nnz_A,
224 (
float*)B, &k, &beta,
246 char matdescra[6] = {
'G',0,0,
'F',0,0};
249 matdescra, (
double*)A, rows_A, cols_A, &nnz_A,
250 (
double*)B, &k, &beta,
260 void default_coomm< std::complex<float> >
264 std::complex<float> alpha,
265 std::complex<float>
const * A,
269 std::complex<float>
const * B,
270 std::complex<float> beta,
271 std::complex<float> * C){
274 char matdescra[6] = {
'G',0,0,
'F',0,0};
276 matdescra, (std::complex<float>*)A, rows_A, cols_A, &nnz_A,
277 (std::complex<float>*)B, &k, &beta,
278 (std::complex<float>*)C, &m);
285 void default_coomm< std::complex<double> >
289 std::complex<double> alpha,
290 std::complex<double>
const * A,
294 std::complex<double>
const * B,
295 std::complex<double> beta,
296 std::complex<double> * C){
299 char matdescra[6] = {
'G',0,0,
'F',0,0};
301 matdescra, (std::complex<double>*)A, rows_A, cols_A, &nnz_A,
302 (std::complex<double>*)B, &k, &beta,
303 (std::complex<double>*)C, &m);
330 template <
typename dtype>
345 #pragma omp parallel for 347 for (
int row_A=0; row_A<m; row_A++){
349 #pragma omp parallel for 351 for (
int col_B=0; col_B<n; col_B++){
352 C[col_B*m+row_A] *= beta;
353 if (IA[row_A] < IA[row_A+1]){
354 int i_A1 = IA[row_A]-1;
355 int col_A1 = JA[i_A1]-1;
356 dtype tmp = A[i_A1]*B[col_B*k+col_A1];
357 for (
int i_A=IA[row_A]; i_A<IA[row_A+1]-1; i_A++){
358 int col_A = JA[i_A]-1;
359 tmp += A[i_A]*B[col_B*k+col_A];
361 C[col_B*m+row_A] += alpha*tmp;
368 template<
typename dtype>
384 #pragma omp parallel for 386 for (
int row_A=0; row_A<m; row_A++){
387 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
388 int row_B = JA[i_A]-1;
389 for (
int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
390 int col_B = JB[i_B]-1;
391 C[col_B*m+row_A] += A[i_A]*B[i_B];
417 char matdescra[6] = {
'G',0,0,
'F',0,0};
419 CTF_BLAS::MKL_SCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
421 CTF_int::muladd_csrmm<float>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
440 char matdescra[6] = {
'G',0,0,
'F',0,0};
442 CTF_BLAS::MKL_DCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
445 CTF_int::muladd_csrmm<double>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
455 std::complex<float> alpha,
456 std::complex<float>
const * A,
460 std::complex<float>
const * B,
461 std::complex<float> beta,
462 std::complex<float> * C)
const {
465 char matdescra[6] = {
'G',0,0,
'F',0,0};
467 CTF_BLAS::MKL_CCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
469 CTF_int::muladd_csrmm< std::complex<float> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
478 std::complex<double> alpha,
479 std::complex<double>
const * A,
483 std::complex<double>
const * B,
484 std::complex<double> beta,
485 std::complex<double> * C)
const {
488 char matdescra[6] = {
'G',0,0,
'F',0,0};
489 CTF_BLAS::MKL_ZCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
491 CTF_int::muladd_csrmm< std::complex<double> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
497 #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \ 499 void CTF::Semiring<dtype,is_ord>::default_csrmultd \ 514 if (alpha == this->taddid){ \ 515 if (beta != this->tmulid) \ 516 CTF_int::default_scal<dtype>(m*n, beta, C, 1); \ 520 if (beta == this->taddid){ \ 521 CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, C, &m); \ 522 if (alpha != this->tmulid) \ 523 CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \ 525 dtype * tmp_C_buf = (dtype*)alloc(sizeof(dtype)*m*n); \ 526 CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, tmp_C_buf, &m); \ 527 if (beta != this->tmulid) \ 528 CTF_int::default_scal<dtype>(m*n, beta, C, 1); \ 529 CTF_int::default_axpy<dtype>(m*n, alpha, tmp_C_buf, 1, C, 1); \ 530 cdealloc(tmp_C_buf); \ 534 #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \ 536 void CTF::Semiring<dtype,is_ord>::default_csrmultd \ 551 if (alpha == this->taddid){ \ 552 if (beta != this->tmulid) \ 553 CTF_int::default_scal<dtype>(m*n, beta, C, 1); \ 556 if (alpha != this->tmulid || beta != this->tmulid){ \ 557 CTF_int::default_scal<dtype>(m*n, beta/alpha, C, 1); \ 559 CTF_int::muladd_csrmultd<dtype>(m,n,k,A,JA,IA,nnz_A,B,JB,IB,nnz_B,C); \ 560 if (alpha != this->tmulid){ \ 561 CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \ 573 #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \ 575 void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \ 589 char *& C_CSR) const { \ 591 CSR_Matrix C_in(C_CSR); \ 593 int * new_ic = (int*)alloc(sizeof(int)*(m+1)); \ 598 CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, NULL, NULL, new_ic, &req, &info); \ 600 CSR_Matrix C_add(new_ic[m]-1, m, n, this); \ 601 memcpy(C_add.IA(), new_ic, (m+1)*sizeof(int)); \ 604 CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, (dtype*)C_add.vals(), C_add.JA(), C_add.IA(), &req, &info); \ 606 if (beta == this->taddid){ \ 607 C_CSR = C_add.all_data; \ 609 if (C_CSR != NULL && beta != this->tmulid){ \ 610 this->scal(C_in.nnz(), (char const *)&beta, C_in.vals(), 1); \ 612 if (alpha != this->tmulid){ \ 613 this->scal(C_add.nnz(), (char const *)&alpha, C_add.vals(), 1); \ 615 if (C_CSR == NULL){ \ 616 C_CSR = C_add.all_data; \ 618 char * C_ret = csr_add(C_CSR, C_add.all_data); \ 619 cdealloc(C_add.all_data); \ 625 #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \ 627 void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \ 641 char *& C_CSR) const { \ 642 this->gen_csrmultcsr(m,n,k,alpha,A,JA,IA,nnz_A,B,JB,IB,nnz_B,beta,C_CSR); \ 661 bool CTF::Semiring<std::complex<float>,0>::is_offloadable()
const {
662 return fgemm == &CTF_int::default_gemm< std::complex<float> >;
672 bool CTF::Semiring<std::complex<double>,0>::is_offloadable()
const {
673 return fgemm == &CTF_int::default_gemm< std::complex<double> >;
689 if (tA ==
'n' || tA ==
'N') lda_A = m;
691 if (tB ==
'N' || tB ==
'N') lda_B = k;
692 CTF_int::offload_gemm<float>(tA, tB, m, n, k, ((
float const*)alpha)[0], (
float const *)A, lda_A, (
float const *)B, lda_B, ((
float const*)beta)[0], (
float*)C, m);
708 if (tA ==
'n' || tA ==
'N') lda_A = m;
710 if (tB ==
'N' || tB ==
'N') lda_B = k;
711 CTF_int::offload_gemm<std::complex<float>>(tA, tB, m, n, k, ((std::complex<float>
const*)alpha)[0], (std::complex<float>
const *)A, lda_A, (std::complex<float>
const *)B, lda_B, ((std::complex<float>
const*)beta)[0], (std::complex<float>*)C, m);
727 if (tA ==
'n' || tA ==
'N') lda_A = m;
729 if (tB ==
'N' || tB ==
'N') lda_B = k;
730 CTF_int::offload_gemm<double>(tA, tB, m, n, k, ((
double const*)alpha)[0], (
double const *)A, lda_A, (
double const *)B, lda_B, ((
double const*)beta)[0], (
double*)C, m);
746 if (tA ==
'n' || tA ==
'N') lda_A = m;
748 if (tB ==
'N' || tB ==
'N') lda_B = k;
749 CTF_int::offload_gemm<std::complex<double>>(tA, tB, m, n, k, ((std::complex<double>
const*)alpha)[0], (std::complex<double>
const *)A, lda_A, (std::complex<double>
const *)B, lda_B, ((std::complex<double>
const*)beta)[0], (std::complex<double>*)C, m);
void MKL_ZCOOMM(char *transa, int *m, int *n, int *k, std::complex< double > *alpha, char *matdescra, std::complex< double > const *val, int const *rowind, int const *colind, int *nnz, std::complex< double > const *b, int *ldb, std::complex< double > *beta, std::complex< double > *c, int *ldc)
void SSCAL(const int *n, float *dA, float *dX, const int *incX)
void CAXPY(const int *n, std::complex< float > *dA, const std::complex< float > *dX, const int *incX, std::complex< float > *dY, const int *incY)
void default_coomm< double >(int m, int n, int k, double alpha, double const *A, int const *rows_A, int const *cols_A, int nnz_A, double const *B, double beta, double *C)
#define DEF_COOMM_KERNEL()
void ZSCAL(const int *n, std::complex< double > *dA, std::complex< double > *dX, const int *incX)
void MKL_DCOOMM(char *transa, int *m, int *n, int *k, double *alpha, char *matdescra, double const *val, int const *rowind, int const *colind, int *nnz, double const *b, int *ldb, double *beta, double *c, int *ldc)
dtype ** get_grp_ptrs(int64_t grp_sz, int64_t ngrp, dtype const *data)
#define CSR_MULTCSR_DEF(dtype, is_ord, MKL_name)
Semiring is a Monoid with an addition multiplicaton function addition must have an identity and be as...
void default_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C) const
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_scal< double >(int n, double alpha, double *X, int incX)
void ZAXPY(const int *n, std::complex< double > *dA, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void MKL_DCSRMM(const char *transa, const int *m, const int *n, const int *k, const double *alpha, const char *matdescra, const double *val, const int *indx, const int *pntrb, const int *pntre, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void default_coomm< float >(int m, int n, int k, float alpha, float const *A, int const *rows_A, int const *cols_A, int nnz_A, float const *B, float beta, float *C)
void MKL_SCOOMM(char *transa, int *m, int *n, int *k, float *alpha, char *matdescra, float const *val, int const *rowind, int const *colind, int *nnz, float const *b, int *ldb, float *beta, float *c, int *ldc)
void CSCAL(const int *n, std::complex< float > *dA, std::complex< float > *dX, const int *incX)
void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
void DAXPY(const int *n, double *dA, const double *dX, const int *incX, double *dY, const int *incY)
void MKL_ZCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const char *matdescra, const std::complex< double > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void DSCAL(const int *n, double *dA, double *dX, const int *incX)
void SAXPY(const int *n, float *dA, const float *dX, const int *incX, float *dY, const int *incY)
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_gemm< double >(char tA, char tB, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
bool is_offloadable() const
void default_gemm< float >(char tA, char tB, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
void muladd_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C)
void default_scal< float >(int n, float alpha, float *X, int incX)
void muladd_csrmultd(int m, int n, int k, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype *C)
void default_axpy< float >(int n, float alpha, float const *X, int incX, float *Y, int incY)
#define CSR_MULTD_DEF(dtype, is_ord, MKL_name)
#define INST_GEMM_BATCH(dtype)
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void MKL_CCOOMM(char *transa, int *m, int *n, int *k, std::complex< float > *alpha, char *matdescra, std::complex< float > const *val, int const *rowind, int const *colind, int *nnz, std::complex< float > const *b, int *ldb, std::complex< float > *beta, std::complex< float > *c, int *ldc)
void MKL_CCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const char *matdescra, const std::complex< float > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void default_axpy< double >(int n, double alpha, double const *X, int incX, double *Y, int incY)
void MKL_SCSRMM(const char *transa, const int *m, const int *n, const int *k, const float *alpha, const char *matdescra, const float *val, const int *indx, const int *pntrb, const int *pntre, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)