Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
semiring.cxx
Go to the documentation of this file.
1 #include "set.h"
2 #include "../shared/blas_symbs.h"
3 #include "../shared/offload.h"
4 #include "../sparse_formats/csr.h"
5 
6 #ifdef USE_MKL
7 #include "../shared/mkl_symbs.h"
8 #endif
9 
10 using namespace CTF_int;
11 
12 namespace CTF_int {
13 
14  template <typename dtype>
15  void gemm_batch(
16  char taA,
17  char taB,
18  int l,
19  int m,
20  int n,
21  int k,
22  dtype alpha,
23  dtype const* A,
24  dtype const* B,
25  dtype beta,
26  dtype * C){
27  if (m == 1 && n == 1 && k == 1) {
28  for (int i=0; i<l; i++){
29  C[i]*=beta;
30  C[i]+=alpha*A[i]*B[i];
31  }
32  return;
33  }
34  int lda, ldb, ldc;
35  ldc = m;
36  if (taA == 'n' || taA == 'N'){
37  lda = m;
38  } else {
39  lda = k;
40  }
41  if (taB == 'n' || taB == 'N'){
42  ldb = k;
43  } else {
44  ldb = n;
45  }
46  dtype ** ptrs_A = get_grp_ptrs(m*k,l,A);
47  dtype ** ptrs_B = get_grp_ptrs(k*n,l,B);
48  dtype ** ptrs_C = get_grp_ptrs(m*n,l,C);
49 #if USE_BATCH_GEMM
50  int group_count = 1;
51  int size_per_group = l;
52  CTF_BLAS::gemm_batch<dtype>(&taA, &taB, &m, &n, &k, &alpha, ptrs_A, &lda, ptrs_B, &ldb, &beta, ptrs_C, &ldc, &group_count, &size_per_group);
53 #else
54  for (int i=0; i<l; i++){
55  CTF_BLAS::gemm<dtype>(&taA,&taB,&m,&n,&k,&alpha, ptrs_A[i] ,&lda, ptrs_B[i] ,&ldb,&beta, ptrs_C[i] ,&ldc);
56  }
57 #endif
58  free(ptrs_A);
59  free(ptrs_B);
60  free(ptrs_C);
61  }
62 
63 #define INST_GEMM_BATCH(dtype) \
64  template void gemm_batch<dtype>( char , \
65  char , \
66  int , \
67  int , \
68  int , \
69  int , \
70  dtype , \
71  dtype const *, \
72  dtype const *, \
73  dtype , \
74  dtype *);
75  INST_GEMM_BATCH(float)
76  INST_GEMM_BATCH(double)
77  INST_GEMM_BATCH(std::complex<float>)
78  INST_GEMM_BATCH(std::complex<double>)
79 #undef INST_GEMM_BATCH
80 
81  template <typename dtype>
82  void gemm(char tA,
83  char tB,
84  int m,
85  int n,
86  int k,
87  dtype alpha,
88  dtype const * A,
89  dtype const * B,
90  dtype beta,
91  dtype * C){
92  int lda, lda_B, lda_C;
93  lda_C = m;
94  if (tA == 'n' || tA == 'N'){
95  lda = m;
96  } else {
97  lda = k;
98  }
99  if (tB == 'n' || tB == 'N'){
100  lda_B = k;
101  } else {
102  lda_B = n;
103  }
104  CTF_BLAS::gemm<dtype>(&tA,&tB,&m,&n,&k,&alpha,A,&lda,B,&lda_B,&beta,C,&lda_C);
105  }
106 
107 #define INST_GEMM(dtype) \
108  template void gemm<dtype>( char , \
109  char , \
110  int , \
111  int , \
112  int , \
113  dtype , \
114  dtype const *, \
115  dtype const *, \
116  dtype , \
117  dtype *);
118  INST_GEMM(float)
119  INST_GEMM(double)
120  INST_GEMM(std::complex<float>)
121  INST_GEMM(std::complex<double>)
122 #undef INST_GEMM
123 
124 
125 
126  template <>
128  (int n,
129  float alpha,
130  float const * X,
131  int incX,
132  float * Y,
133  int incY){
134  CTF_BLAS::SAXPY(&n,&alpha,X,&incX,Y,&incY);
135  }
136 
137  template <>
139  (int n,
140  double alpha,
141  double const * X,
142  int incX,
143  double * Y,
144  int incY){
145  CTF_BLAS::DAXPY(&n,&alpha,X,&incX,Y,&incY);
146  }
147 
148  template <>
149  void default_axpy< std::complex<float> >
150  (int n,
151  std::complex<float> alpha,
152  std::complex<float> const * X,
153  int incX,
154  std::complex<float> * Y,
155  int incY){
156  CTF_BLAS::CAXPY(&n,&alpha,X,&incX,Y,&incY);
157  }
158 
159  template <>
160  void default_axpy< std::complex<double> >
161  (int n,
162  std::complex<double> alpha,
163  std::complex<double> const * X,
164  int incX,
165  std::complex<double> * Y,
166  int incY){
167  CTF_BLAS::ZAXPY(&n,&alpha,X,&incX,Y,&incY);
168  }
169 
170  template <>
171  void default_scal<float>(int n, float alpha, float * X, int incX){
172  CTF_BLAS::SSCAL(&n,&alpha,X,&incX);
173  }
174 
175  template <>
176  void default_scal<double>(int n, double alpha, double * X, int incX){
177  CTF_BLAS::DSCAL(&n,&alpha,X,&incX);
178  }
179 
180  template <>
181  void default_scal< std::complex<float> >
182  (int n, std::complex<float> alpha, std::complex<float> * X, int incX){
183  CTF_BLAS::CSCAL(&n,&alpha,X,&incX);
184  }
185 
186  template <>
187  void default_scal< std::complex<double> >
188  (int n, std::complex<double> alpha, std::complex<double> * X, int incX){
189  CTF_BLAS::ZSCAL(&n,&alpha,X,&incX);
190  }
191 
192 #define DEF_COOMM_KERNEL() \
193  for (int j=0; j<n; j++){ \
194  for (int i=0; i<m; i++){ \
195  C[j*m+i] *= beta; \
196  } \
197  } \
198  for (int i=0; i<nnz_A; i++){ \
199  int row_A = rows_A[i]-1; \
200  int col_A = cols_A[i]-1; \
201  for (int col_C=0; col_C<n; col_C++){ \
202  C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A]; \
203  } \
204  }
205 
206  template <>
208  (int m,
209  int n,
210  int k,
211  float alpha,
212  float const * A,
213  int const * rows_A,
214  int const * cols_A,
215  int nnz_A,
216  float const * B,
217  float beta,
218  float * C){
219 #if USE_MKL
220  char transa = 'N';
221  char matdescra[6] = {'G',0,0,'F',0,0};
222  CTF_BLAS::MKL_SCOOMM(&transa, &m, &n, &k, &alpha,
223  matdescra, (float*)A, rows_A, cols_A, &nnz_A,
224  (float*)B, &k, &beta,
225  (float*)C, &m);
226 #else
228 #endif
229  }
230 
231  template <>
233  (int m,
234  int n,
235  int k,
236  double alpha,
237  double const * A,
238  int const * rows_A,
239  int const * cols_A,
240  int nnz_A,
241  double const * B,
242  double beta,
243  double * C){
244 #if USE_MKL
245  char transa = 'N';
246  char matdescra[6] = {'G',0,0,'F',0,0};
247  //TAU_FSTART(MKL_DCOOMM);
248  CTF_BLAS::MKL_DCOOMM(&transa, &m, &n, &k, &alpha,
249  matdescra, (double*)A, rows_A, cols_A, &nnz_A,
250  (double*)B, &k, &beta,
251  (double*)C, &m);
252  //TAU_FSTOP(MKL_DCOOMM);
253 #else
255 #endif
256  }
257 
258 
259  template <>
260  void default_coomm< std::complex<float> >
261  (int m,
262  int n,
263  int k,
264  std::complex<float> alpha,
265  std::complex<float> const * A,
266  int const * rows_A,
267  int const * cols_A,
268  int nnz_A,
269  std::complex<float> const * B,
270  std::complex<float> beta,
271  std::complex<float> * C){
272 #if USE_MKL
273  char transa = 'N';
274  char matdescra[6] = {'G',0,0,'F',0,0};
275  CTF_BLAS::MKL_CCOOMM(&transa, &m, &n, &k, &alpha,
276  matdescra, (std::complex<float>*)A, rows_A, cols_A, &nnz_A,
277  (std::complex<float>*)B, &k, &beta,
278  (std::complex<float>*)C, &m);
279 #else
281 #endif
282  }
283 
284  template <>
285  void default_coomm< std::complex<double> >
286  (int m,
287  int n,
288  int k,
289  std::complex<double> alpha,
290  std::complex<double> const * A,
291  int const * rows_A,
292  int const * cols_A,
293  int nnz_A,
294  std::complex<double> const * B,
295  std::complex<double> beta,
296  std::complex<double> * C){
297 #if USE_MKL
298  char transa = 'N';
299  char matdescra[6] = {'G',0,0,'F',0,0};
300  CTF_BLAS::MKL_ZCOOMM(&transa, &m, &n, &k, &alpha,
301  matdescra, (std::complex<double>*)A, rows_A, cols_A, &nnz_A,
302  (std::complex<double>*)B, &k, &beta,
303  (std::complex<double>*)C, &m);
304 #else
306 #endif
307  }
308 /*
309 #if USE_MKL
310  template <>
311  bool get_def_has_csrmm<float>(){ return true; }
312  template <>
313  bool get_def_has_csrmm<double>(){ return true; }
314  template <>
315  bool get_def_has_csrmm< std::complex<float> >(){ return true; }
316  template <>
317  bool get_def_has_csrmm< std::complex<double> >(){ return true; }
318 #else
319  template <>
320  bool get_def_has_csrmm<float>(){ return true; }
321  template <>
322  bool get_def_has_csrmm<double>(){ return true; }
323  template <>
324  bool get_def_has_csrmm< std::complex<float> >(){ return true; }
325  template <>
326  bool get_def_has_csrmm< std::complex<double> >(){ return true; }
327 #endif
328 */
329 #if (USE_MKL!=1)
330  template <typename dtype>
331  void muladd_csrmm
332  (int m,
333  int n,
334  int k,
335  dtype alpha,
336  dtype const * A,
337  int const * JA,
338  int const * IA,
339  int nnz_A,
340  dtype const * B,
341  dtype beta,
342  dtype * C){
343  //TAU_FSTART(muladd_csrmm);
344 #ifdef USE_OMP
345  #pragma omp parallel for
346 #endif
347  for (int row_A=0; row_A<m; row_A++){
348 #ifdef USE_OMP
349  #pragma omp parallel for
350 #endif
351  for (int col_B=0; col_B<n; col_B++){
352  C[col_B*m+row_A] *= beta;
353  if (IA[row_A] < IA[row_A+1]){
354  int i_A1 = IA[row_A]-1;
355  int col_A1 = JA[i_A1]-1;
356  dtype tmp = A[i_A1]*B[col_B*k+col_A1];
357  for (int i_A=IA[row_A]; i_A<IA[row_A+1]-1; i_A++){
358  int col_A = JA[i_A]-1;
359  tmp += A[i_A]*B[col_B*k+col_A];
360  }
361  C[col_B*m+row_A] += alpha*tmp;
362  }
363  }
364  }
365  //TAU_FSTOP(muladd_csrmm);
366  }
367 
368  template<typename dtype>
369  void muladd_csrmultd
370  (int m,
371  int n,
372  int k,
373  dtype const * A,
374  int const * JA,
375  int const * IA,
376  int nnz_A,
377  dtype const * B,
378  int const * JB,
379  int const * IB,
380  int nnz_B,
381  dtype * C){
382  //TAU_FSTART(muladd_csrmultd);
383 #ifdef _OPENMP
384  #pragma omp parallel for
385 #endif
386  for (int row_A=0; row_A<m; row_A++){
387  for (int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
388  int row_B = JA[i_A]-1; //=col_A
389  for (int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
390  int col_B = JB[i_B]-1;
391  C[col_B*m+row_A] += A[i_A]*B[i_B];
392  }
393  }
394  }
395  //TAU_FSTOP(muladd_csrmultd);
396  }
397 #endif
398 }
399 
400 namespace CTF {
401 
402  template <>
404  (int m,
405  int n,
406  int k,
407  float alpha,
408  float const * A,
409  int const * JA,
410  int const * IA,
411  int nnz_A,
412  float const * B,
413  float beta,
414  float * C) const {
415 #if USE_MKL
416  char transa = 'N';
417  char matdescra[6] = {'G',0,0,'F',0,0};
418 
419  CTF_BLAS::MKL_SCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
420 #else
421  CTF_int::muladd_csrmm<float>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
422 #endif
423  }
424 
425  template <>
427  (int m,
428  int n,
429  int k,
430  double alpha,
431  double const * A,
432  int const * JA,
433  int const * IA,
434  int nnz_A,
435  double const * B,
436  double beta,
437  double * C) const {
438 #if USE_MKL
439  char transa = 'N';
440  char matdescra[6] = {'G',0,0,'F',0,0};
441  //TAU_FSTART(MKL_DCSRMM);
442  CTF_BLAS::MKL_DCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
443  //TAU_FSTOP(MKL_DCSRMM);
444 #else
445  CTF_int::muladd_csrmm<double>(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
446 #endif
447  }
448 
449 
450  template <>
451  void CTF::Semiring<std::complex<float>,0>::default_csrmm
452  (int m,
453  int n,
454  int k,
455  std::complex<float> alpha,
456  std::complex<float> const * A,
457  int const * JA,
458  int const * IA,
459  int nnz_A,
460  std::complex<float> const * B,
461  std::complex<float> beta,
462  std::complex<float> * C) const {
463 #if USE_MKL
464  char transa = 'N';
465  char matdescra[6] = {'G',0,0,'F',0,0};
466 
467  CTF_BLAS::MKL_CCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
468 #else
469  CTF_int::muladd_csrmm< std::complex<float> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
470 #endif
471  }
472 
473  template <>
474  void CTF::Semiring<std::complex<double>,0>::default_csrmm
475  (int m,
476  int n,
477  int k,
478  std::complex<double> alpha,
479  std::complex<double> const * A,
480  int const * JA,
481  int const * IA,
482  int nnz_A,
483  std::complex<double> const * B,
484  std::complex<double> beta,
485  std::complex<double> * C) const {
486 #if USE_MKL
487  char transa = 'N';
488  char matdescra[6] = {'G',0,0,'F',0,0};
489  CTF_BLAS::MKL_ZCSRMM(&transa, &m, &n, &k, &alpha, matdescra, A, JA, IA, IA+1, B, &k, &beta, C, &m);
490 #else
491  CTF_int::muladd_csrmm< std::complex<double> >(m,n,k,alpha,A,JA,IA,nnz_A,B,beta,C);
492 #endif
493  }
494 
495 
496 #if USE_MKL
497  #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \
498  template<> \
499  void CTF::Semiring<dtype,is_ord>::default_csrmultd \
500  (int m, \
501  int n, \
502  int k, \
503  dtype alpha, \
504  dtype const * A, \
505  int const * JA, \
506  int const * IA, \
507  int nnz_A, \
508  dtype const * B, \
509  int const * JB, \
510  int const * IB, \
511  int nnz_B, \
512  dtype beta, \
513  dtype * C) const { \
514  if (alpha == this->taddid){ \
515  if (beta != this->tmulid) \
516  CTF_int::default_scal<dtype>(m*n, beta, C, 1); \
517  return; \
518  } \
519  char transa = 'N'; \
520  if (beta == this->taddid){ \
521  CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, C, &m); \
522  if (alpha != this->tmulid) \
523  CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \
524  } else { \
525  dtype * tmp_C_buf = (dtype*)alloc(sizeof(dtype)*m*n); \
526  CTF_BLAS::MKL_name(&transa, &m, &k, &n, A, JA, IA, B, JB, IB, tmp_C_buf, &m); \
527  if (beta != this->tmulid) \
528  CTF_int::default_scal<dtype>(m*n, beta, C, 1); \
529  CTF_int::default_axpy<dtype>(m*n, alpha, tmp_C_buf, 1, C, 1); \
530  cdealloc(tmp_C_buf); \
531  } \
532  }
533 #else
534  #define CSR_MULTD_DEF(dtype,is_ord,MKL_name) \
535  template<> \
536  void CTF::Semiring<dtype,is_ord>::default_csrmultd \
537  (int m, \
538  int n, \
539  int k, \
540  dtype alpha, \
541  dtype const * A, \
542  int const * JA, \
543  int const * IA, \
544  int nnz_A, \
545  dtype const * B, \
546  int const * JB, \
547  int const * IB, \
548  int nnz_B, \
549  dtype beta, \
550  dtype * C) const { \
551  if (alpha == this->taddid){ \
552  if (beta != this->tmulid) \
553  CTF_int::default_scal<dtype>(m*n, beta, C, 1); \
554  return; \
555  } \
556  if (alpha != this->tmulid || beta != this->tmulid){ \
557  CTF_int::default_scal<dtype>(m*n, beta/alpha, C, 1); \
558  } \
559  CTF_int::muladd_csrmultd<dtype>(m,n,k,A,JA,IA,nnz_A,B,JB,IB,nnz_B,C); \
560  if (alpha != this->tmulid){ \
561  CTF_int::default_scal<dtype>(m*n, alpha, C, 1); \
562  } \
563  }
564 #endif
565 
567  CSR_MULTD_DEF(double,1,MKL_DCSRMULTD)
568  CSR_MULTD_DEF(std::complex<float>,0,MKL_CCSRMULTD)
569  CSR_MULTD_DEF(std::complex<double>,0,MKL_ZCSRMULTD)
570 
571 
572 #if USE_MKL
573  #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \
574  template<> \
575  void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \
576  (int m, \
577  int n, \
578  int k, \
579  dtype alpha, \
580  dtype const * A, \
581  int const * JA, \
582  int const * IA, \
583  int nnz_A, \
584  dtype const * B, \
585  int const * JB, \
586  int const * IB, \
587  int nnz_B, \
588  dtype beta, \
589  char *& C_CSR) const { \
590  char transa = 'N'; \
591  CSR_Matrix C_in(C_CSR); \
592  \
593  int * new_ic = (int*)alloc(sizeof(int)*(m+1)); \
594  \
595  int sort = 1; \
596  int req = 1; \
597  int info; \
598  CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, NULL, NULL, new_ic, &req, &info); \
599  \
600  CSR_Matrix C_add(new_ic[m]-1, m, n, this); \
601  memcpy(C_add.IA(), new_ic, (m+1)*sizeof(int)); \
602  cdealloc(new_ic); \
603  req = 2; \
604  CTF_BLAS::MKL_name(&transa, &req, &sort, &m, &k, &n, A, JA, IA, B, JB, IB, (dtype*)C_add.vals(), C_add.JA(), C_add.IA(), &req, &info); \
605  \
606  if (beta == this->taddid){ \
607  C_CSR = C_add.all_data; \
608  } else { \
609  if (C_CSR != NULL && beta != this->tmulid){ \
610  this->scal(C_in.nnz(), (char const *)&beta, C_in.vals(), 1); \
611  } \
612  if (alpha != this->tmulid){ \
613  this->scal(C_add.nnz(), (char const *)&alpha, C_add.vals(), 1); \
614  } \
615  if (C_CSR == NULL){ \
616  C_CSR = C_add.all_data; \
617  } else { \
618  char * C_ret = csr_add(C_CSR, C_add.all_data); \
619  cdealloc(C_add.all_data); \
620  C_CSR = C_ret; \
621  } \
622  } \
623  }
624 #else
625  #define CSR_MULTCSR_DEF(dtype,is_ord,MKL_name) \
626  template<> \
627  void CTF::Semiring<dtype,is_ord>::default_csrmultcsr \
628  (int m, \
629  int n, \
630  int k, \
631  dtype alpha, \
632  dtype const * A, \
633  int const * JA, \
634  int const * IA, \
635  int nnz_A, \
636  dtype const * B, \
637  int const * JB, \
638  int const * IB, \
639  int nnz_B, \
640  dtype beta, \
641  char *& C_CSR) const { \
642  this->gen_csrmultcsr(m,n,k,alpha,A,JA,IA,nnz_A,B,JB,IB,nnz_B,beta,C_CSR); \
643  }
644 #endif
645 
648  CSR_MULTCSR_DEF(std::complex<float>,0,MKL_CCSRMULTCSR)
649  CSR_MULTCSR_DEF(std::complex<double>,0,MKL_ZCSRMULTCSR)
650 
651 /* template<>
652  bool CTF::Semiring<float,1>::is_offloadable() const {
653  return fgemm == &CTF_int::default_gemm<float>;
654  }*/
655  template<>
656  bool CTF::Semiring<float,1>::is_offloadable() const {
657  return fgemm == &CTF_int::default_gemm<float>;
658  }
659 
660  template<>
661  bool CTF::Semiring<std::complex<float>,0>::is_offloadable() const {
662  return fgemm == &CTF_int::default_gemm< std::complex<float> >;
663  }
664 
665 
666  template<>
668  return fgemm == &CTF_int::default_gemm<double>;
669  }
670 
671  template<>
672  bool CTF::Semiring<std::complex<double>,0>::is_offloadable() const {
673  return fgemm == &CTF_int::default_gemm< std::complex<double> >;
674  }
675 
676  template<>
678  char tA,
679  char tB,
680  int m,
681  int n,
682  int k,
683  char const * alpha,
684  char const * A,
685  char const * B,
686  char const * beta,
687  char * C) const {
688  int lda_A = k;
689  if (tA == 'n' || tA == 'N') lda_A = m;
690  int lda_B = n;
691  if (tB == 'N' || tB == 'N') lda_B = k;
692  CTF_int::offload_gemm<float>(tA, tB, m, n, k, ((float const*)alpha)[0], (float const *)A, lda_A, (float const *)B, lda_B, ((float const*)beta)[0], (float*)C, m);
693  }
694 
695  template<>
696  void CTF::Semiring<std::complex<float>,0>::offload_gemm(
697  char tA,
698  char tB,
699  int m,
700  int n,
701  int k,
702  char const * alpha,
703  char const * A,
704  char const * B,
705  char const * beta,
706  char * C) const {
707  int lda_A = k;
708  if (tA == 'n' || tA == 'N') lda_A = m;
709  int lda_B = n;
710  if (tB == 'N' || tB == 'N') lda_B = k;
711  CTF_int::offload_gemm<std::complex<float>>(tA, tB, m, n, k, ((std::complex<float> const*)alpha)[0], (std::complex<float> const *)A, lda_A, (std::complex<float> const *)B, lda_B, ((std::complex<float> const*)beta)[0], (std::complex<float>*)C, m);
712  }
713 
714  template<>
716  char tA,
717  char tB,
718  int m,
719  int n,
720  int k,
721  char const * alpha,
722  char const * A,
723  char const * B,
724  char const * beta,
725  char * C) const {
726  int lda_A = k;
727  if (tA == 'n' || tA == 'N') lda_A = m;
728  int lda_B = n;
729  if (tB == 'N' || tB == 'N') lda_B = k;
730  CTF_int::offload_gemm<double>(tA, tB, m, n, k, ((double const*)alpha)[0], (double const *)A, lda_A, (double const *)B, lda_B, ((double const*)beta)[0], (double*)C, m);
731  }
732 
733  template<>
734  void CTF::Semiring<std::complex<double>,0>::offload_gemm(
735  char tA,
736  char tB,
737  int m,
738  int n,
739  int k,
740  char const * alpha,
741  char const * A,
742  char const * B,
743  char const * beta,
744  char * C) const {
745  int lda_A = k;
746  if (tA == 'n' || tA == 'N') lda_A = m;
747  int lda_B = n;
748  if (tB == 'N' || tB == 'N') lda_B = k;
749  CTF_int::offload_gemm<std::complex<double>>(tA, tB, m, n, k, ((std::complex<double> const*)alpha)[0], (std::complex<double> const *)A, lda_A, (std::complex<double> const *)B, lda_B, ((std::complex<double> const*)beta)[0], (std::complex<double>*)C, m);
750  }
751 
752 
753 }
void MKL_ZCOOMM(char *transa, int *m, int *n, int *k, std::complex< double > *alpha, char *matdescra, std::complex< double > const *val, int const *rowind, int const *colind, int *nnz, std::complex< double > const *b, int *ldb, std::complex< double > *beta, std::complex< double > *c, int *ldc)
void SSCAL(const int *n, float *dA, float *dX, const int *incX)
#define MKL_SCSRMULTD
Definition: mkl_symbs.h:47
void CAXPY(const int *n, std::complex< float > *dA, const std::complex< float > *dX, const int *incX, std::complex< float > *dY, const int *incY)
#define MKL_ZCSRMULTCSR
Definition: mkl_symbs.h:54
#define INST_GEMM(dtype)
Definition: semiring.cxx:107
void default_coomm< double >(int m, int n, int k, double alpha, double const *A, int const *rows_A, int const *cols_A, int nnz_A, double const *B, double beta, double *C)
Definition: semiring.cxx:233
#define DEF_COOMM_KERNEL()
Definition: semiring.cxx:192
void ZSCAL(const int *n, std::complex< double > *dA, std::complex< double > *dX, const int *incX)
void MKL_DCOOMM(char *transa, int *m, int *n, int *k, double *alpha, char *matdescra, double const *val, int const *rowind, int const *colind, int *nnz, double const *b, int *ldb, double *beta, double *c, int *ldc)
dtype ** get_grp_ptrs(int64_t grp_sz, int64_t ngrp, dtype const *data)
Definition: semiring.h:110
#define MKL_SCSRMULTCSR
Definition: mkl_symbs.h:51
#define CSR_MULTCSR_DEF(dtype, is_ord, MKL_name)
Definition: semiring.cxx:625
Semiring is a Monoid with an addition multiplicaton function addition must have an identity and be as...
Definition: semiring.h:359
void default_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C) const
Definition: semiring.h:632
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:82
void default_scal< double >(int n, double alpha, double *X, int incX)
Definition: semiring.cxx:176
void ZAXPY(const int *n, std::complex< double > *dA, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void MKL_DCSRMM(const char *transa, const int *m, const int *n, const int *k, const double *alpha, const char *matdescra, const double *val, const int *indx, const int *pntrb, const int *pntre, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
void default_coomm< float >(int m, int n, int k, float alpha, float const *A, int const *rows_A, int const *cols_A, int nnz_A, float const *B, float beta, float *C)
Definition: semiring.cxx:208
void MKL_SCOOMM(char *transa, int *m, int *n, int *k, float *alpha, char *matdescra, float const *val, int const *rowind, int const *colind, int *nnz, float const *b, int *ldb, float *beta, float *c, int *ldc)
void CSCAL(const int *n, std::complex< float > *dA, std::complex< float > *dX, const int *incX)
void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
Definition: semiring.h:586
void DAXPY(const int *n, double *dA, const double *dX, const int *incX, double *dY, const int *incY)
void MKL_ZCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const char *matdescra, const std::complex< double > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void DSCAL(const int *n, double *dA, double *dX, const int *incX)
#define MKL_CCSRMULTD
Definition: mkl_symbs.h:49
void SAXPY(const int *n, float *dA, const float *dX, const int *incX, float *dY, const int *incY)
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:15
void default_gemm< double >(char tA, char tB, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
Definition: semiring.h:166
#define MKL_DCSRMULTD
Definition: mkl_symbs.h:48
bool is_offloadable() const
Definition: semiring.h:600
void default_gemm< float >(char tA, char tB, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
Definition: semiring.h:151
#define MKL_DCSRMULTCSR
Definition: mkl_symbs.h:52
void muladd_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:332
Definition: apsp.cxx:17
void default_scal< float >(int n, float alpha, float *X, int incX)
Definition: semiring.cxx:171
void muladd_csrmultd(int m, int n, int k, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype *C)
Definition: semiring.cxx:370
void default_axpy< float >(int n, float alpha, float const *X, int incX, float *Y, int incY)
Definition: semiring.cxx:128
#define CSR_MULTD_DEF(dtype, is_ord, MKL_name)
Definition: semiring.cxx:534
#define INST_GEMM_BATCH(dtype)
Definition: semiring.cxx:63
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void MKL_CCOOMM(char *transa, int *m, int *n, int *k, std::complex< float > *alpha, char *matdescra, std::complex< float > const *val, int const *rowind, int const *colind, int *nnz, std::complex< float > const *b, int *ldb, std::complex< float > *beta, std::complex< float > *c, int *ldc)
void MKL_CCSRMM(const char *transa, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const char *matdescra, const std::complex< float > *val, const int *indx, const int *pntrb, const int *pntre, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc)
void default_axpy< double >(int n, double alpha, double const *X, int incX, double *Y, int incY)
Definition: semiring.cxx:139
void MKL_SCSRMM(const char *transa, const int *m, const int *n, const int *k, const float *alpha, const char *matdescra, const float *val, const int *indx, const int *pntrb, const int *pntre, const float *b, const int *ldb, const float *beta, float *c, const int *ldc)
#define MKL_CCSRMULTCSR
Definition: mkl_symbs.h:53
#define MKL_ZCSRMULTD
Definition: mkl_symbs.h:50