Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
semiring.h
Go to the documentation of this file.
1 #ifndef __SEMIRING_H__
2 #define __SEMIRING_H__
3 
4 #include "functions.h"
5 #include "../sparse_formats/csr.h"
6 #include <iostream>
7 
8 using namespace std;
9 
10 namespace CTF_int {
11 
12 
13  template <typename dtype>
15  return a*b;
16  }
17 
18  template <typename dtype>
19  void default_axpy(int n,
20  dtype alpha,
21  dtype const * X,
22  int incX,
23  dtype * Y,
24  int incY){
25  for (int i=0; i<n; i++){
26  Y[incY*i] += alpha*X[incX*i];
27  }
28  }
29 
30  template <>
32  (int,float,float const *,int,float *,int);
33 
34  template <>
36  (int,double,double const *,int,double *,int);
37 
38  template <>
39  void default_axpy< std::complex<float> >
40  (int,std::complex<float>,std::complex<float> const *,int,std::complex<float> *,int);
41 
42  template <>
43  void default_axpy< std::complex<double> >
44  (int,std::complex<double>,std::complex<double> const *,int,std::complex<double> *,int);
45 
46  template <typename dtype>
47  void default_scal(int n,
48  dtype alpha,
49  dtype * X,
50  int incX){
51  for (int i=0; i<n; i++){
52  X[incX*i] *= alpha;
53  }
54  }
55 
56  template <>
57  void default_scal<float>(int n, float alpha, float * X, int incX);
58 
59  template <>
60  void default_scal<double>(int n, double alpha, double * X, int incX);
61 
62  template <>
63  void default_scal< std::complex<float> >
64  (int n, std::complex<float> alpha, std::complex<float> * X, int incX);
65 
66  template <>
67  void default_scal< std::complex<double> >
68  (int n, std::complex<double> alpha, std::complex<double> * X, int incX);
69 
70  template<typename dtype>
71  void default_gemm(char tA,
72  char tB,
73  int m,
74  int n,
75  int k,
76  dtype alpha,
77  dtype const * A,
78  dtype const * B,
79  dtype beta,
80  dtype * C){
81  int i,j,l;
82  int istride_A, lstride_A, jstride_B, lstride_B;
83  //TAU_FSTART(default_gemm);
84  if (tA == 'N' || tA == 'n'){
85  istride_A=1;
86  lstride_A=m;
87  } else {
88  istride_A=k;
89  lstride_A=1;
90  }
91  if (tB == 'N' || tB == 'n'){
92  jstride_B=k;
93  lstride_B=1;
94  } else {
95  jstride_B=1;
96  lstride_B=n;
97  }
98  for (j=0; j<n; j++){
99  for (i=0; i<m; i++){
100  C[j*m+i] *= beta;
101  for (l=0; l<k; l++){
102  C[j*m+i] += alpha*A[istride_A*i+lstride_A*l]*B[lstride_B*l+jstride_B*j];
103  }
104  }
105  }
106  //TAU_FSTOP(default_gemm);
107  }
108 
109  template<typename dtype>
110  dtype ** get_grp_ptrs(int64_t grp_sz,
111  int64_t ngrp,
112  dtype const * data){
113  dtype ** data_ptrs = (dtype**)alloc(sizeof(dtype*)*ngrp);
114 #ifdef USE_OMP
115  #pragma omp parallel for
116 #endif
117  for (int i=0; i<ngrp; i++){
118  data_ptrs[i] = ((dtype*)data)+i*grp_sz;
119  }
120  return data_ptrs;
121  }
122 
123  template <typename dtype>
124  void gemm_batch(
125  char taA,
126  char taB,
127  int l,
128  int m,
129  int n,
130  int k,
131  dtype alpha,
132  dtype const* A,
133  dtype const* B,
134  dtype beta,
135  dtype * C);
136 
137  template <typename dtype>
138  void gemm(char tA,
139  char tB,
140  int m,
141  int n,
142  int k,
143  dtype alpha,
144  dtype const * A,
145  dtype const * B,
146  dtype beta,
147  dtype * C);
148 
149  template<>
150  inline void default_gemm<float>
151  (char tA,
152  char tB,
153  int m,
154  int n,
155  int k,
156  float alpha,
157  float const * A,
158  float const * B,
159  float beta,
160  float * C){
161  CTF_int::gemm<float>(tA,tB,m,n,k,alpha,A,B,beta,C);
162  }
163 
164  template<>
165  inline void default_gemm<double>
166  (char tA,
167  char tB,
168  int m,
169  int n,
170  int k,
171  double alpha,
172  double const * A,
173  double const * B,
174  double beta,
175  double * C){
176  CTF_int::gemm<double>(tA,tB,m,n,k,alpha,A,B,beta,C);
177  }
178 
179  template<>
180  inline void default_gemm< std::complex<float> >
181  (char tA,
182  char tB,
183  int m,
184  int n,
185  int k,
186  std::complex<float> alpha,
187  std::complex<float> const * A,
188  std::complex<float> const * B,
189  std::complex<float> beta,
190  std::complex<float> * C){
191  CTF_int::gemm< std::complex<float> >(tA,tB,m,n,k,alpha,A,B,beta,C);
192  }
193 
194  template<>
195  inline void default_gemm< std::complex<double> >
196  (char tA,
197  char tB,
198  int m,
199  int n,
200  int k,
201  std::complex<double> alpha,
202  std::complex<double> const * A,
203  std::complex<double> const * B,
204  std::complex<double> beta,
205  std::complex<double> * C){
206  CTF_int::gemm< std::complex<double> >(tA,tB,m,n,k,alpha,A,B,beta,C);
207  }
208 
209  template<typename dtype>
210  void default_gemm_batch
211  (char taA,
212  char taB,
213  int l,
214  int m,
215  int n,
216  int k,
217  dtype alpha,
218  dtype const* A,
219  dtype const* B,
220  dtype beta,
221  dtype * C){
222  if (m == 1 && n == 1 && k == 1){
223  for (int i=0; i<l; i++){
224  C[i] = C[i]*beta + alpha*A[i]*B[i];
225  }
226  } else {
227  for (int i=0; i<l; i++){
228  default_gemm<dtype>(taA, taB, m, n, k, alpha, A+i*m*k, B+i*k*n, beta, C+i*m*n);
229  }
230  }
231  }
232 
233  template<>
234  inline void default_gemm_batch<float>
235  (char taA,
236  char taB,
237  int l,
238  int m,
239  int n,
240  int k,
241  float alpha,
242  float const* A,
243  float const* B,
244  float beta,
245  float * C){
246  CTF_int::gemm_batch<float>(taA, taB, l, m, n, k, alpha, A, B, beta, C);
247  }
248 
249  template<>
250  inline void default_gemm_batch<double>
251  (char taA,
252  char taB,
253  int l,
254  int m,
255  int n,
256  int k,
257  double alpha,
258  double const* A,
259  double const* B,
260  double beta,
261  double * C){
262  CTF_int::gemm_batch<double>(taA, taB, l, m, n, k, alpha, A, B, beta, C);
263  }
264 
265  template<>
266  inline void default_gemm_batch<std::complex<float>>
267  (char taA,
268  char taB,
269  int l,
270  int m,
271  int n,
272  int k,
273  std::complex<float> alpha,
274  std::complex<float> const* A,
275  std::complex<float> const* B,
276  std::complex<float> beta,
277  std::complex<float> * C){
278  CTF_int::gemm_batch< std::complex<float> >(taA, taB, l, m, n, k, alpha, A, B, beta, C);
279  }
280 
281  template<>
282  inline void default_gemm_batch<std::complex<double>>
283  (char taA,
284  char taB,
285  int l,
286  int m,
287  int n,
288  int k,
289  std::complex<double> alpha,
290  std::complex<double> const* A,
291  std::complex<double> const* B,
292  std::complex<double> beta,
293  std::complex<double> * C){
294  CTF_int::gemm_batch< std::complex<double> >(taA, taB, l, m, n, k, alpha, A, B, beta, C);
295  }
296 
297  template <typename dtype>
298  void default_coomm
299  (int m,
300  int n,
301  int k,
302  dtype alpha,
303  dtype const * A,
304  int const * rows_A,
305  int const * cols_A,
306  int nnz_A,
307  dtype const * B,
308  dtype beta,
309  dtype * C){
310  //TAU_FSTART(default_coomm);
311  for (int j=0; j<n; j++){
312  for (int i=0; i<m; i++){
313  C[j*m+i] *= beta;
314  }
315  }
316  for (int i=0; i<nnz_A; i++){
317  int row_A = rows_A[i]-1;
318  int col_A = cols_A[i]-1;
319  for (int col_C=0; col_C<n; col_C++){
320  C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A];
321  }
322  }
323  //TAU_FSTOP(default_coomm);
324  }
325 
326  template <>
328  (int,int,int,float,float const *,int const *,int const *,int,float const *,float,float *);
329 
330  template <>
332  (int,int,int,double,double const *,int const *,int const *,int,double const *,double,double *);
333 
334  template <>
335  void default_coomm< std::complex<float> >
336  (int,int,int,std::complex<float>,std::complex<float> const *,int const *,int const *,int,std::complex<float> const *,std::complex<float>,std::complex<float> *);
337 
338  template <>
339  void default_coomm< std::complex<double> >
340  (int,int,int,std::complex<double>,std::complex<double> const *,int const *,int const *,int,std::complex<double> const *,std::complex<double>,std::complex<double> *);
341 
342 
343 }
344 
345 
346 namespace CTF {
358  template <typename dtype=double, bool is_ord=CTF_int::get_default_is_ord<dtype>()>
359  class Semiring : public Monoid<dtype, is_ord> {
360  public:
361  bool is_def;
363  void (*fscal)(int,dtype,dtype*,int);
364  void (*faxpy)(int,dtype,dtype const*,int,dtype*,int);
365  dtype (*fmul)(dtype a, dtype b);
366  void (*fgemm)(char,char,int,int,int,dtype,dtype const*,dtype const*,dtype,dtype*);
367  void (*fcoomm)(int,int,int,dtype,dtype const*,int const*,int const*,int,dtype const*,dtype,dtype*);
368  void (*fgemm_batch)(char,char,int,int,int,int,dtype,dtype const*,dtype const*,dtype,dtype*);
369  //void (*fcsrmm)(int,int,int,dtype,dtype const*,int const*,int const*,dtype const*,dtype,dtype*);
370  //csrmultd_ kernel for multiplying two sparse matrices into a dense output
371  //void (*fcsrmultd)(int,int,int,dtype const*,int const*,int const*,dtype const*,int const*, int const*,dtype*,int);
372 
373  Semiring(Semiring const & other) : Monoid<dtype, is_ord>(other) {
374  this->tmulid = other.tmulid;
375  this->fscal = other.fscal;
376  this->faxpy = other.faxpy;
377  this->fmul = other.fmul;
378  this->fgemm = other.fgemm;
379  this->fcoomm = other.fcoomm;
380  this->is_def = other.is_def;
381  this->fgemm_batch = other.fgemm_batch;
382  }
383 
384  virtual CTF_int::algstrct * clone() const {
385  return new Semiring<dtype, is_ord>(*this);
386  }
387 
400  Semiring(dtype addid_,
401  dtype (*fadd_)(dtype a, dtype b),
402  MPI_Op addmop_,
403  dtype mulid_,
404  dtype (*fmul_)(dtype a, dtype b),
405  void (*gemm_)(char,char,int,int,int,dtype,dtype const*,dtype const*,dtype,dtype*)=NULL,
406  void (*axpy_)(int,dtype,dtype const*,int,dtype*,int)=NULL,
407  void (*scal_)(int,dtype,dtype*,int)=NULL,
408  void (*coomm_)(int,int,int,dtype,dtype const*,int const*,int const*,int,dtype const*,dtype,dtype*)=NULL,
409  void (*fgemm_batch_)(char,char,int,int,int,int,dtype,dtype const*,dtype const*,dtype,dtype*)=NULL)
410  : Monoid<dtype, is_ord>(addid_, fadd_, addmop_) {
411  fmul = fmul_;
412  fgemm = gemm_;
413  faxpy = axpy_;
414  fscal = scal_;
415  fcoomm = coomm_;
416  fgemm_batch = fgemm_batch_;
417  tmulid = mulid_;
418  // if provided a coordinate MM kernel, don't use CSR
419  this->has_coo_ker = (coomm_ != NULL);
420  is_def = false;
421  }
422 
426  Semiring() : Monoid<dtype,is_ord>() {
427  tmulid = dtype(1);
428  fmul = &CTF_int::default_mul<dtype>;
429  fgemm = &CTF_int::default_gemm<dtype>;
430  faxpy = &CTF_int::default_axpy<dtype>;
431  fscal = &CTF_int::default_scal<dtype>;
432  fcoomm = &CTF_int::default_coomm<dtype>;
433  fgemm_batch = &CTF_int::default_gemm_batch<dtype>;
434  is_def = true;
435  }
436 
437  void mul(char const * a,
438  char const * b,
439  char * c) const {
440  ((dtype*)c)[0] = fmul(((dtype*)a)[0],((dtype*)b)[0]);
441  }
442 
443  void safemul(char const * a,
444  char const * b,
445  char *& c) const {
446  if (a == NULL && b == NULL){
447  if (c!=NULL) CTF_int::cdealloc(c);
448  c = NULL;
449  } else if (a == NULL) {
450  if (c==NULL) c = (char*)CTF_int::alloc(this->el_size);
451  memcpy(c,b,this->el_size);
452  } else if (b == NULL) {
453  if (c==NULL) c = (char*)CTF_int::alloc(this->el_size);
454  memcpy(c,b,this->el_size);
455  } else {
456  if (c==NULL) c = (char*)CTF_int::alloc(this->el_size);
457  ((dtype*)c)[0] = fmul(((dtype*)a)[0],((dtype*)b)[0]);
458  }
459  }
460 
461  char const * mulid() const {
462  return (char const *)&tmulid;
463  }
464 
465  bool has_mul() const { return true; }
466 
468  void scal(int n,
469  char const * alpha,
470  char * X,
471  int incX) const {
472  if (fscal != NULL) fscal(n, ((dtype const *)alpha)[0], (dtype *)X, incX);
473  else {
474  dtype const a = ((dtype*)alpha)[0];
475  dtype * dX = (dtype*) X;
476  for (int64_t i=0; i<n; i++){
477  dX[i] = fmul(a,dX[i]);
478  }
479  }
480  }
481 
483  void axpy(int n,
484  char const * alpha,
485  char const * X,
486  int incX,
487  char * Y,
488  int incY) const {
489  if (faxpy != NULL) faxpy(n, ((dtype const *)alpha)[0], (dtype const *)X, incX, (dtype *)Y, incY);
490  else {
491  assert(incX==1);
492  assert(incY==1);
493  dtype a = ((dtype*)alpha)[0];
494  dtype const * dX = (dtype*) X;
495  dtype * dY = (dtype*) Y;
496  for (int64_t i=0; i<n; i++){
497  dY[i] = this->fadd(fmul(a,dX[i]), dY[i]);
498  }
499  }
500  }
501 
503  void gemm(char tA,
504  char tB,
505  int m,
506  int n,
507  int k,
508  char const * alpha,
509  char const * A,
510  char const * B,
511  char const * beta,
512  char * C) const {
513  if (fgemm != NULL) {
514  fgemm(tA, tB, m, n, k, ((dtype const *)alpha)[0], (dtype const *)A, (dtype const *)B, ((dtype const *)beta)[0], (dtype *)C);
515  } else {
516  //TAU_FSTART(sring_gemm);
517  dtype const * dA = (dtype const *) A;
518  dtype const * dB = (dtype const *) B;
519  dtype * dC = (dtype*) C;
520  if (!this->isequal(beta, this->mulid())){
521  scal(m*n, beta, C, 1);
522  }
523  int lda_Cj, lda_Ci, lda_Al, lda_Ai, lda_Bj, lda_Bl;
524  lda_Cj = m;
525  lda_Ci = 1;
526  if (tA == 'N'){
527  lda_Al = m;
528  lda_Ai = 1;
529  } else {
530  assert(tA == 'T');
531  lda_Al = 1;
532  lda_Ai = k;
533  }
534  if (tB == 'N'){
535  lda_Bj = k;
536  lda_Bl = 1;
537  } else {
538  assert(tB == 'T');
539  lda_Bj = 1;
540  lda_Bl = n;
541  }
542  if (!this->isequal(alpha, this->mulid())){
543  dtype a = ((dtype*)alpha)[0];
544  for (int64_t j=0; j<n; j++){
545  for (int64_t i=0; i<m; i++){
546  for (int64_t l=0; l<k; l++){
547  //dC[j*m+i] = this->fadd(fmul(a,fmul(dA[l*m+i],dB[j*k+l])), dC[j*m+i]);
548  dC[j*lda_Cj+i*lda_Ci] = this->fadd(fmul(a,fmul(dA[l*lda_Al+i*lda_Ai],dB[j*lda_Bj+l*lda_Bl])), dC[j*lda_Cj+i*lda_Ci]);
549  }
550  }
551  }
552  } else {
553  for (int64_t j=0; j<n; j++){
554  for (int64_t i=0; i<m; i++){
555  for (int64_t l=0; l<k; l++){
556  //dC[j*m+i] = this->fadd(fmul(a,fmul(dA[l*m+i],dB[j*k+l])), dC[j*m+i]);
557  dC[j*lda_Cj+i*lda_Ci] = this->fadd(fmul(dA[l*lda_Al+i*lda_Ai],dB[j*lda_Bj+l*lda_Bl]), dC[j*lda_Cj+i*lda_Ci]);
558  }
559  }
560  }
561  }
562  //TAU_FSTOP(sring_gemm);
563  }
564  }
565 
566  void gemm_batch(char tA,
567  char tB,
568  int l,
569  int m,
570  int n,
571  int k,
572  char const * alpha,
573  char const * A,
574  char const * B,
575  char const * beta,
576  char * C) const {
577  if (fgemm_batch != NULL) {
578  fgemm_batch(tA, tB, l, m, n, k, ((dtype const *)alpha)[0], ((dtype const *)A), ((dtype const *)B), ((dtype const *)beta)[0], ((dtype *)C));
579  } else {
580  for (int i=0; i<l; i++){
581  gemm(tA, tB, m, n, k, alpha, A+m*k*i*sizeof(dtype), B+k*n*i*sizeof(dtype), beta, C+m*n*i*sizeof(dtype));
582  }
583  }
584  }
585 
586  void offload_gemm(char tA,
587  char tB,
588  int m,
589  int n,
590  int k,
591  char const * alpha,
592  char const * A,
593  char const * B,
594  char const * beta,
595  char * C) const {
596  printf("CTF ERROR: offload gemm not present for this semiring\n");
597  assert(0);
598  }
599 
600  bool is_offloadable() const {
601  return false;
602  }
603 
604 
605  void coomm(int m, int n, int k, char const * alpha, char const * A, int const * rows_A, int const * cols_A, int64_t nnz_A, char const * B, char const * beta, char * C, CTF_int::bivar_function const * func) const {
606  if (func == NULL && alpha != NULL && fcoomm != NULL){
607  fcoomm(m, n, k, ((dtype const *)alpha)[0], (dtype const *)A, rows_A, cols_A, nnz_A, (dtype const *)B, ((dtype const *)beta)[0], (dtype *)C);
608  return;
609  }
610  if (func == NULL && alpha != NULL && this->isequal(beta,mulid())){
611  //TAU_FSTART(func_coomm);
612  dtype const * dA = (dtype const*)A;
613  dtype const * dB = (dtype const*)B;
614  dtype * dC = (dtype*)C;
615  dtype a = ((dtype*)alpha)[0];
616  if (!this->isequal(beta, this->mulid())){
617  scal(m*n, beta, C, 1);
618  }
619  for (int64_t i=0; i<nnz_A; i++){
620  int row_A = rows_A[i]-1;
621  int col_A = cols_A[i]-1;
622  for (int col_C=0; col_C<n; col_C++){
623  dC[col_C*m+row_A] = this->fadd(fmul(a,fmul(dA[i],dB[col_C*k+col_A])), dC[col_C*m+row_A]);
624  }
625  }
626  //TAU_FSTOP(func_coomm);
627  } else { assert(0); }
628  }
629 
630 
631  void default_csrmm
632  (int m,
633  int n,
634  int k,
635  dtype alpha,
636  dtype const * A,
637  int const * JA,
638  int const * IA,
639  int nnz_A,
640  dtype const * B,
641  dtype beta,
642  dtype * C) const {
643 #ifdef _OPENMP
644  #pragma omp parallel for
645 #endif
646  for (int row_A=0; row_A<m; row_A++){
647 #ifdef _OPENMP
648  #pragma omp parallel for
649 #endif
650  for (int col_B=0; col_B<n; col_B++){
651  C[col_B*m+row_A] = this->fmul(beta,C[col_B*m+row_A]);
652  if (IA[row_A] < IA[row_A+1]){
653  int i_A1 = IA[row_A]-1;
654  int col_A1 = JA[i_A1]-1;
655  dtype tmp = this->fmul(A[i_A1],B[col_B*k+col_A1]);
656  for (int i_A=IA[row_A]; i_A<IA[row_A+1]-1; i_A++){
657  int col_A = JA[i_A]-1;
658  tmp = this->fadd(tmp, this->fmul(A[i_A],B[col_B*k+col_A]));
659  }
660  C[col_B*m+row_A] = this->fadd(C[col_B*m+row_A], this->fmul(alpha,tmp));
661  }
662  }
663  }
664  }
665 
666 // void (*fcsrmultd)(int,int,int,dtype const*,int const*,int const*,dtype const*,int const*, int const*,dtype*,int);
667 
669  void csrmm(int m,
670  int n,
671  int k,
672  char const * alpha,
673  char const * A,
674  int const * JA,
675  int const * IA,
676  int64_t nnz_A,
677  char const * B,
678  char const * beta,
679  char * C,
680  CTF_int::bivar_function const * func) const {
681  assert(!this->has_coo_ker);
682  assert(func == NULL);
683  this->default_csrmm(m,n,k,((dtype*)alpha)[0],(dtype*)A,JA,IA,nnz_A,(dtype*)B,((dtype*)beta)[0],(dtype*)C);
684  }
685 
686  void default_csrmultd
687  (int m,
688  int n,
689  int k,
690  dtype alpha,
691  dtype const * A,
692  int const * JA,
693  int const * IA,
694  int nnz_A,
695  dtype const * B,
696  int const * JB,
697  int const * IB,
698  int nnz_B,
699  dtype beta,
700  dtype * C) const {
701 
702  if (!this->isequal((char const*)&beta, this->mulid())){
703  this->scal(m*n, (char const *)&beta, (char*)C, 1);
704  }
705 #ifdef _OPENMP
706  #pragma omp parallel for
707 #endif
708  for (int row_A=0; row_A<m; row_A++){
709  for (int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
710  int row_B = JA[i_A]-1; //=col_A
711  for (int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
712  int col_B = JB[i_B]-1;
713  if (!this->isequal((char const*)&alpha, this->mulid()))
714  this->fadd(C[col_B*m+row_A], this->fmul(alpha,this->fmul(A[i_A],B[i_B])));
715  else
716  this->fadd(C[col_B*m+row_A], this->fmul(A[i_A],B[i_B]));
717  }
718  }
719  }
720 
721  }
722  void gen_csrmultcsr
723  (int m,
724  int n,
725  int k,
726  dtype alpha,
727  dtype const * A, // A m by k
728  int const * JA,
729  int const * IA,
730  int nnz_A,
731  dtype const * B, // B k by n
732  int const * JB,
733  int const * IB,
734  int nnz_B,
735  dtype beta,
736  char *& C_CSR) const {
737  int * IC = (int*)CTF_int::alloc(sizeof(int)*(m+1));
738  memset(IC, 0, sizeof(int)*(m+1));
739 #ifdef _OPENMP
740  #pragma omp parallel
741  {
742 #endif
743  int * has_col = (int*)CTF_int::alloc(sizeof(int)*(n+1)); //n is the num of col of B
744  int nnz = 0;
745 #ifdef _OPENMP
746  #pragma omp for schedule(dynamic) // TO DO test other strategies
747 #endif
748  for (int i=0; i<m; i++){
749  memset(has_col, 0, sizeof(int)*(n+1));
750  nnz = 0;
751  for (int j=0; j<IA[i+1]-IA[i]; j++){
752  int row_B = JA[IA[i]+j-1]-1;
753  for (int kk=0; kk<IB[row_B+1]-IB[row_B]; kk++){
754  int idx_B = IB[row_B]+kk-1;
755  if (has_col[JB[idx_B]] == 0){
756  nnz++;
757  has_col[JB[idx_B]] = 1;
758  }
759  }
760  IC[i+1]=nnz;
761  }
762  }
763  CTF_int::cdealloc(has_col);
764 #ifdef _OPENMP
765  } // END PARALLEL
766 #endif
767  int ic_prev = 1;
768  for(int i=0;i < m+1; i++){
769  ic_prev += IC[i];
770  IC[i] = ic_prev;
771  }
772  CTF_int::CSR_Matrix C(IC[m]-1, m, n, this);
773  dtype * vC = (dtype*)C.vals();
774  this->set((char *)vC, this->addid(), IC[m]+1);
775  int * JC = C.JA();
776  memcpy(C.IA(), IC, sizeof(int)*(m+1));
777  CTF_int::cdealloc(IC);
778  IC = C.IA();
779 #ifdef _OPENMP
780  #pragma omp parallel
781  {
782 #endif
783  int ins = 0;
784  int *dcol = (int *) CTF_int::alloc(n*sizeof(int));
785  dtype *acc_data = (dtype*)CTF_int::alloc(n*sizeof (dtype));
786 #ifdef _OPENMP
787  #pragma omp for
788 #endif
789  for (int i=0; i<m; i++){
790  std::fill(acc_data, acc_data+n, this->taddid);
791  memset(dcol, 0, sizeof(int)*(n));
792  ins = 0;
793  for (int j=0; j<IA[i+1]-IA[i]; j++){
794  int row_b = JA[IA[i]+j-1]-1; // 1-based
795  int idx_a = IA[i]+j-1;
796  for (int ii = 0; ii < IB[row_b+1]-IB[row_b]; ii++){
797  int col_b = IB[row_b]+ii-1;
798  int col_c = JB[col_b]-1; // 1-based
799  dtype val = fmul(A[idx_a], B[col_b]);
800  if (dcol[col_c] == 0){
801  dcol[col_c] = JB[col_b];
802  }
803  //acc_data[col_c] += val;
804  acc_data[col_c]= this->fadd(acc_data[col_c], val);
805  }
806  }
807  for(int jj = 0; jj < n; jj++){
808  if (dcol[jj] != 0){
809  JC[IC[i]+ins-1] = dcol[jj];
810  vC[IC[i]+ins-1] = acc_data[jj];
811  ++ins;
812  }
813  }
814  }
815  CTF_int::cdealloc(dcol);
816  CTF_int::cdealloc(acc_data);
817 #ifdef _OPENMP
818  } //PRAGMA END
819 #endif
820  CTF_int::CSR_Matrix C_in(C_CSR);
821  if (!this->isequal((char const *)&alpha, this->mulid())){
822  this->scal(C.nnz(), (char const *)&alpha, C.vals(), 1);
823  }
824  if (C_CSR == NULL || C_in.nnz() == 0 || this->isequal((char const *)&beta, this->addid())){
825  C_CSR = C.all_data;
826  } else {
827  if (!this->isequal((char const *)&beta, this->mulid())){
828  this->scal(C_in.nnz(), (char const *)&beta, C_in.vals(), 1);
829  }
830  char * ans = this->csr_add(C_CSR, C.all_data);
832  C_CSR = ans;
833  }
834  }
835 
836 
837  /* void gen_csrmultcsr_old
838  (int m,
839  int n,
840  int k,
841  dtype alpha,
842  dtype const * A,
843  int const * JA,
844  int const * IA,
845  int nnz_A,
846  dtype const * B,
847  int const * JB,
848  int const * IB,
849  int nnz_B,
850  dtype beta,
851  char *& C_CSR) const {
852  int * IC = (int*)CTF_int::alloc(sizeof(int)*(m+1));
853  int * has_col = (int*)CTF_int::alloc(sizeof(int)*n);
854  IC[0] = 1;
855  for (int i=0; i<m; i++){
856  memset(has_col, 0, sizeof(int)*n);
857  IC[i+1] = IC[i];
858  CTF_int::CSR_Matrix::compute_has_col(JA, IA, JB, IB, i, has_col);
859  for (int j=0; j<n; j++){
860  IC[i+1] += has_col[j];
861  }
862  }
863  CTF_int::CSR_Matrix C(IC[m]-1, m, n, sizeof(dtype));
864  dtype * vC = (dtype*)C.vals();
865  this->set((char *)vC, this->addid(), IC[m]-1);
866  int * JC = C.JA();
867  memcpy(C.IA(), IC, sizeof(int)*(m+1));
868  CTF_int::cdealloc(IC);
869  IC = C.IA();
870  int64_t * rev_col = (int64_t*)CTF_int::alloc(sizeof(int64_t)*n);
871  for (int i=0; i<m; i++){
872  memset(has_col, 0, sizeof(int)*n);
873  CTF_int::CSR_Matrix::compute_has_col(JA, IA, JB, IB, i, has_col);
874  int vs = 0;
875  for (int j=0; j<n; j++){
876  if (has_col[j]){
877  JC[IC[i]+vs-1] = j+1;
878  rev_col[j] = IC[i]+vs-1;
879  vs++;
880  }
881  }
882  for (int j=0; j<IA[i+1]-IA[i]; j++){
883  int row_B = JA[IA[i]+j-1]-1;
884  int idx_A = IA[i]+j-1;
885  for (int l=0; l<IB[row_B+1]-IB[row_B]; l++){
886  int idx_B = IB[row_B]+l-1;
887  dtype tmp = fmul(A[idx_A],B[idx_B]);
888  vC[(rev_col[JB[idx_B]-1])] = this->fadd(vC[(rev_col[JB[idx_B]-1])], tmp);
889  }
890  }
891  }
892  CTF_int::CSR_Matrix C_in(C_CSR);
893  if (!this->isequal((char const *)&alpha, this->mulid())){
894  this->scal(C.nnz(), (char const *)&alpha, C.vals(), 1);
895  }
896  if (C_CSR == NULL || C_in.nnz() == 0 || this->isequal((char const *)&beta, this->addid())){
897  C_CSR = C.all_data;
898  } else {
899  if (!this->isequal((char const *)&beta, this->mulid())){
900  this->scal(C_in.nnz(), (char const *)&beta, C_in.vals(), 1);
901  }
902  char * ans = this->csr_add(C_CSR, C.all_data);
903  CTF_int::cdealloc(C.all_data);
904  C_CSR = ans;
905  }
906  CTF_int::cdealloc(has_col);
907  CTF_int::cdealloc(rev_col);
908  }*/
909 
910 
911  void default_csrmultcsr
912  (int m,
913  int n,
914  int k,
915  dtype alpha,
916  dtype const * A,
917  int const * JA,
918  int const * IA,
919  int nnz_A,
920  dtype const * B,
921  int const * JB,
922  int const * IB,
923  int nnz_B,
924  dtype beta,
925  char *& C_CSR) const {
926  this->gen_csrmultcsr(m,n,k,alpha,A,JA,IA,nnz_A,B,JB,IB,nnz_B,beta,C_CSR);
927  }
928 
929 
930  void csrmultd
931  (int m,
932  int n,
933  int k,
934  char const * alpha,
935  char const * A,
936  int const * JA,
937  int const * IA,
938  int64_t nnz_A,
939  char const * B,
940  int const * JB,
941  int const * IB,
942  int64_t nnz_B,
943  char const * beta,
944  char * C) const {
945  this->default_csrmultd(m,n,k,((dtype const*)alpha)[0],(dtype const*)A,JA,IA,nnz_A,(dtype const*)B,JB,IB,nnz_B,((dtype const*)beta)[0],(dtype*)C);
946  }
947 
948 
949  void csrmultcsr
950  (int m,
951  int n,
952  int k,
953  char const * alpha,
954  char const * A,
955  int const * JA,
956  int const * IA,
957  int64_t nnz_A,
958  char const * B,
959  int const * JB,
960  int const * IB,
961  int64_t nnz_B,
962  char const * beta,
963  char *& C_CSR) const {
964 
965  if (is_def){
966  this->default_csrmultcsr(m,n,k,((dtype const*)alpha)[0],(dtype const*)A,JA,IA,nnz_A,(dtype const*)B,JB,IB,nnz_B,((dtype const*)beta)[0],C_CSR);
967  } else {
968  this->gen_csrmultcsr(m,n,k,((dtype const*)alpha)[0],(dtype const*)A,JA,IA,nnz_A,(dtype const*)B,JB,IB,nnz_B,((dtype const*)beta)[0],C_CSR);
969  }
970  }
971 
972  };
976 }
977 namespace CTF {
978  template <>
979  void CTF::Semiring<float,1>::default_csrmm(int,int,int,float,float const *,int const *,int const *,int,float const *,float,float *) const;
980  template <>
981  void CTF::Semiring<double,1>::default_csrmm(int,int,int,double,double const *,int const *,int const *,int,double const *,double,double *) const;
982  template <>
983  void CTF::Semiring<std::complex<float>,0>::default_csrmm(int,int,int,std::complex<float>,std::complex<float> const *,int const *,int const *,int,std::complex<float> const *,std::complex<float>,std::complex<float> *) const;
984  template <>
985  void CTF::Semiring<std::complex<double>,0>::default_csrmm(int,int,int,std::complex<double>,std::complex<double> const *,int const *,int const *,int,std::complex<double> const *,std::complex<double>,std::complex<double> *) const;
986 
987 
988  template <>
989  void CTF::Semiring<float,1>::default_csrmultd(int,int,int,float,float const *,int const *,int const *,int,float const *,int const *,int const *,int,float,float *) const;
990  template <>
991  void CTF::Semiring<double,1>::default_csrmultd(int,int,int,double,double const *,int const *,int const *,int,double const *,int const *,int const *,int,double,double *) const;
992  template <>
993  void CTF::Semiring<std::complex<float>,0>::default_csrmultd(int,int,int,std::complex<float>,std::complex<float> const *,int const *,int const *,int,std::complex<float> const *,int const *,int const *,int,std::complex<float>,std::complex<float> *) const;
994  template <>
995  void CTF::Semiring<std::complex<double>,0>::default_csrmultd(int,int,int,std::complex<double>,std::complex<double> const *,int const *,int const *,int,std::complex<double> const *,int const *,int const *,int,std::complex<double>,std::complex<double> *) const;
996 
997  template <>
998  void CTF::Semiring<float,1>::default_csrmultcsr(int,int,int,float,float const *,int const *,int const *,int,float const *,int const *,int const *,int,float,char *&) const;
999  template <>
1000  void CTF::Semiring<double,1>::default_csrmultcsr(int,int,int,double,double const *,int const *,int const *,int,double const *,int const *,int const *,int,double,char *&) const;
1001  template <>
1002  void CTF::Semiring<std::complex<float>,0>::default_csrmultcsr(int,int,int,std::complex<float>,std::complex<float> const *,int const *,int const *,int,std::complex<float> const *,int const *,int const *,int,std::complex<float>,char *&) const;
1003  template <>
1004  void CTF::Semiring<std::complex<double>,0>::default_csrmultcsr(int,int,int,std::complex<double>,std::complex<double> const *,int const *,int const *,int,std::complex<double> const *,int const *,int const *,int,std::complex<double>,char *&) const;
1005 
1006 
1007  template<>
1009  template<>
1011  template<>
1012  bool CTF::Semiring<std::complex<float>,0>::is_offloadable() const;
1013  template<>
1014  bool CTF::Semiring<std::complex<double>,0>::is_offloadable() const;
1015 
1016  template<>
1017  void CTF::Semiring<double,1>::offload_gemm(char,char,int,int,int,char const *,char const *,char const *,char const *,char *) const;
1018  template<>
1019  void CTF::Semiring<double,1>::offload_gemm(char,char,int,int,int,char const *,char const *,char const *,char const *,char *) const;
1020  template<>
1021  void CTF::Semiring<std::complex<float>,0>::offload_gemm(char,char,int,int,int,char const *,char const *,char const *,char const *,char *) const;
1022  template<>
1023  void CTF::Semiring<std::complex<double>,0>::offload_gemm(char,char,int,int,int,char const *,char const *,char const *,char const *,char *) const;
1024 }
1025 
1026 #include "ring.h"
1027 #endif
void(* fscal)(int, dtype, dtype *, int)
Definition: semiring.h:363
void scal(int n, char const *alpha, char *X, int incX) const
X["i"]=alpha*X["i"];.
Definition: semiring.h:468
void mul(char const *a, char const *b, char *c) const
c = a*b
Definition: semiring.h:437
void default_scal(int n, dtype alpha, dtype *X, int incX)
Definition: semiring.h:47
int * IA() const
retrieves prefix sum of number of nonzeros for each row (of size nrow()+1) out of all_data ...
Definition: csr.cxx:107
void default_coomm< double >(int m, int n, int k, double alpha, double const *A, int const *rows_A, int const *cols_A, int nnz_A, double const *B, double beta, double *C)
Definition: semiring.cxx:233
void coomm(int m, int n, int k, char const *alpha, char const *A, int const *rows_A, int const *cols_A, int64_t nnz_A, char const *B, char const *beta, char *C, CTF_int::bivar_function const *func) const
sparse version of gemm using coordinate format for A
Definition: semiring.h:605
dtype ** get_grp_ptrs(int64_t grp_sz, int64_t ngrp, dtype const *data)
Definition: semiring.h:110
dtype tmulid
Definition: semiring.h:362
dtype(* fmul)(dtype a, dtype b)
Definition: semiring.h:365
Semiring is a Monoid with an addition multiplicaton function addition must have an identity and be as...
Definition: semiring.h:359
void * alloc(int64_t len)
alloc abstraction
Definition: memcontrol.cxx:365
void(* faxpy)(int, dtype, dtype const *, int, dtype *, int)
Definition: semiring.h:364
void default_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C) const
Definition: semiring.h:632
void csrmm(int m, int n, int k, char const *alpha, char const *A, int const *JA, int const *IA, int64_t nnz_A, char const *B, char const *beta, char *C, CTF_int::bivar_function const *func) const
sparse version of gemm using CSR format for A
Definition: semiring.h:669
void safemul(char const *a, char const *b, char *&c) const
c = a*b, with NULL treated as mulid
Definition: semiring.h:443
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:82
untyped internal class for triply-typed bivariate function
Definition: ctr_comm.h:16
void default_scal< double >(int n, double alpha, double *X, int incX)
Definition: semiring.cxx:176
void default_coomm(int m, int n, int k, dtype alpha, dtype const *A, int const *rows_A, int const *cols_A, int nnz_A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.h:299
void default_gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.h:211
void default_coomm< float >(int m, int n, int k, float alpha, float const *A, int const *rows_A, int const *cols_A, int nnz_A, float const *B, float beta, float *C)
Definition: semiring.cxx:208
void default_axpy(int n, dtype alpha, dtype const *X, int incX, dtype *Y, int incY)
Definition: semiring.h:19
virtual CTF_int::algstrct * clone() const
&#39;&#39;copy constructor&#39;&#39;
Definition: semiring.h:384
bool has_mul() const
Definition: semiring.h:465
void gemm_batch(char tA, char tB, int l, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
beta*C["ijl"]=alpha*A^tA["ikl"]*B^tB["kjl"];
Definition: semiring.h:566
void(* fgemm)(char, char, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)
Definition: semiring.h:366
void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
Definition: semiring.h:586
void default_gemm_batch< double >(char taA, char taB, int l, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
Definition: semiring.h:251
int * JA() const
retrieves column indices of each value in vals stored in sorted form by row
Definition: csr.cxx:119
void default_csrmultcsr(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype beta, char *&C_CSR) const
Definition: semiring.h:912
void axpy(int n, char const *alpha, char const *X, int incX, char *Y, int incY) const
Y["i"]+=alpha*X["i"];.
Definition: semiring.h:483
void gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
beta*C["ij"]=alpha*A^tA["ik"]*B^tB["kj"];
Definition: semiring.h:503
int64_t nnz() const
retrieves number of nonzeros out of all_data
Definition: csr.cxx:80
char const * mulid() const
identity element for multiplication i.e. 1
Definition: semiring.h:461
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
Definition: csr.h:22
void default_csrmultd(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype beta, dtype *C) const
Definition: semiring.h:687
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:15
void default_gemm< double >(char tA, char tB, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
Definition: semiring.h:166
Semiring()
constructor for algstrct equipped with + only
Definition: semiring.h:426
bool is_offloadable() const
Definition: semiring.h:600
Semiring(dtype addid_, dtype(*fadd_)(dtype a, dtype b), MPI_Op addmop_, dtype mulid_, dtype(*fmul_)(dtype a, dtype b), void(*gemm_)(char, char, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)=NULL, void(*axpy_)(int, dtype, dtype const *, int, dtype *, int)=NULL, void(*scal_)(int, dtype, dtype *, int)=NULL, void(*coomm_)(int, int, int, dtype, dtype const *, int const *, int const *, int, dtype const *, dtype, dtype *)=NULL, void(*fgemm_batch_)(char, char, int, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)=NULL)
constructor for algstrct equipped with * and +
Definition: semiring.h:400
void default_gemm< float >(char tA, char tB, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
Definition: semiring.h:151
char * all_data
serialized buffer containing all info, index, and values related to matrix
Definition: csr.h:25
char * vals() const
retrieves array of values out of all_data
Definition: csr.cxx:101
int cdealloc(void *ptr)
free abstraction
Definition: memcontrol.cxx:480
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
Definition: apsp.cxx:17
void default_scal< float >(int n, float alpha, float *X, int incX)
Definition: semiring.cxx:171
A Monoid is a Set equipped with a binary addition operator &#39;+&#39; or a custom function addition must hav...
Definition: monoid.h:69
void default_axpy< float >(int n, float alpha, float const *X, int incX, float *Y, int incY)
Definition: semiring.cxx:128
void(* fgemm_batch)(char, char, int, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)
Definition: semiring.h:368
Semiring(Semiring const &other)
Definition: semiring.h:373
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void default_gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.h:71
void default_axpy< double >(int n, double alpha, double const *X, int incX, double *Y, int incY)
Definition: semiring.cxx:139
void(* fcoomm)(int, int, int, dtype, dtype const *, int const *, int const *, int, dtype const *, dtype, dtype *)
Definition: semiring.h:367
void default_gemm_batch< float >(char taA, char taB, int l, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
Definition: semiring.h:235
dtype default_mul(dtype a, dtype b)
Definition: semiring.h:14