2 #include "../shared/blas_symbs.h"     3 #include "../shared/offload.h"     4 #include "../sparse_formats/csr.h"     7 #include "../shared/mkl_symbs.h"    14   template <
typename dtype>
    27     if (m == 1 && n == 1 && k == 1) {
    28       for (
int i=0; i<l; i++){
    30         C[i]+=alpha*A[i]*B[i];
    36     if (taA == 
'n' || taA == 
'N'){
    41     if (taB == 
'n' || taB == 
'N'){
    51     int size_per_group = l;
    52     CTF_BLAS::gemm_batch<dtype>(&taA, &taB, &m, &n, &k, &alpha, ptrs_A, &lda, ptrs_B, &ldb, &beta, ptrs_C, &ldc, &group_count, &size_per_group);
    54     for (
int i=0; i<l; i++){
    55       CTF_BLAS::gemm<dtype>(&taA,&taB,&m,&n,&k,&alpha, ptrs_A[i] ,&lda, ptrs_B[i] ,&ldb,&beta, ptrs_C[i] ,&ldc);
    63 #define INST_GEMM_BATCH(dtype)            \    64   template void gemm_batch<dtype>( char , \    79 #undef INST_GEMM_BATCH    81   template <
typename dtype>
    92     int lda, lda_B, lda_C;
    94     if (tA == 
'n' || tA == 
'N'){
    99     if (tB == 
'n' || tB == 
'N'){
   104     CTF_BLAS::gemm<dtype>(&tA,&tB,&m,&n,&k,&alpha,A,&lda,B,&lda_B,&beta,C,&lda_C);
   107 #define INST_GEMM(dtype)            \   108   template void gemm<dtype>( char , \   149   void default_axpy< std::complex<float> >
   151                     std::complex<float>         alpha,
   152                     std::complex<float> 
const * X,
   154                     std::complex<float> *       Y,
   160   void default_axpy< std::complex<double> >
   162                     std::complex<double>         alpha,
   163                     std::complex<double> 
const * X,
   165                     std::complex<double> *       Y,
   181   void default_scal< std::complex<float> >
   182       (
int n, std::complex<float> alpha, std::complex<float> * X, 
int incX){
   187   void default_scal< std::complex<double> >
   188       (
int n, std::complex<double> alpha, std::complex<double> * X, 
int incX){
   192 #define DEF_COOMM_KERNEL()                                \   193     for (int j=0; j<n; j++){                              \   194       for (int i=0; i<m; i++){                            \   198     for (int i=0; i<nnz_A; i++){                          \   199       int row_A = rows_A[i]-1;                            \   200       int col_A = cols_A[i]-1;                            \   201       for (int col_C=0; col_C<n; col_C++){                \   202          C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A]; \   221     char matdescra[6] = {
'G',0,0,
'F',0,0};
   223                matdescra, (
float*)A, rows_A, cols_A, &nnz_A,
   224                (
float*)B, &k, &beta,
   246     char matdescra[6] = {
'G',0,0,
'F',0,0};
   249                matdescra, (
double*)A, rows_A, cols_A, &nnz_A,
   250                (
double*)B, &k, &beta,
   260   void default_coomm< std::complex<float> >
   264            std::complex<float>         alpha,
   265            std::complex<float> 
const * A,
   269            std::complex<float> 
const * B,
   270            std::complex<float>         beta,
   271            std::complex<float> *       C){
   274     char matdescra[6] = {
'G',0,0,
'F',0,0};
   276                matdescra, (std::complex<float>*)A, rows_A, cols_A, &nnz_A,
   277                (std::complex<float>*)B, &k, &beta,
   278                (std::complex<float>*)C, &m);
   285   void default_coomm< std::complex<double> >
   289       std::complex<double>         alpha,
   290       std::complex<double> 
const * A,
   294       std::complex<double> 
const * B,
   295       std::complex<double>         beta,
   296       std::complex<double> *       C){
   299     char matdescra[6] = {
'G',0,0,
'F',0,0};
   301                matdescra, (std::complex<double>*)A, rows_A, cols_A, &nnz_A,
   302                (std::complex<double>*)B, &k, &beta,
   303                (std::complex<double>*)C, &m);
   330   template <
typename dtype>
   345     #pragma omp parallel for   347     for (
int row_A=0; row_A<m; row_A++){
   349       #pragma omp parallel for   351       for (
int col_B=0; col_B<n; col_B++){
   352         C[col_B*m+row_A] *= beta;
   353         if (IA[row_A] < IA[row_A+1]){
   354           int i_A1 = IA[row_A]-1;
   355           int col_A1 = JA[i_A1]-1;
   356           dtype tmp = A[i_A1]*B[col_B*k+col_A1];
   357           for (
int i_A=IA[row_A]; i_A<IA[row_A+1]-1; i_A++){
   358             int col_A = JA[i_A]-1;
   359             tmp += A[i_A]*B[col_B*k+col_A];
   361           C[col_B*m+row_A] += alpha*tmp;
   368   template<
typename dtype>
   384     #pragma omp parallel for   386     for (
int row_A=0; row_A<m; row_A++){
   387       for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
   388         int row_B = JA[i_A]-1; 
   389         for (
int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
   390           int col_B = JB[i_B]-1;
   391           C[col_B*m+row_A] += A[i_A]*B[i_B];
   417     char matdescra[6] = {
'G',0,0,
'F',0,0};
   419     CTF_BLAS::MKL_SCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
   421     CTF_int::muladd_csrmm<float>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
   440     char matdescra[6] = {
'G',0,0,
'F',0,0};
   442     CTF_BLAS::MKL_DCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
   445     CTF_int::muladd_csrmm<double>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
   455            std::complex<float>         alpha,
   456            std::complex<float> 
const * A,
   460            std::complex<float> 
const * B,
   461            std::complex<float>         beta,
   462            std::complex<float> *       C) 
const {
   465     char matdescra[6] = {
'G',0,0,
'F',0,0};
   467     CTF_BLAS::MKL_CCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
   469     CTF_int::muladd_csrmm< std::complex<float> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
   478            std::complex<double>         alpha,
   479            std::complex<double> 
const * A,
   483            std::complex<double> 
const * B,
   484            std::complex<double>         beta,
   485            std::complex<double> *       C) 
const {
   488     char matdescra[6] = {
'G',0,0,
'F',0,0};
   489     CTF_BLAS::MKL_ZCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
   491     CTF_int::muladd_csrmm< std::complex<double> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
   497   #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \   499   void CTF::Semiring<dtype,is_ord>::default_csrmultd \   514     if (alpha == this->taddid){ \   515       if (beta != this->tmulid) \   516         CTF_int::default_scal<dtype>(m*n, beta, C, 1); \   520     if (beta == this->taddid){ \   521       CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, C, &m); \   522       if (alpha != this->tmulid) \   523         CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \   525       dtype * tmp_C_buf = (dtype*)alloc(sizeof(dtype)*m*n); \   526       CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, tmp_C_buf, &m); \   527       if (beta != this->tmulid) \   528         CTF_int::default_scal<dtype>(m*n, beta, C, 1); \   529       CTF_int::default_axpy<dtype>(m*n, alpha, tmp_C_buf, 1, C, 1); \   530       cdealloc(tmp_C_buf); \   534   #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \   536   void CTF::Semiring<dtype,is_ord>::default_csrmultd \   551     if (alpha == this->taddid){ \   552       if (beta != this->tmulid) \   553         CTF_int::default_scal<dtype>(m*n, beta, C, 1); \   556     if (alpha != this->tmulid || beta != this->tmulid){ \   557       CTF_int::default_scal<dtype>(m*n, beta/alpha, C, 1); \   559     CTF_int::muladd_csrmultd<dtype>(m,n,k,A,JA,IA,nnz_A,B,JB,IB,nnz_B,C); \   560     if (alpha != this->tmulid){ \   561       CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \   573   #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \   575   void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \   589                       char *&       C_CSR) const { \   591     CSR_Matrix C_in(C_CSR); \   593     int * new_ic = (int*)alloc(sizeof(int)*(m+1)); \   598     CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, NULL, NULL, new_ic, &req, &info); \   600     CSR_Matrix C_add(new_ic[m]-1, m, n, this); \   601     memcpy(C_add.IA(), new_ic, (m+1)*sizeof(int)); \   604     CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, (dtype*)C_add.vals(), C_add.JA(), C_add.IA(), &req, &info); \   606     if (beta == this->taddid){ \   607       C_CSR = C_add.all_data; \   609       if (C_CSR != NULL && beta != this->tmulid){ \   610         this->scal(C_in.nnz(), (char const *)&beta, C_in.vals(), 1); \   612       if (alpha != this->tmulid){ \   613         this->scal(C_add.nnz(), (char const *)&alpha, C_add.vals(), 1); \   615       if (C_CSR == NULL){ \   616         C_CSR = C_add.all_data; \   618         char * C_ret = csr_add(C_CSR, C_add.all_data); \   619         cdealloc(C_add.all_data); \   625   #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \   627   void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \   641                       char *&       C_CSR) const { \   642     this->gen_csrmultcsr(m,n,k,alpha,A,JA,IA,nnz_A,B,JB,IB,nnz_B,beta,C_CSR); \   661   bool CTF::Semiring<std::complex<float>,0>::is_offloadable()
 const {
   662     return fgemm == &CTF_int::default_gemm< std::complex<float> >;
   672   bool CTF::Semiring<std::complex<double>,0>::is_offloadable()
 const {
   673     return fgemm == &CTF_int::default_gemm< std::complex<double> >;
   689     if (tA == 
'n' || tA == 
'N') lda_A = m;
   691     if (tB == 
'N' || tB == 
'N') lda_B = k;
   692     CTF_int::offload_gemm<float>(tA, tB, m, n, k, ((
float const*)alpha)[0], (
float const *)A, lda_A, (
float const *)B, lda_B, ((
float const*)beta)[0], (
float*)C, m);
   708     if (tA == 
'n' || tA == 
'N') lda_A = m;
   710     if (tB == 
'N' || tB == 
'N') lda_B = k;
   711     CTF_int::offload_gemm<std::complex<float>>(tA, tB, m, n, k, ((std::complex<float> 
const*)alpha)[0], (std::complex<float> 
const *)A, lda_A, (std::complex<float> 
const *)B, lda_B, ((std::complex<float> 
const*)beta)[0], (std::complex<float>*)C, m);
   727     if (tA == 
'n' || tA == 
'N') lda_A = m;
   729     if (tB == 
'N' || tB == 
'N') lda_B = k;
   730     CTF_int::offload_gemm<double>(tA, tB, m, n, k, ((
double const*)alpha)[0], (
double const *)A, lda_A, (
double const *)B, lda_B, ((
double const*)beta)[0], (
double*)C, m);
   746     if (tA == 
'n' || tA == 
'N') lda_A = m;
   748     if (tB == 
'N' || tB == 
'N') lda_B = k;
   749     CTF_int::offload_gemm<std::complex<double>>(tA, tB, m, n, k, ((std::complex<double> 
const*)alpha)[0], (std::complex<double> 
const *)A, lda_A, (std::complex<double> 
const *)B, lda_B, ((std::complex<double> 
const*)beta)[0], (std::complex<double>*)C, m);
 void MKL_ZCOOMM(char *transa, int *m, int *n, int *k, std::complex< double > *alpha, char *matdescra, std::complex< double > const *val, int const *rowind, int const *colind, int *nnz, std::complex< double > const *b, int *ldb, std::complex< double > *beta, std::complex< double > *c, int *ldc)
void SSCAL(const int *n, float *dA, float *dX, const int *incX)
void CAXPY(const int *n, std::complex< float > *dA, const std::complex< float > *dX, const int *incX, std::complex< float > *dY, const int *incY)
void default_coomm< double >(int m, int n, int k, double alpha, double const *A, int const *rows_A, int const *cols_A, int nnz_A, double const *B, double beta, double *C)
#define DEF_COOMM_KERNEL()                                                            
void ZSCAL(const int *n, std::complex< double > *dA, std::complex< double > *dX, const int *incX)
void MKL_DCOOMM(char *transa, int *m, int *n, int *k, double *alpha, char *matdescra, double const *val, int const *rowind, int const *colind, int *nnz, double const *b, int *ldb, double *beta, double *c, int *ldc)
dtype ** get_grp_ptrs(int64_t grp_sz, int64_t ngrp, dtype const *data)
#define CSR_MULTCSR_DEF(dtype, is_ord, MKL_name)
Semiring is a Monoid with an addition multiplicaton function addition must have an identity and be as...
void default_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C) const 
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_scal< double >(int n, double alpha, double *X, int incX)
void ZAXPY(const int *n, std::complex< double > *dA, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void MKL_DCSRMM(const char *transa, const int *m, const int *n, const int *k, const double *alpha, const char *matdescra, const double *val, const int *indx, const int *pntrb, const int *pntre, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void default_coomm< float >(int m, int n, int k, float alpha, float const *A, int const *rows_A, int const *cols_A, int nnz_A, float const *B, float beta, float *C)
void MKL_SCOOMM(char *transa, int *m, int *n, int *k, float *alpha, char *matdescra, float const *val, int const *rowind, int const *colind, int *nnz, float const *b, int *ldb, float *beta, float *c, int *ldc)
void CSCAL(const int *n, std::complex< float > *dA, std::complex< float > *dX, const int *incX)
void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const 
void DAXPY(const int *n, double *dA, const double *dX, const int *incX, double *dY, const int *incY)
void MKL_ZCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const char *matdescra, const std::complex< double > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void DSCAL(const int *n, double *dA, double *dX, const int *incX)
void SAXPY(const int *n, float *dA, const float *dX, const int *incX, float *dY, const int *incY)
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_gemm< double >(char tA, char tB, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
bool is_offloadable() const 
void default_gemm< float >(char tA, char tB, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
void muladd_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C)
void default_scal< float >(int n, float alpha, float *X, int incX)
void muladd_csrmultd(int m, int n, int k, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype *C)
void default_axpy< float >(int n, float alpha, float const *X, int incX, float *Y, int incY)
#define CSR_MULTD_DEF(dtype, is_ord, MKL_name)
#define INST_GEMM_BATCH(dtype)                    
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void MKL_CCOOMM(char *transa, int *m, int *n, int *k, std::complex< float > *alpha, char *matdescra, std::complex< float > const *val, int const *rowind, int const *colind, int *nnz, std::complex< float > const *b, int *ldb, std::complex< float > *beta, std::complex< float > *c, int *ldc)
void MKL_CCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const char *matdescra, const std::complex< float > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void default_axpy< double >(int n, double alpha, double const *X, int incX, double *Y, int incY)
void MKL_SCSRMM(const char *transa, const int *m, const int *n, const int *k, const float *alpha, const char *matdescra, const float *val, const int *indx, const int *pntrb, const int *pntre, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)