Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
algstrct.cxx
Go to the documentation of this file.
1 /*Copyright (c) 2014, Edgar Solomonik, all rights reserved.*/
2 #include "../shared/util.h"
3 #include "../shared/blas_symbs.h"
4 #include "untyped_tensor.h"
5 #include "algstrct.h"
6 #include "../sparse_formats/csr.h"
7 
8 using namespace std;
9 
10 namespace CTF_int {
11  LinModel<3> csrred_mdl(csrred_mdl_init,"csrred_mdl");
12  LinModel<3> csrred_mdl_cst(csrred_mdl_cst_init,"csrred_mdl_cst");
13 
14 
15  template<int l>
16  struct CompPair{
17  int64_t key;
18  char data[l];
19  bool operator < (const CompPair& other) const {
20  return (key < other.key);
21  }
22  }; // __attribute__((packed));
23 
24  struct IntPair{
25  int64_t key;
26  int data;
27  bool operator < (const IntPair& other) const {
28  return (key < other.key);
29  }
30  }; // __attribute__((packed));
31 
32  struct ShortPair{
33  int64_t key;
34  short data;
35  bool operator < (const ShortPair& other) const {
36  return (key < other.key);
37  }
38  }; // __attribute__((packed));
39 
40  struct BoolPair{
41  int64_t key;
42  bool data;
43  bool operator < (const BoolPair& other) const {
44  return (key < other.key);
45  }
46  }; // __attribute__((packed));
47 
48 /* template struct CompPair<1>;
49  template struct CompPair<2>;
50  template struct CompPair<4>;*/
51  template struct CompPair<8>;
52  template struct CompPair<12>;
53  template struct CompPair<16>;
54  template struct CompPair<20>;
55  template struct CompPair<24>;
56  template struct CompPair<28>;
57  template struct CompPair<32>;
58 
59  struct CompPtrPair{
60  int64_t key;
61  int64_t idx;
62  bool operator < (const CompPtrPair& other) const {
63  return (key < other.key);
64  }
65  };
66 
67 
68  algstrct::algstrct(int el_size_){
69  el_size = el_size_;
70  has_coo_ker = false;
71  }
72 
73  MPI_Op algstrct::addmop() const {
74  printf("CTF ERROR: no addition MPI_Op present for this algebraic structure\n");
75  ASSERT(0);
76  assert(0);
77  return MPI_SUM;
78  }
79 
80  MPI_Datatype algstrct::mdtype() const {
81  printf("CTF ERROR: no MPI_Datatype present for this algebraic structure\n");
82  ASSERT(0);
83  assert(0);
84  return MPI_CHAR;
85  }
86 
87 
88 
89  char const * algstrct::addid() const {
90  return NULL;
91  }
92 
93  char const * algstrct::mulid() const {
94  return NULL;
95  }
96 
97  void algstrct::safeaddinv(char const * a, char *& b) const {
98  printf("CTF ERROR: no additive inverse present for this algebraic structure\n");
99  ASSERT(0);
100  assert(0);
101  }
102 
103  void algstrct::addinv(char const * a, char * b) const {
104  printf("CTF ERROR: no additive inverse present for this algebraic structure\n");
105  ASSERT(0);
106  assert(0);
107  }
108 
109  void algstrct::add(char const * a, char const * b, char * c) const {
110  printf("CTF ERROR: no addition operation present for this algebraic structure\n");
111  ASSERT(0);
112  assert(0);
113  }
114 
115  void algstrct::accum(char const * a, char * b) const {
116  this->add(a, b, b);
117  }
118 
119 
120  void algstrct::mul(char const * a, char const * b, char * c) const {
121  printf("CTF ERROR: no multiplication operation present for this algebraic structure\n");
122  ASSERT(0);
123  assert(0);
124  }
125 
126  void algstrct::safemul(char const * a, char const * b, char *& c) const {
127  printf("CTF ERROR: no multiplication operation present for this algebraic structure\n");
128  ASSERT(0);
129  assert(0);
130  }
131 
132  void algstrct::min(char const * a, char const * b, char * c) const {
133  printf("CTF ERROR: no min operation present for this algebraic structure\n");
134  ASSERT(0);
135  assert(0);
136  }
137 
138  void algstrct::max(char const * a, char const * b, char * c) const {
139  printf("CTF ERROR: no max operation present for this algebraic structure\n");
140  ASSERT(0);
141  assert(0);
142  }
143 
144  void algstrct::cast_int(int64_t i, char * c) const {
145  printf("CTF ERROR: integer scaling not possible for this algebraic structure\n");
146  ASSERT(0);
147  assert(0);
148  }
149 
150  void algstrct::cast_double(double d, char * c) const {
151  printf("CTF ERROR: double scaling not possible for this algebraic structure\n");
152  ASSERT(0);
153  assert(0);
154  }
155 
156  double algstrct::cast_to_double(char const * c) const {
157  printf("CTF ERROR: double cast not possible for this algebraic structure\n");
158  ASSERT(0);
159  assert(0);
160  return 0.0;
161  }
162 
163  int64_t algstrct::cast_to_int(char const * c) const {
164  printf("CTF ERROR: int cast not possible for this algebraic structure\n");
165  ASSERT(0);
166  assert(0);
167  return 0;
168  }
169 
170  void algstrct::print(char const * a, FILE * fp) const {
171  for (int i=0; i<el_size; i++){
172  fprintf(fp,"%x",a[i]);
173  }
174  }
175 
176  void algstrct::min(char * c) const {
177  printf("CTF ERROR: min limit not present for this algebraic structure\n");
178  ASSERT(0);
179  assert(0);
180  }
181 
182  void algstrct::max(char * c) const {
183  printf("CTF ERROR: max limit not present for this algebraic structure\n");
184  ASSERT(0);
185  assert(0);
186  }
187 
188  void algstrct::sort(int64_t n, char * pairs) const {
189  switch (this->el_size){
190  case 1:
191  ASSERT(sizeof(BoolPair)==this->pair_size());
192  std::sort((BoolPair*)pairs,((BoolPair*)pairs)+n);
193  break;
194  case 2:
195  ASSERT(sizeof(ShortPair)==this->pair_size());
196  std::sort((ShortPair*)pairs,((ShortPair*)pairs)+n);
197  break;
198  case 4:
199  ASSERT(sizeof(IntPair)==this->pair_size());
200  std::sort((IntPair*)pairs,((IntPair*)pairs)+n);
201  break;
202  case 8:
203  ASSERT(sizeof(CompPair<8>)==this->pair_size());
204  std::sort((CompPair<8>*)pairs,((CompPair<8>*)pairs)+n);
205  break;
206  case 12:
207  ASSERT(sizeof(CompPair<12>)==this->pair_size());
208  std::sort((CompPair<12>*)pairs,((CompPair<12>*)pairs)+n);
209  break;
210  case 16:
211  ASSERT(sizeof(CompPair<16>)==this->pair_size());
212  std::sort((CompPair<16>*)pairs,((CompPair<16>*)pairs)+n);
213  break;
214  case 20:
215  ASSERT(sizeof(CompPair<20>)==this->pair_size());
216  std::sort((CompPair<20>*)pairs,((CompPair<20>*)pairs)+n);
217  break;
218  case 24:
219  ASSERT(sizeof(CompPair<24>)==this->pair_size());
220  std::sort((CompPair<24>*)pairs,((CompPair<24>*)pairs)+n);
221  break;
222  case 28:
223  ASSERT(sizeof(CompPair<28>)==this->pair_size());
224  std::sort((CompPair<28>*)pairs,((CompPair<28>*)pairs)+n);
225  break;
226  case 32:
227  ASSERT(sizeof(CompPair<32>)==this->pair_size());
228  std::sort((CompPair<32>*)pairs,((CompPair<32>*)pairs)+n);
229  break;
230  default:
231  //Causes a bogus uninitialized variable warning with GNU
232  CompPtrPair idx_pairs[n];
233 #ifdef USE_OMP
234  #pragma omp parallel for
235 #endif
236  for (int64_t i=0; i<n; i++){
237  idx_pairs[i].key = *(int64_t*)(pairs+i*(sizeof(int64_t)+this->el_size));
238  idx_pairs[i].idx = i;
239  }
240  //FIXME :(
241  char * swap_buffer = this->pair_alloc(n);
242 
243  this->copy_pairs(swap_buffer, pairs, n);
244 
245  std::sort(idx_pairs, idx_pairs+n);
246 
247  ConstPairIterator piw(this, swap_buffer);
248  PairIterator pip(this, pairs);
249 
250 #ifdef USE_OMP
251  #pragma omp parallel for
252 #endif
253  for (int64_t i=0; i<n; i++){
254  pip[i].write_val(piw[idx_pairs[i].idx].d());
255  }
256  this->pair_dealloc(swap_buffer);
257  break; //compiler warning here seems to be gcc bug
258  }
259 
260  }
261 
262  void algstrct::scal(int n,
263  char const * alpha,
264  char * X,
265  int incX) const {
266  if (isequal(alpha, addid())){
267  if (incX == 1) set(X, addid(), n);
268  else {
269  for (int i=0; i<n; i++){
270  copy(X+i*el_size, addid());
271  }
272  }
273  } else {
274  printf("CTF ERROR: scal not present for this algebraic structure\n");
275  ASSERT(0);
276  assert(0);
277  }
278  }
279 
280  void algstrct::axpy(int n,
281  char const * alpha,
282  char const * X,
283  int incX,
284  char * Y,
285  int incY) const {
286  printf("CTF ERROR: axpy not present for this algebraic structure\n");
287  ASSERT(0);
288  assert(0);
289  }
290 
291  void algstrct::gemm_batch(char tA,
292  char tB,
293  int l,
294  int m,
295  int n,
296  int k,
297  char const * alpha,
298  char const * A,
299  char const * B,
300  char const * beta,
301  char * C) const {
302  printf("CTF ERROR: gemm_batch not present for this algebraic structure\n");
303  ASSERT(0);
304  }
305 
306 
307  void algstrct::gemm(char tA,
308  char tB,
309  int m,
310  int n,
311  int k,
312  char const * alpha,
313  char const * A,
314  char const * B,
315  char const * beta,
316  char * C) const {
317  printf("CTF ERROR: gemm not present for this algebraic structure\n");
318  ASSERT(0);
319  }
320 
321 
323  char tB,
324  int m,
325  int n,
326  int k,
327  char const * alpha,
328  char const * A,
329  char const * B,
330  char const * beta,
331  char * C) const {
332  printf("CTF ERROR: offload gemm not present for this algebraic structure\n");
333  ASSERT(0);
334  }
335 
336  bool algstrct::is_offloadable() const {
337  return false;
338  }
339 
340  bool algstrct::isequal(char const * a, char const * b) const {
341  if (a == NULL && b == NULL) return true;
342  if (a == NULL || b == NULL) return false;
343  bool iseq = true;
344  for (int i=0; i<el_size; i++) {
345  if (a[i] != b[i]) iseq = false;
346  }
347  return iseq;
348  }
349 
350  void algstrct::coo_to_csr(int64_t nz, int nrow, char * csr_vs, int * csr_cs, int * csr_rs, char const * coo_vs, int const * coo_rs, int const * coo_cs) const {
351  printf("CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
352  ASSERT(0);
353  }
354 
355  void algstrct::csr_to_coo(int64_t nz, int nrow, char const * csr_vs, int const * csr_ja, int const * csr_ia, char * coo_vs, int * coo_rs, int * coo_cs) const {
356  printf("CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
357  ASSERT(0);
358  }
359 
360 
361 // void algstrct::csr_add(int64_t m, int64_t n, char const * a, int const * ja, int const * ia, char const * b, int const * jb, int const * ib, char *& c, int *& jc, int *& ic){
362  char * algstrct::csr_add(char * cA, char * cB) const {
363 
364  return CTF_int::CSR_Matrix::csr_add(cA, cB, this);
365  }
366 
367  char * algstrct::csr_reduce(char * cA, int root, MPI_Comm cm) const {
368  int r, p;
369  MPI_Comm_rank(cm, &r);
370  MPI_Comm_size(cm, &p);
371  if (p==1) return cA;
372  TAU_FSTART(csr_reduce);
373  int s = 2;
374  double t_st = MPI_Wtime();
375  while (p%s != 0) s++;
376  int sr = r%s;
377  MPI_Comm scm;
378  MPI_Comm rcm;
379  MPI_Comm_split(cm, r/s, sr, &scm);
380  MPI_Comm_split(cm, sr, r/s, &rcm);
381 
382  CSR_Matrix A(cA);
383  int64_t sz_A = A.size();
384  char * parts_buffer;
385  CSR_Matrix ** parts = (CSR_Matrix**)CTF_int::alloc(sizeof(CSR_Matrix*)*s);
386  A.partition(s, &parts_buffer, parts);
387  //MPI_Request reqs[2*(s-1)];
388  int rcv_szs[s];
389  int snd_szs[s];
390  int64_t tot_buf_size = 0;
391  for (int i=0; i<s; i++){
392  if (i==sr) snd_szs[i] = 0;
393  else snd_szs[i] = parts[i]->size();
394  tot_buf_size += snd_szs[i];
395  }
396 
397  MPI_Alltoall(snd_szs, 1, MPI_INT, rcv_szs, 1, MPI_INT, scm);
398  int64_t tot_rcv_sz = 0;
399  for (int i=0; i<s; i++){
400  //printf("i=%d/%d,rcv_szs[i]=%d\n",i,s,rcv_szs[i]);
401  tot_rcv_sz += rcv_szs[i];
402  }
403  char * rcv_buf = (char*)CTF_int::alloc(tot_rcv_sz);
404  char * smnds[s];
405  int rcv_displs[s];
406  int snd_displs[s];
407  rcv_displs[0] = 0;
408  for (int i=0; i<s; i++){
409  if (i>0) rcv_displs[i] = rcv_szs[i-1]+rcv_displs[i-1];
410  snd_displs[i] = parts[i]->all_data - parts[0]->all_data;
411  if (i==sr) smnds[i] = parts[i]->all_data;
412  else smnds[i] = rcv_buf + rcv_displs[i];
413 // printf("parts[%d].all_data = %p\n",i,parts[i]->all_data);
414  // printf("snd_dipls[%d] = %d\n", i, snd_displs[i]);
415 // printf("rcv_dipls[%d] = %d\n", i, rcv_displs[i]);
416  }
417  MPI_Alltoallv(parts[0]->all_data, snd_szs, snd_displs, MPI_CHAR, rcv_buf, rcv_szs, rcv_displs, MPI_CHAR, scm);
418  for (int i=0; i<s; i++){
419  delete parts[i]; //does not actually free buffer space
420  }
421  cdealloc(parts);
422  /* smnds[i] = (char*)alloc(rcv_szs[i]);
423  int sbw = (r/phase - i + s-1)%s;
424  int rbw = sbw + (r/(phase*s))*s + (r%phase);
425  int rfw = sfw + (r/(phase*s))*s + (r%phase);
426  char * rcv_data = (char*)alloc(rcv_szs[i]);
427  smnds[i] = rcv_data;
428  MPI_Isend(parts[sfw], snd_szs[i], MPI_CHAR, rfw, s+i, cm, reqs+i);
429  MPI_Irecv(rcv_data, rcv_szs[i], MPI_CHAR, rbw, s+i, cm, reqs+s-1+i);
430  }
431  MPI_Status stats[2*(s-1)];
432  MPI_Waitall(2*(s-1), reqs, stats);
433  for (int i=1; i<s; i++){
434  int sfw = (r/phase + i + s-1)%s;
435  cdealloc(parts[sfw]);
436  }
437  cdealloc(parts);*/
438  for (int z=1; z<s; z<<=1){
439  for (int i=0; i<s-z; i+=2*z){
440  char * csr_new = csr_add(smnds[i], smnds[i+z]);
441  if ((smnds[i] < parts_buffer ||
442  smnds[i] > parts_buffer+tot_buf_size) &&
443  (smnds[i] < rcv_buf ||
444  smnds[i] > rcv_buf+tot_rcv_sz))
445  cdealloc(smnds[i]);
446  if ((smnds[i+z] < parts_buffer ||
447  smnds[i+z] > parts_buffer+tot_buf_size) &&
448  (smnds[i+z] < rcv_buf ||
449  smnds[i+z] > rcv_buf+tot_rcv_sz))
450  cdealloc(smnds[i+z]);
451  smnds[i] = csr_new;
452  }
453  }
454  cdealloc(parts_buffer); //dealloc all parts
455  cdealloc(rcv_buf);
456  TAU_FSTOP(csr_reduce);
457  char * red_sum = csr_reduce(smnds[0], root/s, rcm);
458  TAU_FSTART(csr_reduce);
459  if (smnds[0] != red_sum) cdealloc(smnds[0]);
460  if (r/s == root/s){
461  CSR_Matrix cf(red_sum);
462  int sz = cf.size();
463  int sroot = root%s;
464  int cb_sizes[s];
465  if (sroot == sr) sz = 0;
466  MPI_Gather(&sz, 1, MPI_INT, cb_sizes, 1, MPI_INT, sroot, scm);
467  int64_t tot_cb_size = 0;
468  int cb_displs[s];
469  if (sr == sroot){
470  for (int i=0; i<s; i++){
471  cb_displs[i] = tot_cb_size;
472  tot_cb_size += cb_sizes[i];
473  }
474  }
475  char * cb_bufs = (char*)CTF_int::alloc(tot_cb_size);
476  MPI_Gatherv(red_sum, sz, MPI_CHAR, cb_bufs, cb_sizes, cb_displs, MPI_CHAR, sroot, scm);
477  MPI_Comm_free(&scm);
478  MPI_Comm_free(&rcm);
479  if (sr == sroot){
480  for (int i=0; i<s; i++){
481  smnds[i] = cb_bufs + cb_displs[i];
482  if (i==sr) smnds[i] = red_sum;
483  }
484  CSR_Matrix out(smnds, s);
485  cdealloc(red_sum);
486  cdealloc(cb_bufs);
487  double t_end = MPI_Wtime() - t_st;
488  double tps[] = {t_end, 1.0, log2((double)p), (double)sz_A};
489 
490  // note-quite-sure
491  csrred_mdl.observe(tps);
492  TAU_FSTOP(csr_reduce);
493  return out.all_data;
494  } else {
495  cdealloc(red_sum);
496  cdealloc(cb_bufs);
497  TAU_FSTOP(csr_reduce);
498  return NULL;
499  }
500  } else {
501  MPI_Comm_free(&scm);
502  MPI_Comm_free(&rcm);
503  TAU_FSTOP(csr_reduce);
504  return NULL;
505  }
506  }
507 
508  double algstrct::estimate_csr_red_time(int64_t msg_sz, CommData const * cdt) const {
509 
510  double ps[] = {1.0, log2((double)cdt->np), (double)msg_sz};
511  return csrred_mdl.est_time(ps);
512  }
513 
514  void algstrct::acc(char * b, char const * beta, char const * a, char const * alpha) const {
515  char tmp[el_size];
516  mul(b, beta, tmp);
517  mul(a, alpha, b);
518  add(b, tmp, b);
519  }
520 
521  void algstrct::accmul(char * c, char const * a, char const * b, char const * alpha) const {
522  char tmp[el_size];
523  mul(a, b, tmp);
524  mul(tmp, alpha, tmp);
525  add(c, tmp, c);
526  }
527 
528 
529  void algstrct::safecopy(char *& a, char const * b) const {
530  if (b == NULL){
531  if (a != NULL) cdealloc(a);
532  a = NULL;
533  } else {
534  if (a == NULL) a = (char*)CTF_int::alloc(el_size);
535  this->copy(a, b);
536  }
537  }
538  void algstrct::copy(char * a, char const * b) const {
539  memcpy(a, b, el_size);
540  }
541 
542  void algstrct::copy_pair(char * a, char const * b) const {
543  memcpy(a, b, pair_size());
544  }
545 
546  void algstrct::copy(char * a, char const * b, int64_t n) const {
547  memcpy(a, b, el_size*n);
548  }
549 
550  void algstrct::copy_pairs(char * a, char const * b, int64_t n) const {
551  memcpy(a, b, pair_size()*n);
552  }
553 
554 
555  void algstrct::copy(int64_t nn, char const * a, int inc_a, char * b, int inc_b) const {
556  int n = nn;
557  switch (el_size) {
558  case 4:
559  CTF_BLAS::SCOPY(&n, (float const*)a, &inc_a, (float*)b, &inc_b);
560  break;
561  case 8:
562  CTF_BLAS::DCOPY(&n, (double const*)a, &inc_a, (double*)b, &inc_b);
563  break;
564  case 16:
565  CTF_BLAS::ZCOPY(&n, (std::complex<double> const*)a, &inc_a, (std::complex<double>*)b, &inc_b);
566  break;
567  default:
568 #ifdef USE_OMP
569  #pragma omp parallel for
570 #endif
571  for (int64_t i=0; i<nn; i++){
572  copy(b+el_size*inc_b*i, a+el_size*inc_a*i);
573  }
574  break;
575  }
576  }
577 
578  void algstrct::copy(int64_t m,
579  int64_t n,
580  char const * a,
581  int64_t lda_a,
582  char * b,
583  int64_t lda_b) const {
584  if (lda_a == m && lda_b == n){
585  memcpy(b,a,el_size*m*n);
586  } else {
587  for (int i=0; i<n; i++){
588  memcpy(b+el_size*lda_b*i,a+el_size*lda_a*i,m*el_size);
589  }
590  }
591  }
592 
593  void algstrct::copy(int64_t m,
594  int64_t n,
595  char const * a,
596  int64_t lda_a,
597  char const * alpha,
598  char * b,
599  int64_t lda_b,
600  char const * beta) const {
601  if (!isequal(beta, mulid())){
602  if (isequal(beta, addid())){
603  if (lda_b == 1)
604  set(b, addid(), m*n);
605  else {
606  for (int i=0; i<n; i++){
607  set(b+i*lda_b*el_size, addid(), m);
608  }
609  }
610  } else {
611  if (lda_b == m)
612  scal(m*n, beta, b, 1);
613  else {
614  for (int i=0; i<n; i++){
615  scal(m, beta, b+i*lda_b*el_size, 1);
616  }
617  }
618  }
619  }
620  if (lda_a == m && lda_b == m){
621  axpy(m*n, alpha, a, 1, b, 1);
622  } else {
623  for (int i=0; i<n; i++){
624  axpy(m, alpha, a+el_size*lda_a*i, 1, b+el_size*lda_b*i, 1);
625  }
626  }
627  }
628 
629  void algstrct::set(char * a, char const * b, int64_t n) const {
630  switch (el_size) {
631  case 4: {
632  float * ia = (float*)a;
633  float ib = *((float*)b);
634  std::fill(ia, ia+n, ib);
635  }
636  break;
637  case 8: {
638  double * ia = (double*)a;
639  double ib = *((double*)b);
640  std::fill(ia, ia+n, ib);
641  }
642  break;
643  case 16: {
644  std::complex<double> * ia = (std::complex<double>*)a;
645  std::complex<double> ib = *((std::complex<double>*)b);
646  std::fill(ia, ia+n, ib);
647  }
648  break;
649  default: {
650  for (int i=0; i<n; i++) {
651  memcpy(a+i*el_size, b, el_size);
652  }
653  }
654  break;
655  }
656  }
657 
658  void algstrct::set_pair(char * a, int64_t key, char const * vb) const {
659  memcpy(a, &key, sizeof(int64_t));
660  memcpy(get_value(a), vb, el_size);
661  }
662 
663  void algstrct::set_pairs(char * a, char const * b, int64_t n) const {
664  for (int i=0; i<n; i++) {
665  memcpy(a + i*pair_size(), b, pair_size());
666  }
667  }
668 
669  int64_t algstrct::get_key(char const * a) const {
670  return (int64_t)*a;
671  }
672 
673  char * algstrct::get_value(char * a) const {
674  return a+sizeof(int64_t);
675  }
676 
677  char const * algstrct::get_const_value(char const * a) const {
678  return a+sizeof(int64_t);
679  }
680 
681  char * algstrct::pair_alloc(int64_t n) const {
682  return (char*)CTF_int::alloc(n*pair_size());
683  }
684 
685  char * algstrct::alloc(int64_t n) const {
686  return (char*)CTF_int::alloc(n*el_size);
687  }
688 
689  void algstrct::dealloc(char * ptr) const {
690  CTF_int::cdealloc(ptr);
691  }
692 
693  void algstrct::pair_dealloc(char * ptr) const {
694  CTF_int::cdealloc(ptr);
695  }
696 
697  void algstrct::init(int64_t n, char * arr) const {
698 
699  }
700 
701 
702 
703  void algstrct::coomm(int m, int n, int k, char const * alpha, char const * A, int const * rows_A, int const * cols_A, int64_t nnz_A, char const * B, char const * beta, char * C, bivar_function const * func) const {
704  printf("CTF ERROR: coomm not present for this algebraic structure\n");
705  ASSERT(0);
706  }
707 
708  void algstrct::csrmm(int m, int n, int k, char const * alpha, char const * A, int const * JA, int const * IA, int64_t nnz_A, char const * B, char const * beta, char * C, bivar_function const * func) const {
709  printf("CTF ERROR: csrmm not present for this algebraic structure\n");
710  ASSERT(0);
711  }
712 
713  void algstrct::csrmultd
714  (int m,
715  int n,
716  int k,
717  char const * alpha,
718  char const * A,
719  int const * JA,
720  int const * IA,
721  int64_t nnz_A,
722  char const * B,
723  int const * JB,
724  int const * IB,
725  int64_t nnz_B,
726  char const * beta,
727  char * C) const {
728  printf("CTF ERROR: csrmultd not present for this algebraic structure\n");
729  ASSERT(0);
730  }
731 
732  void algstrct::csrmultcsr
733  (int m,
734  int n,
735  int k,
736  char const * alpha,
737  char const * A,
738  int const * JA,
739  int const * IA,
740  int64_t nnz_A,
741  char const * B,
742  int const * JB,
743  int const * IB,
744  int64_t nnz_B,
745  char const * beta,
746  char *& C_CSR) const {
747 
748  printf("CTF ERROR: csrmultcsr not present for this algebraic structure\n");
749  ASSERT(0);
750  }
751 
752  ConstPairIterator::ConstPairIterator(PairIterator const & pi){
753  sr=pi.sr; ptr=pi.ptr;
754  }
755 
756  ConstPairIterator::ConstPairIterator(algstrct const * sr_, char const * ptr_){
757  sr=sr_; ptr=ptr_;
758  }
759 
760  ConstPairIterator ConstPairIterator::operator[](int n) const {
761  return ConstPairIterator(sr,ptr+sr->pair_size()*n);
762  }
763 
764  int64_t ConstPairIterator::k() const {
765  return ((int64_t*)ptr)[0];
766  }
767 
768  char const * ConstPairIterator::d() const {
769  return sr->get_const_value(ptr);
770  }
771 
772  void ConstPairIterator::read(char * buf, int64_t n) const {
773  memcpy(buf, ptr, sr->pair_size()*n);
774  }
775 
776  void ConstPairIterator::read_val(char * buf) const {
777  memcpy(buf, sr->get_const_value(ptr), sr->el_size);
778  }
779 
780  PairIterator::PairIterator(algstrct const * sr_, char * ptr_){
781  sr=sr_;
782  ptr=ptr_;
783  }
784 
785  PairIterator PairIterator::operator[](int n) const {
786  return PairIterator(sr,ptr+sr->pair_size()*n);
787  }
788 
789  int64_t PairIterator::k() const {
790  return ((int64_t*)ptr)[0];
791  }
792 
793  char * PairIterator::d() const {
794  return sr->get_value(ptr);
795  }
796 
797  void PairIterator::read(char * buf, int64_t n) const {
798  sr->copy_pair(buf, ptr);
799  }
800 
801  void PairIterator::read_val(char * buf) const {
802  sr->copy(buf, sr->get_const_value(ptr));
803  }
804 
805  void PairIterator::write(char const * buf, int64_t n){
806  sr->copy_pairs(ptr, buf, n);
807  }
808 
809  void PairIterator::write(PairIterator const iter, int64_t n){
810  this->write(iter.ptr, n);
811  }
812 
813  void PairIterator::write(ConstPairIterator const iter, int64_t n){
814  this->write(iter.ptr, n);
815  }
816 
817  void PairIterator::write_val(char const * buf){
818  sr->copy(sr->get_value(ptr), buf);
819  }
820 
821  void PairIterator::write_key(int64_t key){
822  ((int64_t*)ptr)[0] = key;
823  }
824 
825  void PairIterator::sort(int64_t n){
826  sr->sort(n, ptr);
827  }
828 
829  void ConstPairIterator::permute(int64_t n, int order, int const * old_lens, int64_t const * new_lda, PairIterator wA){
830  ConstPairIterator rA = * this;
831 #ifdef USE_OMP
832  #pragma omp parallel for
833 #endif
834  for (int64_t i=0; i<n; i++){
835  int64_t k = rA[i].k();
836  int64_t k_new = 0;
837  for (int j=0; j<order; j++){
838  k_new += (k%old_lens[j])*new_lda[j];
839  k = k/old_lens[j];
840  }
841  ((int64_t*)wA[i].ptr)[0] = k_new;
842  wA[i].write_val(rA[i].d());
843  //printf("value %lf old key %ld new key %ld\n",((double*)wA[i].d())[0], rA[i].k(), wA[i].k());
844  }
845 
846 
847  }
848 
849  void ConstPairIterator::pin(int64_t n, int order, int const * lens, int const * divisor, PairIterator pi_new){
850  TAU_FSTART(pin);
851  ConstPairIterator pi = *this;
852  int * div_lens;
853  alloc_ptr(order*sizeof(int), (void**)&div_lens);
854  for (int j=0; j<order; j++){
855  div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
856 // printf("lens[%d] = %d divisor[%d] = %d div_lens[%d] = %d\n",j,lens[j],j,divisor[j],j,div_lens[j]);
857  }
858 #ifdef USE_OMP
859  #pragma omp parallel for
860 #endif
861  for (int64_t i=0; i<n; i++){
862  int64_t key = pi[i].k();
863  int64_t new_key = 0;
864  int64_t lda = 1;
865 // printf("rank = %d, in key = %ld, val = %lf\n", phys_rank[0], save_key, ((double*)pi_new[i].d())[0]);
866  for (int j=0; j<order; j++){
867 // printf("%d %ld %d\n",j,(key%lens[j])%divisor[j],phys_rank[j]);
868  //ASSERT(((key%lens[j])%(divisor[j]/virt_dim[j])) == phys_rank[j]);
869  new_key += ((key%lens[j])/divisor[j])*lda;
870  lda *= div_lens[j];
871  key = key/lens[j];
872  }
873  ((int64_t*)pi_new[i].ptr)[0] = new_key;
874 /* if (i>0 && pi[i].k() > pi[i-1].k()){
875  assert(pi_new[i].k() > pi_new[i-1].k());
876  }*/
877  }
878  cdealloc(div_lens);
879  TAU_FSTOP(pin);
880 
881  }
882 
883  void depin(algstrct const * sr, int order, int const * lens, int const * divisor, int nvirt, int const * virt_dim, int const * phys_rank, char * X, int64_t & new_nnz_B, int64_t * nnz_blk, char *& new_B, bool check_padding){
884 
885  TAU_FSTART(depin);
886 
887  int * div_lens;
888  alloc_ptr(order*sizeof(int), (void**)&div_lens);
889  for (int j=0; j<order; j++){
890  div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
891 // printf("lens[%d] = %d divisor[%d] = %d div_lens[%d] = %d\n",j,lens[j],j,divisor[j],j,div_lens[j]);
892  }
893  if (check_padding){
894  check_padding = false;
895  for (int v=0; v<nvirt; v++){
896  int vv = v;
897  for (int j=0; j<order; j++){
898  int vo = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
899  if (lens[j]%divisor[j] != 0 && vo >= lens[j]%divisor[j]){
900  check_padding = true;
901  }
902  vv=vv/virt_dim[j];
903  }
904  }
905  }
906  int64_t * old_nnz_blk_B = nnz_blk;
907  if (check_padding){
908  //FIXME: uses a bit more memory then we will probably need, but probably worth not doing another round to count first
909  new_B = sr->pair_alloc(new_nnz_B);
910  old_nnz_blk_B = (int64_t*)CTF_int::alloc(sizeof(int64_t)*nvirt);
911  memcpy(old_nnz_blk_B, nnz_blk, sizeof(int64_t)*nvirt);
912  memset(nnz_blk, 0, sizeof(int64_t)*nvirt);
913  }
914 
915  int * virt_offset;
916  alloc_ptr(order*sizeof(int), (void**)&virt_offset);
917  int64_t nnz_off = 0;
918  if (check_padding)
919  new_nnz_B = 0;
920  for (int v=0; v<nvirt; v++){
921  //printf("%d %p new_B %p pin %p new_blk_nnz_B[%d] = %ld\n",A_or_B,this,new_B,nnz_blk,v,nnz_blk[v]);
922  int vv=v;
923  for (int j=0; j<order; j++){
924  virt_offset[j] = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
925  vv=vv/virt_dim[j];
926  }
927 
928  if (check_padding){
929  int64_t new_nnz_blk = 0;
930  ConstPairIterator vpi(sr, X+nnz_off*sr->pair_size());
931  PairIterator vpi_new(sr, new_B+new_nnz_B*sr->pair_size());
932  for (int64_t i=0; i<old_nnz_blk_B[v]; i++){
933  int64_t key = vpi[i].k();
934  int64_t new_key = 0;
935  int64_t lda = 1;
936  bool is_outside = false;
937  for (int j=0; j<order; j++){
938  //printf("%d %ld %ld %d\n",j,vpi[i].k(),((key%div_lens[j])*divisor[j]+virt_offset[j]),lens[j]);
939  if (((key%div_lens[j])*divisor[j]+virt_offset[j])>=lens[j]){
940  //printf("element is outside\n");
941  is_outside = true;
942  }
943  new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
944  lda *= lens[j];
945  key = key/div_lens[j];
946  }
947  if (!is_outside){
948  //printf("key = %ld, new_key = %ld, val = %lf\n", vpi[i].k(), new_key, ((double*)vpi[i].d())[0]);
949  ((int64_t*)vpi_new[new_nnz_blk].ptr)[0] = new_key;
950  vpi_new[new_nnz_blk].write_val(vpi[i].d());
951  new_nnz_blk++;
952  }
953  }
954  nnz_blk[v] = new_nnz_blk;
955  new_nnz_B += nnz_blk[v];
956  nnz_off += old_nnz_blk_B[v];
957 
958  } else {
959  ConstPairIterator vpi(sr, X+nnz_off*sr->pair_size());
960  PairIterator vpi_new(sr, X+nnz_off*sr->pair_size());
961  #ifdef USE_OMP
962  #pragma omp parallel for
963  #endif
964  for (int64_t i=0; i<nnz_blk[v]; i++){
965  int64_t key = vpi[i].k();
966  //int64_t save_key = vpi[i].k();
967  int64_t new_key = 0;
968  int64_t lda = 1;
969  for (int64_t j=0; j<order; j++){
970  new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
971  lda *= lens[j];
972  key = key/div_lens[j];
973  }
974  ((int64_t*)vpi_new[i].ptr)[0] = new_key;
975  //printf(",,key = %ld, new_key = %ld, val = %lf\n", save_key, new_key, ((double*)vpi_new[i].d())[0]);
976  }
977  nnz_off += nnz_blk[v];
978  }
979  }
980  if (check_padding){
981  cdealloc(old_nnz_blk_B);
982  }
983  cdealloc(virt_offset);
984  cdealloc(div_lens);
985 
986  TAU_FSTOP(depin);
987  }
988 
989 
990 
991  int64_t PairIterator::lower_bound(int64_t n, ConstPairIterator op){
992  switch (sr->el_size){
993  case 1:
994  return std::lower_bound((CompPair<1>*)ptr,((CompPair<1>*)ptr)+n, ((CompPair<1>*)op.ptr)[0]) - (CompPair<1>*)ptr;
995  break;
996  case 2:
997  return std::lower_bound((CompPair<2>*)ptr,((CompPair<2>*)ptr)+n, ((CompPair<2>*)op.ptr)[0]) - (CompPair<2>*)ptr;
998  break;
999  case 4:
1000  return std::lower_bound((CompPair<4>*)ptr,((CompPair<4>*)ptr)+n, ((CompPair<4>*)op.ptr)[0]) - (CompPair<4>*)ptr;
1001  break;
1002  case 8:
1003  return std::lower_bound((CompPair<8>*)ptr,((CompPair<8>*)ptr)+n, ((CompPair<8>*)op.ptr)[0]) - (CompPair<8>*)ptr;
1004  break;
1005  case 12:
1006  return std::lower_bound((CompPair<12>*)ptr,((CompPair<12>*)ptr)+n, ((CompPair<12>*)op.ptr)[0]) - (CompPair<12>*)ptr;
1007  break;
1008  case 16:
1009  return std::lower_bound((CompPair<16>*)ptr,((CompPair<16>*)ptr)+n, ((CompPair<16>*)op.ptr)[0]) - (CompPair<16>*)ptr;
1010  break;
1011  case 20:
1012  return std::lower_bound((CompPair<20>*)ptr,((CompPair<20>*)ptr)+n, ((CompPair<20>*)op.ptr)[0]) - (CompPair<20>*)ptr;
1013  break;
1014  case 24:
1015  return std::lower_bound((CompPair<24>*)ptr,((CompPair<24>*)ptr)+n, ((CompPair<24>*)op.ptr)[0]) - (CompPair<24>*)ptr;
1016  break;
1017  case 28:
1018  return std::lower_bound((CompPair<28>*)ptr,((CompPair<28>*)ptr)+n, ((CompPair<28>*)op.ptr)[0]) - (CompPair<28>*)ptr;
1019  break;
1020  case 32:
1021  return std::lower_bound((CompPair<32>*)ptr,((CompPair<32>*)ptr)+n, ((CompPair<32>*)op.ptr)[0]) - (CompPair<32>*)ptr;
1022  break;
1023  default: {
1024  int64_t keys[n];
1025 #ifdef USE_OMP
1026  #pragma omp parallel for
1027 #endif
1028  for (int64_t i=0; i<n; i++){
1029  keys[i] = (*this)[i].k();
1030  }
1031  return std::lower_bound(keys, keys+n, op.k())-keys;
1032  } break;
1033  }
1034  }
1035 
1036 }
void permute(int order, int const *perm, int *arr)
permute an array
Definition: util.cxx:205
virtual int pair_size() const
gets pair size el_size plus the key size
Definition: algstrct.h:46
virtual char * pair_alloc(int64_t n) const
allocate space for n (int64_t,dtype) pairs, necessary for object types
Definition: algstrct.cxx:681
algstrct const * sr
Definition: algstrct.h:436
static char * csr_add(char *cA, char *cB, accumulatable const *adder)
Definition: csr.cxx:332
LinModel< 3 > csrred_mdl_cst(csrred_mdl_cst_init,"csrred_mdl_cst")
LinModel< 3 > csrred_mdl(csrred_mdl_init,"csrred_mdl")
void SCOPY(const int *n, const float *dX, const int *incX, float *dY, const int *incY)
#define ASSERT(...)
Definition: util.h:88
void * alloc(int64_t len)
alloc abstraction
Definition: memcontrol.cxx:365
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:82
untyped internal class for triply-typed bivariate function
Definition: ctr_comm.h:16
void DCOPY(const int *n, const double *dX, const int *incX, double *dY, const int *incY)
void depin(algstrct const *sr, int order, int const *lens, int const *divisor, int nvirt, int const *virt_dim, int const *phys_rank, char *X, int64_t &new_nnz_B, int64_t *nnz_blk, char *&new_B, bool check_padding)
depins keys of n pairs
Definition: algstrct.cxx:883
int64_t k() const
returns key of pair at head of ptr
Definition: algstrct.cxx:764
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
Definition: memcontrol.cxx:320
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
Definition: csr.h:22
int64_t k() const
returns key of pair at head of ptr
Definition: algstrct.cxx:789
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:15
double csrred_mdl_init[]
Definition: init_models.cxx:38
#define TAU_FSTOP(ARG)
Definition: util.h:281
#define TAU_FSTART(ARG)
Definition: util.h:280
def copy(tensor, A)
Definition: core.pyx:3583
char * all_data
serialized buffer containing all info, index, and values related to matrix
Definition: csr.h:25
int cdealloc(void *ptr)
free abstraction
Definition: memcontrol.cxx:480
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
void write_val(char const *buf)
sets value of head pair to what is in buf
Definition: algstrct.cxx:817
void ZCOPY(const int *n, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void partition(int s, char **parts_buffer, CSR_Matrix **parts)
splits CSR matrix into s submatrices (returned) corresponding to subsets of rows, all parts allocated...
Definition: csr.cxx:219
double csrred_mdl_cst_init[]
Definition: init_models.cxx:39
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
int64_t key
Definition: back_comp.h:66
int64_t size() const
retrieves buffer size out of all_data
Definition: csr.cxx:89