Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
sparse_rw.cxx
Go to the documentation of this file.
1 /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
2 
3 #include "sparse_rw.h"
4 #include "pad.h"
5 #include "../shared/util.h"
6 
7 
8 
9 namespace CTF_int {
10  void permute_keys(int order,
11  int num_pair,
12  int const * edge_len,
13  int const * new_edge_len,
14  int * const * permutation,
15  char * pairs_buf,
16  int64_t * new_num_pair,
17  algstrct const * sr){
19  #ifdef USE_OMP
20  int mntd = omp_get_max_threads();
21  #else
22  int mntd = 1;
23  #endif
24  int64_t counts[mntd];
25  std::fill(counts,counts+mntd,0);
26  #ifdef USE_OMP
27  #pragma omp parallel
28  #endif
29  {
30  int i, j, tid, ntd, outside;
31  int64_t lda, wkey, knew, kdim, tstart, tnum_pair, cnum_pair;
32  #ifdef USE_OMP
33  tid = omp_get_thread_num();
34  ntd = omp_get_num_threads();
35  #else
36  tid = 0;
37  ntd = 1;
38  #endif
39  tnum_pair = num_pair/ntd;
40  tstart = tnum_pair * tid + MIN(tid, num_pair % ntd);
41  if (tid < num_pair % ntd) tnum_pair++;
42 
43  //std::vector< tkv_pair<dtype> > my_pairs;
44  //allocate buffer of same size of pairs,
45  //FIXME: not all space may be used, so a smaller buffer is possible
46  char * my_pairs_buf = sr->pair_alloc(tnum_pair);
47  PairIterator my_pairs(sr, my_pairs_buf);
48  PairIterator pairs(sr, pairs_buf);
49  cnum_pair = 0;
50 
51  for (i=tstart; i<tstart+tnum_pair; i++){
52  wkey = pairs[i].k();
53  lda = 1;
54  knew = 0;
55  outside = 0;
56  for (j=0; j<order; j++){
57  kdim = wkey%edge_len[j];
58  if (permutation[j] != NULL){
59  if (permutation[j][kdim] == -1){
60  outside = 1;
61  } else{
62  knew += lda*permutation[j][kdim];
63  }
64  } else {
65  knew += lda*kdim;
66  }
67  lda *= new_edge_len[j];
68  wkey = wkey/edge_len[j];
69  }
70  if (!outside){
71  char tkp[sr->pair_size()];
72  //tkp.k = knew;
73  //tkp.d = pairs[i].d;
74  sr->set_pair(tkp, knew, pairs[i].d());
75  my_pairs[cnum_pair].write(tkp);
76  cnum_pair++;
77  }
78  }
79  counts[tid] = cnum_pair;
80  {
81  #ifdef USE_OMP
82  #pragma omp barrier
83  #endif
84  int64_t pfx = 0;
85  for (i=0; i<tid; i++){
86  pfx += counts[i];
87  }
88  pairs[pfx].write(my_pairs,cnum_pair);
89  }
90  sr->pair_dealloc(my_pairs_buf);
91  }
92  *new_num_pair = 0;
93  for (int i=0; i<mntd; i++){
94  *new_num_pair += counts[i];
95  }
97  }
98 
99  void depermute_keys(int order,
100  int num_pair,
101  int const * edge_len,
102  int const * new_edge_len,
103  int * const * permutation,
104  char * pairs_buf,
105  algstrct const * sr){
107  #ifdef USE_OMP
108  int mntd = omp_get_max_threads();
109  #else
110  int mntd = 1;
111  #endif
112  int64_t counts[mntd];
113  std::fill(counts,counts+mntd,0);
114  int ** depermutation = (int**)CTF_int::alloc(order*sizeof(int*));
115  TAU_FSTART(form_depermutation);
116  for (int d=0; d<order; d++){
117  if (permutation[d] == NULL){
118  depermutation[d] = NULL;
119  } else {
120  depermutation[d] = (int*)CTF_int::alloc(new_edge_len[d]*sizeof(int));
121  std::fill(depermutation[d],depermutation[d]+new_edge_len[d], -1);
122  for (int i=0; i<edge_len[d]; i++){
123  if (permutation[d][i] > -1)
124  depermutation[d][permutation[d][i]] = i;
125  }
126  }
127  }
128  TAU_FSTOP(form_depermutation);
129  #ifdef USE_OMP
130  #pragma omp parallel
131  #endif
132  {
133  int i, j, tid, ntd;
134  int64_t lda, wkey, knew, kdim, tstart, tnum_pair;
135  #ifdef USE_OMP
136  tid = omp_get_thread_num();
137  ntd = omp_get_num_threads();
138  #else
139  tid = 0;
140  ntd = 1;
141  #endif
142  tnum_pair = num_pair/ntd;
143  tstart = tnum_pair * tid + MIN(tid, num_pair % ntd);
144  if (tid < num_pair % ntd) tnum_pair++;
145 
146  char * my_pairs_buf = (char*)alloc(sr->pair_size()*tnum_pair);
147  //char my_pairs_buf[sr->pair_size()*tnum_pair];
148 
149  PairIterator my_pairs(sr, my_pairs_buf);
150  PairIterator pairs(sr, pairs_buf);
151 
152  for (i=tstart; i<tstart+tnum_pair; i++){
153  wkey = pairs[i].k();
154  lda = 1;
155  knew = 0;
156  for (j=0; j<order; j++){
157  kdim = wkey%new_edge_len[j];
158  if (depermutation[j] != NULL){
159  ASSERT(depermutation[j][kdim] != -1);
160  knew += lda*depermutation[j][kdim];
161  } else {
162  knew += lda*kdim;
163  }
164  lda *= edge_len[j];
165  wkey = wkey/new_edge_len[j];
166  }
167  pairs[i].write_key(knew);
168  }
169  }
170  for (int d=0; d<order; d++){
171  if (permutation[d] != NULL)
172  CTF_int::cdealloc(depermutation[d]);
173  }
174  CTF_int::cdealloc(depermutation);
175 
177  }
178 
179 
180  void assign_keys(int order,
181  int64_t size,
182  int nvirt,
183  int const * edge_len,
184  int const * sym,
185  int const * phase,
186  int const * phys_phase,
187  int const * virt_dim,
188  int * phase_rank,
189  char const * vdata,
190  char * vpairs,
191  algstrct const * sr){
192  int i, imax, act_lda, act_max;
193  int64_t p, idx_offset, buf_offset;
194  int * idx, * virt_rank;
195  int64_t * edge_lda;
196  if (order == 0){
197  ASSERT(size <= 1);
198  if (size == 1){
199  sr->set_pair(vpairs, 0, vdata);
200  }
201  return;
202  }
203 
205  CTF_int::alloc_ptr(order*sizeof(int), (void**)&idx);
206  CTF_int::alloc_ptr(order*sizeof(int), (void**)&virt_rank);
207  CTF_int::alloc_ptr(order*sizeof(int64_t), (void**)&edge_lda);
208 
209  memset(virt_rank, 0, sizeof(int)*order);
210 
211  edge_lda[0] = 1;
212  for (i=1; i<order; i++){
213  edge_lda[i] = edge_lda[i-1]*edge_len[i-1];
214  }
215  for (p=0;;p++){
216  char const * data = vdata + sr->el_size*p*(size/nvirt);
217  PairIterator pairs = PairIterator(sr, vpairs + sr->pair_size()*p*(size/nvirt));
218 
219  idx_offset = 0, buf_offset = 0;
220  for (act_lda=1; act_lda<order; act_lda++){
221  idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
222  }
223 
224  //printf("size = %d\n", size);
225  memset(idx, 0, order*sizeof(int));
226  imax = edge_len[0]/phase[0];
227  for (;;){
228  if (sym[0] != NS)
229  imax = idx[1]+1;
230  /* Increment virtual bucket */
231  for (i=0; i<imax; i++){
232  ASSERT(buf_offset+i<size);
233  if (p*(size/nvirt) + buf_offset + i >= size){
234  printf("exceeded how much I was supposed to read %ld/%ld\n", p*(size/nvirt)+buf_offset+i,size);
235  ABORT;
236  }
237  pairs[buf_offset+i].write_key(idx_offset+i*phase[0]+phase_rank[0]);
238  pairs[buf_offset+i].write_val(data+(buf_offset+i)*sr->el_size);
239  }
240  buf_offset += imax;
241  /* Increment indices and set up offsets */
242  for (act_lda=1; act_lda < order; act_lda++){
243  idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
244  *edge_lda[act_lda];
245  idx[act_lda]++;
246  act_max = edge_len[act_lda]/phase[act_lda];
247  if (sym[act_lda] != NS) act_max = idx[act_lda+1]+1;
248  if (idx[act_lda] >= act_max)
249  idx[act_lda] = 0;
250  idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
251  *edge_lda[act_lda];
252  ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
253  if (idx[act_lda] > 0)
254  break;
255  }
256  if (act_lda >= order) break;
257  }
258  for (act_lda=0; act_lda < order; act_lda++){
259  phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
260  virt_rank[act_lda]++;
261  if (virt_rank[act_lda] >= virt_dim[act_lda])
262  virt_rank[act_lda] = 0;
263  phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
264  if (virt_rank[act_lda] > 0)
265  break;
266  }
267  if (act_lda >= order) break;
268  }
269  ASSERT(buf_offset == size/nvirt);
270  CTF_int::cdealloc(idx);
271  CTF_int::cdealloc(virt_rank);
272  CTF_int::cdealloc(edge_lda);
274  }
275 
276  void spsfy_tsr(int order,
277  int64_t size,
278  int nvirt,
279  int const * edge_len,
280  int const * sym,
281  int const * phase,
282  int const * phys_phase,
283  int const * virt_dim,
284  int * phase_rank,
285  char const * vdata,
286  char *& vpairs,
287  int64_t * nnz_blk,
288  algstrct const * sr,
289  int64_t const * edge_lda,
290  std::function<bool(char const*)> f){
291  int i, imax, act_lda, act_max;
292  int64_t p, idx_offset, buf_offset;
293  int * idx, * virt_rank;
294  memset(nnz_blk, 0, sizeof(int64_t)*nvirt);
295  if (order == 0){
296  ASSERT(size <= 1);
297  if (size == 1){
298  if (f(vdata)){
299  vpairs = sr->pair_alloc(1);
300  nnz_blk[0] = 1;
301  sr->set_pair(vpairs, 0, vdata);
302  } else vpairs = NULL;
303  }
304  return;
305  }
306 
308  CTF_int::alloc_ptr(order*sizeof(int), (void**)&idx);
309  CTF_int::alloc_ptr(order*sizeof(int), (void**)&virt_rank);
310 
311  memset(virt_rank, 0, sizeof(int)*order);
312  bool * keep_vals;
313  CTF_int::alloc_ptr(size*sizeof(bool), (void**)&keep_vals);
314 
315  int virt_blk = 0;
316  for (p=0;;p++){
317  char const * data = vdata + sr->el_size*p*(size/nvirt);
318 
319  buf_offset = 0;
320 
321  memset(idx, 0, order*sizeof(int));
322  imax = edge_len[0]/phase[0];
323  for (;;){
324  if (sym[0] != NS)
325  imax = idx[1]+1;
326  /* Increment virtual bucket */
327  for (i=0; i<imax; i++){
328  ASSERT(buf_offset+i<size);
329  keep_vals[buf_offset+i] = f(data+(buf_offset+i)*sr->el_size);
330  nnz_blk[virt_blk] += keep_vals[buf_offset+i];
331  }
332  buf_offset += imax;
333  /* Increment indices and set up offsets */
334  for (act_lda=1; act_lda < order; act_lda++){
335  idx[act_lda]++;
336  act_max = edge_len[act_lda]/phase[act_lda];
337  if (sym[act_lda] != NS) act_max = idx[act_lda+1]+1;
338  if (idx[act_lda] >= act_max)
339  idx[act_lda] = 0;
340  ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
341  if (idx[act_lda] > 0)
342  break;
343  }
344  if (act_lda >= order) break;
345  }
346  for (act_lda=0; act_lda < order; act_lda++){
347  phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
348  virt_rank[act_lda]++;
349  if (virt_rank[act_lda] >= virt_dim[act_lda])
350  virt_rank[act_lda] = 0;
351  phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
352  if (virt_rank[act_lda] > 0)
353  break;
354  }
355  virt_blk++;
356  if (act_lda >= order) break;
357  }
358  int64_t * nnz_blk_lda = (int64_t*)alloc(sizeof(int64_t)*nvirt);
359  nnz_blk_lda[0]=0;
360  for (int i=1; i<nvirt; i++){
361  nnz_blk_lda[i] = nnz_blk_lda[i-1]+nnz_blk[i-1];
362  }
363  vpairs = sr->pair_alloc(nnz_blk_lda[nvirt-1]+nnz_blk[nvirt-1]);
364 
365 
366  memset(nnz_blk, 0, sizeof(int64_t)*nvirt);
367  virt_blk = 0;
368  for (p=0;;p++){
369  char const * data = vdata + sr->el_size*p*(size/nvirt);
370  PairIterator pairs = PairIterator(sr, vpairs + sr->pair_size()*nnz_blk_lda[virt_blk]);
371 
372  idx_offset = 0, buf_offset = 0;
373  for (act_lda=1; act_lda<order; act_lda++){
374  idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
375  }
376 
377 
378  memset(idx, 0, order*sizeof(int));
379  imax = edge_len[0]/phase[0];
380  for (;;){
381  if (sym[0] != NS)
382  imax = idx[1]+1;
383  /* Increment virtual bucket */
384  for (i=0; i<imax; i++){
385  ASSERT(buf_offset+i<size);
386  if (keep_vals[buf_offset+i]){
387 
388  pairs[nnz_blk[virt_blk]].write_key(idx_offset+i*phase[0]+phase_rank[0]);
389  pairs[nnz_blk[virt_blk]].write_val(data+(buf_offset+i)*sr->el_size);
390  nnz_blk[virt_blk]++;
391  }
392  }
393  buf_offset += imax;
394  /* Increment indices and set up offsets */
395  for (act_lda=1; act_lda < order; act_lda++){
396  idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
397  *edge_lda[act_lda];
398  idx[act_lda]++;
399  act_max = edge_len[act_lda]/phase[act_lda];
400  if (sym[act_lda] != NS) act_max = idx[act_lda+1]+1;
401  if (idx[act_lda] >= act_max)
402  idx[act_lda] = 0;
403  idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
404  *edge_lda[act_lda];
405  ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
406  if (idx[act_lda] > 0)
407  break;
408  }
409  if (act_lda >= order) break;
410  }
411  for (act_lda=0; act_lda < order; act_lda++){
412  phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
413  virt_rank[act_lda]++;
414  if (virt_rank[act_lda] >= virt_dim[act_lda])
415  virt_rank[act_lda] = 0;
416  phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
417  if (virt_rank[act_lda] > 0)
418  break;
419  }
420  virt_blk++;
421  if (act_lda >= order) break;
422  }
423 
424  CTF_int::cdealloc(keep_vals);
425  CTF_int::cdealloc(nnz_blk_lda);
426  CTF_int::cdealloc(idx);
427  CTF_int::cdealloc(virt_rank);
429  }
430 
431 
432  void bucket_by_pe(int order,
433  int64_t num_pair,
434  int64_t np,
435  int const * phys_phase,
436  int const * virt_phase,
437  int const * bucket_lda,
438  int const * edge_len,
439  ConstPairIterator mapped_data,
440  int64_t * bucket_counts,
441  int64_t * bucket_off,
442  PairIterator bucket_data,
443  algstrct const * sr){
444 
445  memset(bucket_counts, 0, sizeof(int64_t)*np);
446  #ifdef USE_OMP
447  int64_t * sub_counts, * sub_offs;
448  CTF_int::alloc_ptr(np*sizeof(int64_t)*omp_get_max_threads(), (void**)&sub_counts);
449  CTF_int::alloc_ptr(np*sizeof(int64_t)*omp_get_max_threads(), (void**)&sub_offs);
450  memset(sub_counts, 0, np*sizeof(int64_t)*omp_get_max_threads());
451  #endif
452 
453 
454  TAU_FSTART(bucket_by_pe_count);
455  /* Calculate counts */
456  #ifdef USE_OMP
457  #pragma omp parallel for schedule(static,256)
458  #endif
459  for (int64_t i=0; i<num_pair; i++){
460  int64_t k = mapped_data[i].k();
461  int64_t loc = 0;
462  // int tmp_arr[order];
463  for (int j=0; j<order; j++){
464  /* tmp_arr[j] = (k%edge_len[j])%phase[j];
465  tmp_arr[j] = tmp_arr[j]/virt_phase[j];
466  tmp_arr[j] = tmp_arr[j]*bucket_lda[j];*/
467  //FIXME: fine for dense but need extra mod for sparse :(
468  //loc += (k%phys_phase[j])*bucket_lda[j];
469  loc += ((k%edge_len[j])%phys_phase[j])*bucket_lda[j];
470  k = k/edge_len[j];
471  }
472  /* for (j=0; j<order; j++){
473  loc += tmp_arr[j];
474  }*/
475  ASSERT(loc<np);
476  #ifdef USE_OMP
477  sub_counts[loc+omp_get_thread_num()*np]++;
478  #else
479  bucket_counts[loc]++;
480  #endif
481  }
482  TAU_FSTOP(bucket_by_pe_count);
483 
484  #ifdef USE_OMP
485  for (int j=0; j<omp_get_max_threads(); j++){
486  for (int64_t i=0; i<np; i++){
487  bucket_counts[i] = sub_counts[j*np+i] + bucket_counts[i];
488  }
489  }
490  #endif
491 
492  /* Prefix sum to get offsets */
493  bucket_off[0] = 0;
494  for (int64_t i=1; i<np; i++){
495  bucket_off[i] = bucket_counts[i-1] + bucket_off[i-1];
496  }
497 
498  /* reset counts */
499  #ifdef USE_OMP
500  memset(sub_offs, 0, sizeof(int64_t)*np);
501  for (int i=1; i<omp_get_max_threads(); i++){
502  for (int64_t j=0; j<np; j++){
503  sub_offs[j+i*np]=sub_counts[j+(i-1)*np]+sub_offs[j+(i-1)*np];
504  }
505  }
506  #else
507  memset(bucket_counts, 0, sizeof(int64_t)*np);
508  #endif
509 
510  /* bucket data */
511  TAU_FSTART(bucket_by_pe_move);
512  #ifdef USE_OMP
513  #pragma omp parallel for schedule(static,256)
514  #endif
515  for (int64_t i=0; i<num_pair; i++){
516  int64_t k = mapped_data[i].k();
517  int64_t loc = 0;
518  for (int j=0; j<order; j++){
519  //FIXME: fine for dense but need extra mod for sparse :(
520  //loc += (k%phys_phase[j])*bucket_lda[j];
521  loc += ((k%edge_len[j])%phys_phase[j])*bucket_lda[j];
522  k = k/edge_len[j];
523  }
524  #ifdef USE_OMP
525  bucket_data[bucket_off[loc] + sub_offs[loc+omp_get_thread_num()*np]].write(mapped_data[i]);
526  sub_offs[loc+omp_get_thread_num()*np]++;
527  #else
528  bucket_data[bucket_off[loc] + bucket_counts[loc]].write(mapped_data[i]);
529  bucket_counts[loc]++;
530  #endif
531  }
532  #ifdef USE_OMP
533  CTF_int::cdealloc(sub_counts);
534  CTF_int::cdealloc(sub_offs);
535  #endif
536  TAU_FSTOP(bucket_by_pe_move);
537  }
538 
539  int64_t * bucket_by_virt(int order,
540  int num_virt,
541  int64_t num_pair,
542  int const * phys_phase,
543  int const * virt_phase,
544  int const * edge_len,
545  ConstPairIterator mapped_data,
546  PairIterator bucket_data,
547  algstrct const * sr){
548  int64_t * virt_counts, * virt_prefix, * virt_lda;
550 
551  CTF_int::alloc_ptr(num_virt*sizeof(int64_t), (void**)&virt_counts);
552  CTF_int::alloc_ptr(num_virt*sizeof(int64_t), (void**)&virt_prefix);
553  CTF_int::alloc_ptr(order*sizeof(int64_t), (void**)&virt_lda);
554 
555 
556  if (order > 0){
557  virt_lda[0] = 1;
558  for (int i=1; i<order; i++){
559  ASSERT(virt_phase[i] > 0);
560  virt_lda[i] = virt_phase[i-1]*virt_lda[i-1];
561  }
562  }
563 
564  memset(virt_counts, 0, sizeof(int64_t)*num_virt);
565  #ifdef USE_OMP
566  int64_t * sub_counts, * sub_offs;
567  CTF_int::alloc_ptr(num_virt*sizeof(int64_t)*omp_get_max_threads(), (void**)&sub_counts);
568  CTF_int::alloc_ptr(num_virt*sizeof(int64_t)*omp_get_max_threads(), (void**)&sub_offs);
569  memset(sub_counts, 0, num_virt*sizeof(int64_t)*omp_get_max_threads());
570  #endif
571 
572 
573  /* bucket data */
574  #ifdef USE_OMP
575  TAU_FSTART(bucket_by_virt_omp_cnt);
576  #pragma omp parallel for schedule(static)
577  for (int64_t i=0; i<num_pair; i++){
578  int64_t k = mapped_data[i].k();
579  int64_t loc = 0;
580  //#pragma unroll
581  for (int j=0; j<order; j++){
582  //FIXME: fine for dense but need extra mod for sparse :(
583  //loc += ((k/phys_phase[j])%virt_phase[j])*virt_lda[j];
584  loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
585  k = k/edge_len[j];
586  }
587 
588  //bucket_data[loc*num_pair_virt + virt_counts[loc]] = mapped_data[i];
589  sub_counts[loc+omp_get_thread_num()*num_virt]++;
590  }
591  TAU_FSTOP(bucket_by_virt_omp_cnt);
592  TAU_FSTART(bucket_by_virt_assemble_offsets);
593  for (int j=0; j<omp_get_max_threads(); j++){
594  for (int64_t i=0; i<num_virt; i++){
595  virt_counts[i] = sub_counts[j*num_virt+i] + virt_counts[i];
596  }
597  }
598  virt_prefix[0] = 0;
599  for (int64_t i=1; i<num_virt; i++){
600  virt_prefix[i] = virt_prefix[i-1] + virt_counts[i-1];
601  }
602 
603  memset(sub_offs, 0, sizeof(int64_t)*num_virt);
604  for (int i=1; i<omp_get_max_threads(); i++){
605  for (int64_t j=0; j<num_virt; j++){
606  sub_offs[j+i*num_virt]=sub_counts[j+(i-1)*num_virt]+sub_offs[j+(i-1)*num_virt];
607  }
608  }
609  TAU_FSTOP(bucket_by_virt_assemble_offsets);
610  TAU_FSTART(bucket_by_virt_move);
611  #pragma omp parallel for schedule(static)
612  for (int64_t i=0; i<num_pair; i++){
613  int64_t k = mapped_data[i].k();
614  int64_t loc = 0;
615  //#pragma unroll
616  for (int j=0; j<order; j++){
617  //FIXME: fine for dense but need extra mod for sparse :(
618  //loc += ((k/phys_phase[j])%virt_phase[j])*virt_lda[j];
619  loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
620  k = k/edge_len[j];
621  }
622  bucket_data[virt_prefix[loc] + sub_offs[loc+omp_get_thread_num()*num_virt]].write(mapped_data[i]);
623  sub_offs[loc+omp_get_thread_num()*num_virt]++;
624  }
625  TAU_FSTOP(bucket_by_virt_move);
626  #else
627  for (int64_t i=0; i<num_pair; i++){
628  int64_t k = mapped_data[i].k();
629  int64_t loc = 0;
630  for (int j=0; j<order; j++){
631  //FIXME: fine for dense but need extra mod for sparse :(
632  //loc += ((k/phys_phase[j])%virt_phase[j])*virt_lda[j];
633  loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
634  k = k/edge_len[j];
635  }
636  virt_counts[loc]++;
637  }
638 
639  virt_prefix[0] = 0;
640  for (int64_t i=1; i<num_virt; i++){
641  virt_prefix[i] = virt_prefix[i-1] + virt_counts[i-1];
642  }
643  memset(virt_counts, 0, sizeof(int64_t)*num_virt);
644 
645  for (int64_t i=0; i<num_pair; i++){
646  int64_t k = mapped_data[i].k();
647  int64_t loc = 0;
648  for (int j=0; j<order; j++){
649  //FIXME: fine for dense but need extra mod for sparse :(
650  //loc += ((k/phys_phase[j])%virt_phase[j])*virt_lda[j];
651  loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
652  k = k/edge_len[j];
653  }
654  bucket_data[virt_prefix[loc] + virt_counts[loc]].write(mapped_data[i]);
655  virt_counts[loc]++;
656  }
657  #endif
658 
659  TAU_FSTART(bucket_by_virt_sort);
660  #ifdef USE_OMP
661  #pragma omp parallel for
662  #endif
663  for (int64_t i=0; i<num_virt; i++){
664  /*std::sort(bucket_data+virt_prefix[i],
665  bucket_data+(virt_prefix[i]+virt_counts[i]));*/
666  bucket_data[virt_prefix[i]].sort(virt_counts[i]);
667  }
668  TAU_FSTOP(bucket_by_virt_sort);
669  #if DEBUG >= 1
670  // FIXME: Can we handle replicated keys?
671  /* for (i=1; i<num_pair; i++){
672  ASSERT(bucket_data[i].k != bucket_data[i-1].k);
673  }*/
674  #endif
675  #ifdef USE_OMP
676  CTF_int::cdealloc(sub_counts);
677  CTF_int::cdealloc(sub_offs);
678  #endif
679  CTF_int::cdealloc(virt_prefix);
680  CTF_int::cdealloc(virt_lda);
682  return virt_counts;
683  }
684 
685  void readwrite(int order,
686  int64_t size,
687  char const * alpha,
688  char const * beta,
689  int nvirt,
690  int const * edge_len,
691  int const * sym,
692  int const * phase,
693  int const * phys_phase,
694  int const * virt_dim,
695  int * phase_rank,
696  char * vdata,
697  char * pairs_buf,
698  char rw,
699  algstrct const * sr){
700  int act_lda;
701  int64_t idx_offset, act_max, buf_offset, pr_offset, p;
702  int64_t * idx, * virt_rank, * edge_lda;
703 
704  PairIterator pairs = PairIterator(sr, pairs_buf);
705 
706  if (order == 0){
707  if (size > 0){
708  if (size > 1){
709  for (int64_t i=1; i<size; i++){
710  //check for write conflicts
711  //FIXME this makes sense how again?
712  ASSERT(pairs[i].k() == 0 || pairs[i].d() != pairs[0].d());
713  }
714  }
715  // printf("size = " PRId64 "\n",size);
716  // ASSERT(size == 1);
717  if (rw == 'r'){
718  pairs[0].write_val(vdata);
719  } else {
720  //vdata[0] = pairs[0].d;
721  pairs[0].read_val(vdata);
722  }
723  }
724  return;
725  }
727  CTF_int::alloc_ptr(order*sizeof(int64_t), (void**)&idx);
728  CTF_int::alloc_ptr(order*sizeof(int64_t), (void**)&virt_rank);
729  CTF_int::alloc_ptr(order*sizeof(int64_t), (void**)&edge_lda);
730 
731  memset(virt_rank, 0, sizeof(int64_t)*order);
732  edge_lda[0] = 1;
733  for (int i=1; i<order; i++){
734  edge_lda[i] = edge_lda[i-1]*edge_len[i-1];
735  }
736 
737  pr_offset = 0;
738  buf_offset = 0;
739  char * data = vdata;// + buf_offset;
740 
741  for (p=0;;p++){
742  data = data + sr->el_size*buf_offset;
743  idx_offset = 0, buf_offset = 0;
744  for (act_lda=1; act_lda<order; act_lda++){
745  idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
746  }
747 
748  memset(idx, 0, order*sizeof(int64_t));
749  int64_t imax = edge_len[0]/phase[0];
750  for (;;){
751  if (sym[0] != NS)
752  imax = idx[1]+1;
753  /* Increment virtual bucket */
754  for (int64_t i=0; i<imax;){// i++){
755  if (pr_offset >= size)
756  break;
757  else {
758  if (pairs[pr_offset].k() == idx_offset +i*phase[0]+phase_rank[0]){
759  if (rw == 'r'){
760  if (alpha == NULL){
761  pairs[pr_offset].write_val(data+sr->el_size*(buf_offset+i));
762 /* if (sr->isbeta == 0.0)
763  char wval[sr->pair_size()];
764  sr->mul(alpha,data + sr->el_size*(buf_offset+i), wval);
765  pairs[pr_offset].write_val(wval);*/
766  } else {
767  /* should it be the opposite? No, because 'pairs' was passed in and 'data' is being added to pairs, so data is operand, gets alpha. */
768  //pairs[pr_offset].d = alpha*data[buf_offset+i]+beta*pairs[pr_offset].d;
769  char wval[sr->pair_size()];
770  sr->mul(alpha, data + sr->el_size*(buf_offset+i), wval);
771  char wval2[sr->pair_size()];
772  sr->mul(beta, pairs[pr_offset].d(), wval2);
773  sr->add(wval, wval2, wval);
774  pairs[pr_offset].write_val(wval);
775  }
776  } else {
777  ASSERT(rw =='w');
778  //data[(int64_t)buf_offset+i] = beta*data[(int64_t)buf_offset+i]+alpha*pairs[pr_offset].d;
779  if (alpha == NULL)
780  pairs[pr_offset].read_val(data+sr->el_size*(buf_offset+i));
781  else {
782  char wval[sr->pair_size()];
783  sr->mul(beta, data + sr->el_size*(buf_offset+i), wval);
784  char wval2[sr->pair_size()];
785  sr->mul(alpha, pairs[pr_offset].d(), wval2);
786  sr->add(wval, wval2, wval);
787  sr->copy(data + sr->el_size*(buf_offset+i), wval);
788  }
789  }
790  pr_offset++;
791  //Check for write conflicts
792  //Fixed: allow and handle them!
793  while (pr_offset < size && pairs[pr_offset].k() == pairs[pr_offset-1].k()){
794  // printf("found overlapped write of key %ld and value %lf\n", pairs[pr_offset].k, pairs[pr_offset].d);
795  if (rw == 'r'){
796  if (alpha == NULL){
797  pairs[pr_offset].write_val(data + sr->el_size*(buf_offset+i));
798  } else {
799 // pairs[pr_offset].d = alpha*data[buf_offset+i]+beta*pairs[pr_offset].d;
800  char wval[sr->pair_size()];
801  sr->mul(alpha, data + sr->el_size*(buf_offset+i), wval);
802  char wval2[sr->pair_size()];
803  sr->mul(beta, pairs[pr_offset].d(), wval2);
804  sr->add(wval, wval2, wval);
805  pairs[pr_offset].write_val(wval);
806  }
807  } else {
808  //FIXME: may be problematic if someone writes entries of a symmetric tensor redundantly
809  if (alpha == NULL){
810  sr->add(data + (buf_offset+i)*sr->el_size,
811  pairs[pr_offset].d(),
812  data + (buf_offset+i)*sr->el_size);
813  } else {
814  //data[(int64_t)buf_offset+i] = beta*data[(int64_t)buf_offset+i]+alpha*pairs[pr_offset].d;
815  char wval[sr->pair_size()];
816  sr->mul(alpha, pairs[pr_offset].d(), wval);
817  sr->add(wval, data + sr->el_size*(buf_offset+i), wval);
818  sr->copy(data + sr->el_size*(buf_offset+i), wval);
819  }
820  }
821  // printf("rw = %c found overlapped write and set value to %lf\n", rw, data[(int64_t)buf_offset+i]);
822  pr_offset++;
823  }
824  } else {
825  i++;
826  /* DEBUG_PRINTF("%d key[%d] %d not matched with %d\n",
827  (int)pairs[pr_offset-1].k,
828  pr_offset, (int)pairs[pr_offset].k,
829  (idx_offset+i*phase[0]+phase_rank[0]));*/
830  }
831  }
832  }
833  buf_offset += imax;
834  if (pr_offset >= size)
835  break;
836  /* Increment indices and set up offsets */
837  for (act_lda=1; act_lda < order; act_lda++){
838  idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
839  *edge_lda[act_lda];
840  idx[act_lda]++;
841  act_max = edge_len[act_lda]/phase[act_lda];
842  if (sym[act_lda] != NS) act_max = idx[act_lda+1]+1;
843  if (idx[act_lda] >= act_max)
844  idx[act_lda] = 0;
845  idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
846  *edge_lda[act_lda];
847  ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
848  if (idx[act_lda] > 0)
849  break;
850  }
851  if (act_lda == order) break;
852  }
853  for (act_lda=0; act_lda < order; act_lda++){
854  phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
855  virt_rank[act_lda]++;
856  if (virt_rank[act_lda] >= virt_dim[act_lda])
857  virt_rank[act_lda] = 0;
858  phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
859  if (virt_rank[act_lda] > 0)
860  break;
861  }
862  if (act_lda == order) break;
863  }
865  //printf("pr_offset = %ld / %ld \n",pr_offset,size);
866  ASSERT(pr_offset == size);
867  CTF_int::cdealloc(idx);
868  CTF_int::cdealloc(virt_rank);
869  CTF_int::cdealloc(edge_lda);
870  }
871 
872  void wr_pairs_layout(int order,
873  int np,
874  int64_t inwrite,
875  char const * alpha,
876  char const * beta,
877  char rw,
878  int num_virt,
879  int const * sym,
880  int const * edge_len,
881  int const * padding,
882  int const * phase,
883  int const * phys_phase,
884  int const * virt_phase,
885  int * virt_phys_rank,
886  int const * bucket_lda,
887  char * wr_pairs_buf,
888  char * rw_data,
889  CommData glb_comm,
890  algstrct const * sr,
891  bool is_sparse,
892  int64_t nnz_loc,
893  int64_t * nnz_blk,
894  char *& pprs_new,
895  int64_t & nnz_loc_new){
896  int64_t new_num_pair, nwrite, swp;
897  int64_t * bucket_counts, * recv_counts;
898  int64_t * recv_displs, * send_displs;
899  int * depadding, * depad_edge_len;
900  int * ckey;
901  int j, is_out, sign, is_perm;
902  char * swap_datab, * buf_datab;
903  int64_t * old_nnz_blk;
904  if (is_sparse){
905  CTF_int::alloc_ptr(num_virt*sizeof(int64_t), (void**)&old_nnz_blk);
906  memcpy(old_nnz_blk, nnz_blk, num_virt*sizeof(int64_t));
907  }
908 
909  buf_datab = sr->pair_alloc(inwrite);
910  swap_datab = sr->pair_alloc(inwrite);
911  CTF_int::alloc_ptr(np*sizeof(int64_t), (void**)&bucket_counts);
912  CTF_int::alloc_ptr(np*sizeof(int64_t), (void**)&recv_counts);
913  CTF_int::alloc_ptr(np*sizeof(int64_t), (void**)&send_displs);
914  CTF_int::alloc_ptr(np*sizeof(int64_t), (void**)&recv_displs);
915 
916  PairIterator buf_data = PairIterator(sr, buf_datab);
917  PairIterator swap_data = PairIterator(sr, swap_datab);
918  PairIterator wr_pairs = PairIterator(sr, wr_pairs_buf);
919 
920 
921  #if DEBUG >= 1
922  int64_t total_tsr_size = 1;
923  for (int i=0; i<order; i++){
924  total_tsr_size *= edge_len[i];
925  }
926  //printf("pair size is %d el size is %d\n",sr->pair_size(),sr->el_size);
927  for (int64_t i=0; i<inwrite; i++){
928  if (wr_pairs[i].k()>=total_tsr_size)
929  printf("[%d] %ldth key is %ld size %ld\n",glb_comm.rank, i, wr_pairs[i].k(),total_tsr_size);
930  ASSERT(wr_pairs[i].k() >= 0);
931  ASSERT(wr_pairs[i].k() < total_tsr_size);
932  }
933  #endif
935 
936  /* Copy out the input data, do not touch that array */
937  // memcpy(swap_data, wr_pairs, nwrite*sizeof(tkv_pair<dtype>));
938  CTF_int::alloc_ptr(order*sizeof(int), (void**)&depad_edge_len);
939  for (int i=0; i<order; i++){
940  depad_edge_len[i] = edge_len[i] - padding[i];
941  }
942  CTF_int::alloc_ptr(order*sizeof(int), (void**)&ckey);
943  TAU_FSTART(check_key_ranges);
944 
945  //calculate the number of keys that need to be vchanged first
946  int64_t nchanged = 0;
947  for (int64_t i=0; i<inwrite; i++){
948  cvrt_idx(order, depad_edge_len, wr_pairs[i].k(), ckey);
949  is_out = 0;
950  sign = 1;
951  is_perm = 1;
952  while (is_perm && !is_out){
953  is_perm = 0;
954  for (j=0; j<order-1; j++){
955  if ((sym[j] == SH || sym[j] == AS) && ckey[j] == ckey[j+1]){
956  is_out = 1;
957  break;
958  } else if (sym[j] != NS && ckey[j] > ckey[j+1]){
959  swp = ckey[j];
960  ckey[j] = ckey[j+1];
961  ckey[j+1] = swp;
962  if (sym[j] == AS){
963  sign *= -1;
964  }
965  is_perm = 1;
966  }/* else if (sym[j] == AS && ckey[j] > ckey[j+1]){
967  swp = ckey[j];
968  ckey[j] = ckey[j+1];
969  ckey[j+1] = swp;
970  is_perm = 1;
971  } */
972  }
973  }
974  if (!is_out){
975  int64_t skey;
976  cvrt_idx(order, depad_edge_len, ckey, &skey);
977  if (rw == 'r' && skey != wr_pairs[i].k()){
978  nchanged++;
979  }
980  } else if (rw == 'r'){
981  nchanged++;
982  }
983  }
984 
985  nwrite = 0;
986  int64_t * changed_key_indices;
987  char * new_changed_pairs = sr->pair_alloc(nchanged);
988  PairIterator ncp(sr, new_changed_pairs);
989  int * changed_key_scale;
990  CTF_int::alloc_ptr(nchanged*sizeof(int64_t), (void**)&changed_key_indices);
991  CTF_int::alloc_ptr(nchanged*sizeof(int), (void**)&changed_key_scale);
992 
993  nchanged = 0;
994  for (int64_t i=0; i<inwrite; i++){
995  cvrt_idx(order, depad_edge_len, wr_pairs[i].k(), ckey);
996  is_out = 0;
997  sign = 1;
998  is_perm = 1;
999  while (is_perm && !is_out){
1000  is_perm = 0;
1001  for (j=0; j<order-1; j++){
1002  if ((sym[j] == SH || sym[j] == AS) && ckey[j] == ckey[j+1]){
1003  is_out = 1;
1004  break;
1005  } else if (sym[j] != NS && ckey[j] > ckey[j+1]){
1006  swp = ckey[j];
1007  ckey[j] = ckey[j+1];
1008  ckey[j+1] = swp;
1009  if (sym[j] == AS){
1010  sign *= -1;
1011  }
1012  is_perm = 1;
1013  }/* else if (sym[j] == AS && ckey[j] > ckey[j+1]){
1014  swp = ckey[j];
1015  ckey[j] = ckey[j+1];
1016  ckey[j+1] = swp;
1017  is_perm = 1;
1018  } */
1019  }
1020  }
1021  if (!is_out){
1022  int64_t ky = swap_data[nwrite].k();
1023  cvrt_idx(order, depad_edge_len, ckey, &ky);
1024  swap_data[nwrite].write_key(ky);
1025  if (sign == 1)
1026  swap_data[nwrite].write_val(wr_pairs[i].d());
1027  else {
1028  char ainv[sr->el_size];
1029  sr->addinv(wr_pairs[i].d(), ainv);
1030  swap_data[nwrite].write_val(ainv);
1031  }
1032  if (rw == 'r' && swap_data[nwrite].k() != wr_pairs[i].k()){
1033  /*printf("the %lldth key has been set from %lld to %lld\n",
1034  i, wr_pairs[i].k, swap_data[nwrite].k);*/
1035  changed_key_indices[nchanged]= i;
1036  swap_data[nwrite].read(ncp[nchanged].ptr);
1037  changed_key_scale[nchanged] = sign;
1038  nchanged++;
1039  }
1040  nwrite++;
1041  } else if (rw == 'r'){
1042  changed_key_indices[nchanged] = i;
1043  wr_pairs[i].read(ncp[nchanged].ptr);
1044  changed_key_scale[nchanged] = 0;
1045  nchanged++;
1046  }
1047  }
1048  CTF_int::cdealloc(ckey);
1049  TAU_FSTOP(check_key_ranges);
1050 
1051  /* If the packed tensor is padded, pad keys */
1052  int const * wlen;
1053  if (!is_sparse){
1054  pad_key(order, nwrite, depad_edge_len, padding, swap_data, sr);
1055  CTF_int::cdealloc(depad_edge_len);
1056  wlen = edge_len;
1057  } else wlen = depad_edge_len;
1058 
1059  /* Figure out which processor the value in a packed layout, lies for each key */
1060  bucket_by_pe(order, nwrite, np,
1061  phys_phase, virt_phase, bucket_lda,
1062  wlen, swap_data, bucket_counts,
1063  send_displs, buf_data, sr);
1064 
1065  /* Exchange send counts */
1066  MPI_Alltoall(bucket_counts, 1, MPI_INT64_T,
1067  recv_counts, 1, MPI_INT64_T, glb_comm.cm);
1068 
1069  /* calculate offsets */
1070  recv_displs[0] = 0;
1071  for (int i=1; i<np; i++){
1072  recv_displs[i] = recv_displs[i-1] + recv_counts[i-1];
1073  }
1074  new_num_pair = recv_displs[np-1] + recv_counts[np-1];
1075 
1076  /*for (i=0; i<np; i++){
1077  bucket_counts[i] = bucket_counts[i]*sizeof(tkv_pair<dtype>);
1078  send_displs[i] = send_displs[i]*sizeof(tkv_pair<dtype>);
1079  recv_counts[i] = recv_counts[i]*sizeof(tkv_pair<dtype>);
1080  recv_displs[i] = recv_displs[i]*sizeof(tkv_pair<dtype>);
1081  }*/
1082 
1083 /* int64_t max_np;
1084  MPI_Allreduce(&new_num_pair, &max_np, 1, MPI_INT64_T, MPI_MAX, glb_comm.cm);
1085  if (glb_comm.rank == 0) printf("max received elements is %ld, mine are %ld\n", max_np, new_num_pair);*/
1086 
1087  if (new_num_pair > nwrite){
1088  sr->pair_dealloc(swap_datab);
1089  swap_datab = sr->pair_alloc(new_num_pair);
1090  swap_data = PairIterator(sr, swap_datab);
1091  }
1092  /* Exchange data according to counts/offsets */
1093  //ALL_TO_ALLV(buf_data, bucket_counts, send_displs, MPI_CHAR,
1094  // swap_data, recv_counts, recv_displs, MPI_CHAR, glb_comm);
1095  if (glb_comm.np == 1){
1096  char * save_ptr = buf_datab;
1097  buf_datab = swap_datab;
1098  swap_datab = save_ptr;
1099  buf_data = PairIterator(sr, buf_datab);
1100  swap_data = PairIterator(sr, swap_datab);
1101  } else {
1102  glb_comm.all_to_allv(buf_data.ptr, bucket_counts, send_displs, sr->pair_size(),
1103  swap_data.ptr, recv_counts, recv_displs);
1104  }
1105 
1106 
1107 
1108  if (new_num_pair > nwrite){
1109  sr->pair_dealloc(buf_datab);
1110  buf_datab = sr->pair_alloc(new_num_pair);
1111  buf_data = PairIterator(sr, buf_datab);
1112  }
1113  /* Figure out what virtual bucket each key belongs to. Bucket
1114  and sort them accordingly */
1115  int64_t * virt_counts =
1116  bucket_by_virt(order, num_virt, new_num_pair, phys_phase, virt_phase,
1117  wlen, swap_data, buf_data, sr);
1118 
1119  /* Write or read the values corresponding to the keys */
1120  if (is_sparse){
1121  if (rw == 'r'){
1122  ConstPairIterator prs_tsr(sr, rw_data);
1123  sp_read(sr, nnz_loc, prs_tsr, alpha, new_num_pair, buf_data, beta);
1124  } else {
1125  ConstPairIterator prs_tsr(sr, rw_data);
1126  ConstPairIterator prs_write(sr, buf_data.ptr);
1127  sp_write(num_virt, sr, old_nnz_blk, prs_tsr, beta, virt_counts, prs_write, alpha, nnz_blk, pprs_new);
1128  for (int v=0; v<num_virt; v++){
1129  if (v==0) nnz_loc_new = nnz_blk[0];
1130  else nnz_loc_new += nnz_blk[v];
1131  }
1132  }
1133  } else
1134  readwrite(order,
1135  new_num_pair,
1136  alpha,
1137  beta,
1138  num_virt,
1139  edge_len,
1140  sym,
1141  phase,
1142  phys_phase,
1143  virt_phase,
1144  virt_phys_rank,
1145  rw_data,
1146  buf_datab,
1147  rw,
1148  sr);
1149 
1150  cdealloc(virt_counts);
1151 
1152  /* If we want to read the keys, we must return them to where they
1153  were requested */
1154  if (rw == 'r'){
1155  CTF_int::alloc_ptr(order*sizeof(int), (void**)&depadding);
1156  /* Sort the key-value pairs we determine*/
1157  //std::sort(buf_data, buf_data+new_num_pair);
1158  buf_data.sort(new_num_pair);
1159  /* Search for the keys in the order in which we received the keys */
1160  for (int64_t i=0; i<new_num_pair; i++){
1161  /*el_loc = std::lower_bound(buf_data,
1162  buf_data+new_num_pair,
1163  swap_data[i]);*/
1164  int64_t el_loc = buf_data.lower_bound(new_num_pair, swap_data[i]);
1165  #if (DEBUG>=5)
1166  if (el_loc < 0 || el_loc >= new_num_pair){
1168  DEBUG_PRINTF("swap_data[%d].k = %d, not found\n", i, (int)swap_data[i].k());
1169  ASSERT(0);
1170  }
1171  #endif
1172  swap_data[i].write_val(buf_data[el_loc].d());
1173  }
1174 
1175  /* Inverse the transpose we did above to get the keys back to requestors */
1176  //ALL_TO_ALLV(swap_data, recv_counts, recv_displs, MPI_CHAR,
1177  // buf_data, bucket_counts, send_displs, MPI_CHAR, glb_comm);
1178  glb_comm.all_to_allv(swap_data.ptr, recv_counts, recv_displs, sr->pair_size(),
1179  buf_data.ptr, bucket_counts, send_displs);
1180 
1181  /* unpad the keys if necesary */
1182  if (!is_sparse){
1183  for (int i=0; i<order; i++){
1184  depadding[i] = -padding[i];
1185  }
1186  pad_key(order, nwrite, edge_len, depadding, buf_data, sr);
1187  }
1188 
1189  /* Sort the pairs that were sent out, now with correct values */
1190 // std::sort(buf_data, buf_data+nwrite);
1191  buf_data.sort(nwrite);
1192  /* Search for the keys in the same order they were requested */
1193  j=0;
1194  for (int64_t i=0; i<inwrite; i++){
1195  if (j<(int64_t)nchanged && changed_key_indices[j] == i){
1196  if (changed_key_scale[j] == 0){
1197  wr_pairs[i].write_val(sr->addid());
1198  } else {
1199  //el_loc = std::lower_bound(buf_data, buf_data+nwrite, new_changed_pairs[j]);
1200  //wr_pairs[i].d = changed_key_scale[j]*el_loc[0].d;
1201  int64_t el_loc = buf_data.lower_bound(nwrite, ConstPairIterator(sr, new_changed_pairs+j*sr->pair_size()));
1202  if (changed_key_scale[j] == -1){
1203  char aspr[sr->el_size];
1204  sr->addinv(buf_data[el_loc].d(), aspr);
1205  wr_pairs[i].write_val(aspr);
1206  } else
1207  wr_pairs[i].write_val(buf_data[el_loc].d());
1208  }
1209  j++;
1210  } else {
1211  int64_t el_loc = buf_data.lower_bound(nwrite, wr_pairs[i]);
1212 // el_loc = std::lower_bound(buf_data, buf_data+nwrite, wr_pairs[i]);
1213  wr_pairs[i].write_val(buf_data[el_loc].d());
1214  }
1215  }
1216  CTF_int::cdealloc(depadding);
1217  }
1218  if (is_sparse) cdealloc(depad_edge_len);
1219  //FIXME: free here?
1220  cdealloc(changed_key_indices);
1221  cdealloc(changed_key_scale);
1222  sr->pair_dealloc(new_changed_pairs);
1224 
1225  if (is_sparse) CTF_int::cdealloc(old_nnz_blk);
1226  sr->pair_dealloc(swap_datab);
1227  sr->pair_dealloc(buf_datab);
1228  CTF_int::cdealloc((void*)bucket_counts);
1229  CTF_int::cdealloc((void*)recv_counts);
1230  CTF_int::cdealloc((void*)send_displs);
1231  CTF_int::cdealloc((void*)recv_displs);
1232 
1233  }
1234 
1235  void read_loc_pairs(int order,
1236  int64_t nval,
1237  int num_virt,
1238  int const * sym,
1239  int const * edge_len,
1240  int const * padding,
1241  int const * phase,
1242  int const * phys_phase,
1243  int const * virt_phase,
1244  int * phase_rank,
1245  int64_t * nread,
1246  char const * data,
1247  char ** pairs,
1248  algstrct const * sr){
1249  int64_t i;
1250  int * prepadding;
1251  char * dpairsb;
1252  dpairsb = sr->pair_alloc(nval);
1253  CTF_int::alloc_ptr(sizeof(int)*order, (void**)&prepadding);
1254  memset(prepadding, 0, sizeof(int)*order);
1255  /* Iterate through packed layout and form key value pairs */
1256  assign_keys(order,
1257  nval,
1258  num_virt,
1259  edge_len,
1260  sym,
1261  phase,
1262  phys_phase,
1263  virt_phase,
1264  phase_rank,
1265  data,
1266  dpairsb,
1267  sr);
1268 /* for (i=0; i<nval; i++){
1269  printf("\nX[%ld] ", ((int64_t*)(dpairsb+i*sr->pair_size()))[0]);
1270  sr->print(dpairsb+i*sr->pair_size()+sizeof(int64_t));
1271  }
1272 */
1273  /* If we need to unpad */
1274  int64_t new_num_pair;
1275  int * depadding;
1276  int * pad_len;
1277  char * new_pairsb;
1278  new_pairsb = sr->pair_alloc(nval);
1279 
1280  PairIterator new_pairs = PairIterator(sr, new_pairsb);
1281 
1282  CTF_int::alloc_ptr(sizeof(int)*order, (void**)&depadding);
1283  CTF_int::alloc_ptr(sizeof(int)*order, (void**)&pad_len);
1284 
1285  for (i=0; i<order; i++){
1286  pad_len[i] = edge_len[i]-padding[i];
1287  }
1288  /* Get rid of any padded values */
1289  depad_tsr(order, nval, pad_len, sym, padding, prepadding,
1290  dpairsb, new_pairsb, &new_num_pair, sr);
1291 
1292  sr->pair_dealloc(dpairsb);
1293  if (new_num_pair == 0){
1294  sr->pair_dealloc(new_pairsb);
1295  new_pairsb = NULL;
1296  }
1297  *pairs = new_pairsb;
1298  *nread = new_num_pair;
1299 
1300  for (i=0; i<order; i++){
1301  depadding[i] = -padding[i];
1302  }
1303 
1304  /* Adjust keys to remove padding */
1305  pad_key(order, new_num_pair, edge_len, depadding, new_pairs, sr);
1306  CTF_int::cdealloc((void*)pad_len);
1307  CTF_int::cdealloc((void*)depadding);
1308  CTF_int::cdealloc(prepadding);
1309  }
1310 
1311  void sp_read(algstrct const * sr,
1312  int64_t ntsr,
1313  ConstPairIterator prs_tsr,
1314  char const * alpha,
1315  int64_t nread,
1316  PairIterator prs_read,
1317  char const * beta){
1318  // each for loop iteration does one addition, o and r are also incremented within
1319  // only incrementing r allows multiple reads of the same val
1320  int64_t r = 0;
1321  for (int64_t t=0; t<ntsr && r<nread; r++){
1322  while (t<ntsr && r<nread && prs_tsr[t].k() != prs_read[r].k()){
1323  if (prs_tsr[t].k() < prs_read[r].k())
1324  t++;
1325  else {
1326  prs_read[r].write_val(sr->addid());
1327  r++;
1328  }
1329  }
1330  // scale and add if match found
1331  if (t<ntsr && r<nread){
1332  char a[sr->el_size];
1333  char b[sr->el_size];
1334  char c[sr->el_size];
1335  if (beta != NULL){
1336  sr->mul(prs_read[r].d(), beta, a);
1337  } else {
1338  prs_read[r].read_val(a);
1339  }
1340  if (alpha != NULL){
1341  sr->mul(prs_tsr[t].d(), alpha, b);
1342  } else {
1343  if (beta == NULL){
1344  prs_read[r].write_val(prs_tsr[t].d());
1345  } else {
1346  prs_tsr[t].read_val(b);
1347  }
1348  }
1349  if (beta == NULL && alpha != NULL){
1350  prs_read[r].write_val(b);
1351  } else if (beta != NULL){
1352  sr->add(a, b, c);
1353  prs_read[r].write_val(c);
1354  }
1355  }
1356  }
1357  for (; r<nread; r++){
1358  prs_read[r].write_val(sr->addid());
1359  }
1360  }
1361 
1362  void sp_write(int num_virt,
1363  algstrct const * sr,
1364  int64_t * vntsr,
1365  ConstPairIterator vprs_tsr,
1366  char const * beta,
1367  int64_t * vnwrite,
1368  ConstPairIterator vprs_write,
1369  char const * alpha,
1370  int64_t * vnnew,
1371  char *& pprs_new){
1372  // determine how many unique keys there are in prs_tsr and prs_Write
1373  int64_t tot_new = 0;
1374  ConstPairIterator prs_tsr = vprs_tsr;
1375  ConstPairIterator prs_write = vprs_write;
1376  for (int v=0; v<num_virt; v++){
1377  int64_t ntsr = vntsr[v];
1378  int64_t nwrite = vnwrite[v];
1379  if (v>0){
1380  prs_tsr = prs_tsr[vntsr[v-1]];
1381  prs_write = prs_write[vnwrite[v-1]];
1382  }
1383  int64_t nnew = 0;
1384  nnew = ntsr;
1385  for (int64_t t=0,w=0; w<nwrite; w++){
1386  while (w<nwrite){
1387  if (t<ntsr && prs_tsr[t].k() < prs_write[w].k())
1388  t++;
1389  else if (t<ntsr && prs_tsr[t].k() == prs_write[w].k()){
1390  t++;
1391  w++;
1392  } else {
1393  if (w==0 || prs_write[w-1].k() != prs_write[w].k())
1394  nnew++;
1395  w++;
1396  }
1397  }
1398  }
1399  vnnew[v] = nnew;
1400  tot_new += nnew;
1401  }
1402  //printf("ntsr = %ld nwrite = %ld nnew = %ld\n",ntsr,nwrite,nnew);
1403  pprs_new = sr->pair_alloc(tot_new);
1404  PairIterator vprs_new(sr, pprs_new);
1405  // each for loop computes one new value of prs_new
1406  // (multiple writes may contribute to it),
1407  // t, w, and n are incremented within
1408  // only incrementing r allows multiple writes of the same val
1409  prs_tsr = vprs_tsr;
1410  prs_write = vprs_write;
1411  PairIterator prs_new = vprs_new;
1412  for (int v=0; v<num_virt; v++){
1413  int64_t ntsr = vntsr[v];
1414  int64_t nwrite = vnwrite[v];
1415  int64_t nnew = vnnew[v];
1416  if (v>0){
1417  prs_tsr = prs_tsr[vntsr[v-1]];
1418  prs_write = prs_write[vnwrite[v-1]];
1419  prs_new = prs_new[vnnew[v-1]];
1420  }
1421 
1422  for (int64_t t=0,w=0,n=0; n<nnew; n++){
1423  if (t<ntsr && (w==nwrite || prs_tsr[t].k() < prs_write[w].k())){
1424  prs_new[n].write(prs_tsr[t].ptr);
1425  t++;
1426  } else {
1427  if (t>=ntsr || prs_tsr[t].k() > prs_write[w].k()){
1428  prs_new[n].write(prs_write[w].ptr);
1429  if (alpha != NULL)
1430  sr->mul(prs_new[n].d(), alpha, prs_new[n].d());
1431  w++;
1432  } else {
1433  char a[sr->el_size];
1434  char b[sr->el_size];
1435  char c[sr->el_size];
1436  if (alpha != NULL){
1437  sr->mul(prs_write[w].d(), alpha, a);
1438  } else {
1439  prs_write[w].read_val(a);
1440  }
1441  if (beta != NULL){
1442  sr->mul(prs_tsr[t].d(), beta, b);
1443  } else {
1444  prs_tsr[t].read_val(b);
1445  }
1446  sr->add(a, b, c);
1447  prs_new[n].write_val(c);
1448  ((int64_t*)(prs_new[n].ptr))[0] = prs_tsr[t].k();
1449  t++;
1450  w++;
1451  }
1452  // accumulate any repeated key writes
1453  while (w < nwrite && prs_write[w].k() == prs_write[w-1].k()){
1454  if (alpha != NULL){
1455  char a[sr->el_size];
1456  sr->mul(prs_write[w].d(), alpha, a);
1457  sr->add(prs_new[n].d(), a, prs_new[n].d());
1458  } else
1459  sr->add(prs_new[n].d(), prs_write[w].d(), prs_new[n].d());
1460  w++;
1461  }
1462  }
1463  /*printf("%ldth value is ", n);
1464  sr->print(prs_new[n].d());
1465  printf(" with key %ld\n",prs_new[n].k());*/
1466  }
1467  }
1468  }
1469 }
1470 
void write(char const *buf, int64_t n=1)
sets internal pairs to provided data
Definition: algstrct.cxx:805
void write_key(int64_t key)
sets key of head pair to key
Definition: algstrct.cxx:821
virtual int pair_size() const
gets pair size el_size plus the key size
Definition: algstrct.h:46
virtual char * pair_alloc(int64_t n) const
allocate space for n (int64_t,dtype) pairs, necessary for object types
Definition: algstrct.cxx:681
void read_loc_pairs(int order, int64_t nval, int num_virt, int const *sym, int const *edge_len, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int *phase_rank, int64_t *nread, char const *data, char **pairs, algstrct const *sr)
read tensor pairs local to processor
Definition: sparse_rw.cxx:1235
virtual void copy(char *a, char const *b) const
copies element b to element a
Definition: algstrct.cxx:538
void read(char *buf, int64_t n=1) const
sets external data to what this operator points to
Definition: algstrct.cxx:797
double sign(int par)
#define ASSERT(...)
Definition: util.h:88
void * alloc(int64_t len)
alloc abstraction
Definition: memcontrol.cxx:365
Definition: common.h:37
virtual char const * addid() const
MPI datatype for pairs.
Definition: algstrct.cxx:89
void sp_write(int num_virt, algstrct const *sr, int64_t *vntsr, ConstPairIterator vprs_tsr, char const *beta, int64_t *vnwrite, ConstPairIterator vprs_write, char const *alpha, int64_t *vnnew, char *&pprs_new)
writes pairs in a sparse write set to the sparse set of elements defining the tensor, resulting in a set of size between ntsr and ntsr+nwrite
Definition: sparse_rw.cxx:1362
#define DEBUG_PRINTF(...)
Definition: util.h:238
void sort(int64_t n)
sorts set of pairs using std::sort
Definition: algstrct.cxx:825
int64_t lower_bound(int64_t n, ConstPairIterator op)
searches for pair op via std::lower_bound
Definition: algstrct.cxx:991
void read_val(char *buf) const
sets external value to the value pointed by the iterator
Definition: algstrct.cxx:801
void all_to_allv(void *send_buffer, int64_t const *send_counts, int64_t const *send_displs, int64_t datum_size, void *recv_buffer, int64_t const *recv_counts, int64_t const *recv_displs)
performs all-to-all-v with 64-bit integer counts and offset on arbitrary length types (datum_size)...
Definition: common.cxx:424
virtual void addinv(char const *a, char *b) const
b = -a
Definition: algstrct.cxx:103
void sp_read(algstrct const *sr, int64_t ntsr, ConstPairIterator prs_tsr, char const *alpha, int64_t nread, PairIterator prs_read, char const *beta)
reads elements of a sparse set defining the tensor, into a sparse read set with potentially repeating...
Definition: sparse_rw.cxx:1311
int64_t k() const
returns key of pair at head of ptr
Definition: algstrct.cxx:764
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
Definition: memcontrol.cxx:320
int64_t k() const
returns key of pair at head of ptr
Definition: algstrct.cxx:789
virtual void pair_dealloc(char *ptr) const
deallocate given pointer containing contiguous array of pairs
Definition: algstrct.cxx:693
#define TAU_FSTOP(ARG)
Definition: util.h:281
void assign_keys(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char const *vdata, char *vpairs, algstrct const *sr)
assigns keys to an array of values
Definition: sparse_rw.cxx:180
#define TAU_FSTART(ARG)
Definition: util.h:280
void read_val(char *buf) const
sets value to the value pointed by the iterator
Definition: algstrct.cxx:776
void bucket_by_pe(int order, int64_t num_pair, int64_t np, int const *phys_phase, int const *virt_phase, int const *bucket_lda, int const *edge_len, ConstPairIterator mapped_data, int64_t *bucket_counts, int64_t *bucket_off, PairIterator bucket_data, algstrct const *sr)
buckets key-value pairs by processor according to distribution
Definition: sparse_rw.cxx:432
MPI_Comm cm
Definition: common.h:129
char * d() const
returns value of pair at head of ptr
Definition: algstrct.cxx:793
void depermute_keys(int order, int num_pair, int const *edge_len, int const *new_edge_len, int *const *permutation, char *pairs_buf, algstrct const *sr)
depermutes keys (apply P^T)
Definition: sparse_rw.cxx:99
void permute_keys(int order, int num_pair, int const *edge_len, int const *new_edge_len, int *const *permutation, char *pairs_buf, int64_t *new_num_pair, algstrct const *sr)
permutes keys
Definition: sparse_rw.cxx:10
virtual void add(char const *a, char const *b, char *c) const
c = a+b
Definition: algstrct.cxx:109
int el_size
size of each element of algstrct in bytes
Definition: algstrct.h:16
void pad_key(int order, int64_t num_pair, int const *edge_len, int const *padding, PairIterator pairs, algstrct const *sr, int const *offsets)
applies padding to keys
Definition: pad.cxx:6
int cdealloc(void *ptr)
free abstraction
Definition: memcontrol.cxx:480
int64_t * bucket_by_virt(int order, int num_virt, int64_t num_pair, int const *phys_phase, int const *virt_phase, int const *edge_len, ConstPairIterator mapped_data, PairIterator bucket_data, algstrct const *sr)
buckets key value pairs by block/virtual-processor
Definition: sparse_rw.cxx:539
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
void write_val(char const *buf)
sets value of head pair to what is in buf
Definition: algstrct.cxx:817
void wr_pairs_layout(int order, int np, int64_t inwrite, char const *alpha, char const *beta, char rw, int num_virt, int const *sym, int const *edge_len, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int *virt_phys_rank, int const *bucket_lda, char *wr_pairs_buf, char *rw_data, CommData glb_comm, algstrct const *sr, bool is_sparse, int64_t nnz_loc, int64_t *nnz_blk, char *&pprs_new, int64_t &nnz_loc_new)
read or write pairs from / to tensor
Definition: sparse_rw.cxx:872
#define MIN(a, b)
Definition: util.h:176
virtual void mul(char const *a, char const *b, char *c) const
c = a*b
Definition: algstrct.cxx:120
void spsfy_tsr(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char const *vdata, char *&vpairs, int64_t *nnz_blk, algstrct const *sr, int64_t const *edge_lda, std::function< bool(char const *)> f)
extracts all tensor values (in pair format) that pass a sparsifier function (including padded zeros i...
Definition: sparse_rw.cxx:276
void depad_tsr(int order, int64_t num_pair, int const *edge_len, int const *sym, int const *padding, int const *prepadding, char const *pairsb, char *new_pairsb, int64_t *new_num_pair, algstrct const *sr)
retrieves the unpadded pairs
Definition: pad.cxx:51
Definition: common.h:37
#define ABORT
Definition: util.h:162
void readwrite(int order, int64_t size, char const *alpha, char const *beta, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char *vdata, char *pairs_buf, char rw, algstrct const *sr)
read or write pairs from / to tensor
Definition: sparse_rw.cxx:685
Definition: common.h:37
def np(self)
Definition: core.pyx:315
void cvrt_idx(int order, int const *lens, int64_t idx, int *idx_arr)
Definition: common.cxx:533
virtual void set_pair(char *a, int64_t key, char const *vb) const
sets 1 elements of pair a to value and key
Definition: algstrct.cxx:658