Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
cyclic_reshuffle.cxx
Go to the documentation of this file.
1 /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
2 
3 #include "cyclic_reshuffle.h"
4 #include "../shared/util.h"
5 
6 namespace CTF_int {
7 
8  void pad_cyclic_pup_virt_buff(int const * sym,
9  distribution const & old_dist,
10  distribution const & new_dist,
11  int const * len,
12  int const * old_phys_dim,
13  int const * old_phys_edge_len,
14  int const * old_virt_edge_len,
15  int64_t old_virt_nelem,
16  int const * old_offsets,
17  int * const * old_permutation,
18  int total_np,
19  int const * new_phys_dim,
20  int const * new_phys_edge_len,
21  int const * new_virt_edge_len,
22  int64_t new_virt_nelem,
23  char * old_data,
24  char ** new_data,
25  int forward,
26  int * const * bucket_offset,
27  char const * alpha,
28  char const * beta,
29  algstrct const * sr){
30  bool is_copy = false;
31  if (sr->isequal(sr->mulid(), alpha) && sr->isequal(sr->addid(), beta)) is_copy = true;
32  if (old_dist.order == 0){
33  if (forward)
34  sr->copy(new_data[0], old_data);
35  else {
36  if (is_copy)
37  sr->copy(old_data, new_data[0]);
38  else
39  sr->acc(old_data, beta, new_data[0], alpha);
40  }
41  return;
42  }
43 
44  int old_virt_np = 1;
45  for (int dim = 0;dim < old_dist.order;dim++) old_virt_np *= old_dist.virt_phase[dim];
46 
47  int new_virt_np = 1;
48  for (int dim = 0;dim < old_dist.order;dim++) new_virt_np *= new_dist.virt_phase[dim];
49 
50  int nbucket = total_np; //*(forward ? new_virt_np : old_virt_np);
51 
52  #if DEBUG >= 1
53  int rank;
54  MPI_Comm_rank(MPI_COMM_WORLD,&rank);
55  #endif
56 
57  TAU_FSTART(cyclic_pup_bucket);
58  #ifdef USE_OMP
59  int max_ntd = omp_get_max_threads();
60  max_ntd = MAX(1,MIN(max_ntd,new_virt_nelem/nbucket));
61 
62  int64_t old_size, new_size;
63  old_size = sy_packed_size(old_dist.order, old_virt_edge_len, sym)*old_virt_np;
64  new_size = sy_packed_size(old_dist.order, new_virt_edge_len, sym)*new_virt_np;
65  /*if (forward){
66  } else {
67  old_size = sy_packed_size(old_dist.order, old_virt_edge_len, sym)*new_virt_np;
68  new_size = sy_packed_size(old_dist.order, new_virt_edge_len, sym)*old_virt_np;
69  }*/
70  /*printf("old_size=%d, new_size=%d,old_virt_np=%d,new_virt_np=%d\n",
71  old_size,new_size,old_virt_np,new_virt_np);
72  */
73  int64_t * bucket_store;
74  int64_t * count_store;
75  int64_t * thread_store;
76  mst_alloc_ptr(sizeof(int64_t)*MAX(old_size,new_size), (void**)&bucket_store);
77  mst_alloc_ptr(sizeof(int64_t)*MAX(old_size,new_size), (void**)&count_store);
78  mst_alloc_ptr(sizeof(int64_t)*MAX(old_size,new_size), (void**)&thread_store);
79  std::fill(bucket_store, bucket_store+MAX(old_size,new_size), -1);
80 
81  int64_t ** par_virt_counts;
82  alloc_ptr(sizeof(int64_t*)*max_ntd, (void**)&par_virt_counts);
83  for (int t=0; t<max_ntd; t++){
84  mst_alloc_ptr(sizeof(int64_t)*nbucket, (void**)&par_virt_counts[t]);
85  std::fill(par_virt_counts[t], par_virt_counts[t]+nbucket, 0);
86  }
87  #pragma omp parallel num_threads(max_ntd)
88  {
89  #endif
90 
91  int *offs; alloc_ptr(sizeof(int)*old_dist.order, (void**)&offs);
92  if (old_offsets == NULL)
93  for (int dim = 0;dim < old_dist.order;dim++) offs[dim] = 0;
94  else
95  for (int dim = 0;dim < old_dist.order;dim++) offs[dim] = old_offsets[dim];
96 
97  int *ends; alloc_ptr(sizeof(int)*old_dist.order, (void**)&ends);
98  for (int dim = 0;dim < old_dist.order;dim++) ends[dim] = len[dim];
99 
100  #ifdef USE_OMP
101  int tid = omp_get_thread_num();
102  int ntd = omp_get_num_threads();
103  //partition the global tensor among threads, to preserve
104  //global ordering and load balance in partitioning
105  int gidx_st[old_dist.order];
106  int gidx_end[old_dist.order];
107  if (old_dist.order > 1){
108  int64_t all_size = packed_size(old_dist.order, len, sym);
109  int64_t chnk = all_size/ntd;
110  int64_t glb_idx_st = chnk*tid + MIN(tid,all_size%ntd);
111  int64_t glb_idx_end = glb_idx_st+chnk+(tid<(all_size%ntd));
112  //calculate global indices along each dimension corresponding to partition
113 // printf("glb_idx_st = %ld, glb_idx_end = %ld\n",glb_idx_st,glb_idx_end);
114  calc_idx_arr(old_dist.order, len, sym, glb_idx_st, gidx_st);
115  calc_idx_arr(old_dist.order, len, sym, glb_idx_end, gidx_end);
116  gidx_st[0] = 0;
117  //FIXME: wrong but evidently not used
118  gidx_end[0] = 0;
119  #if DEBUG >= 1
120  if (ntd == 1){
121  if (gidx_end[old_dist.order-1] != len[old_dist.order-1]){
122  for (int dim=0; dim<old_dist.order; dim++){
123  printf("glb_idx_end = %ld, gidx_end[%d]= %d, len[%d] = %d\n",
124  glb_idx_end, dim, gidx_end[dim], dim, len[dim]);
125  }
126  ABORT;
127  }
128  ASSERT(gidx_end[old_dist.order-1] <= ends[old_dist.order-1]);
129  }
130  #endif
131  } else {
132  //FIXME the below means redistribution of a vector is non-threaded
133  if (tid == 0){
134  gidx_st[0] = 0;
135  gidx_end[0] = ends[0];
136  } else {
137  gidx_st[0] = 0;
138  gidx_end[0] = 0;
139  }
140 
141  }
142  //clip global indices to my physical cyclic phase (local tensor data)
143 
144  #endif
145  // FIXME: may be better to mst_alloc, but this should ensure the
146  // compiler knows there are no write conflicts
147  #ifdef USE_OMP
148  int64_t * count = par_virt_counts[tid];
149  #else
150  int64_t *count; alloc_ptr(sizeof(int64_t)*nbucket, (void**)&count);
151  memset(count, 0, sizeof(int64_t)*nbucket);
152  #endif
153 
154  int *gidx; alloc_ptr(sizeof(int)*old_dist.order, (void**)&gidx);
155  memset(gidx, 0, sizeof(int)*old_dist.order);
156  for (int dim = 0;dim < old_dist.order;dim++){
157  gidx[dim] = old_dist.perank[dim];
158  }
159 
160  int64_t *virt_offset; alloc_ptr(sizeof(int64_t)*old_dist.order, (void**)&virt_offset);
161  memset(virt_offset, 0, sizeof(int64_t)*old_dist.order);
162 
163  int *idx; alloc_ptr(sizeof(int)*old_dist.order, (void**)&idx);
164  memset(idx, 0, sizeof(int)*old_dist.order);
165 
166  int64_t *virt_acc; alloc_ptr(sizeof(int64_t)*old_dist.order, (void**)&virt_acc);
167  memset(virt_acc, 0, sizeof(int64_t)*old_dist.order);
168 
169  int64_t *idx_acc; alloc_ptr(sizeof(int64_t)*old_dist.order, (void**)&idx_acc);
170  memset(idx_acc, 0, sizeof(int64_t)*old_dist.order);
171 
172  int64_t *old_virt_lda; alloc_ptr(sizeof(int64_t)*old_dist.order, (void**)&old_virt_lda);
173  old_virt_lda[0] = old_virt_nelem;
174  for (int dim=1; dim<old_dist.order; dim++){
175  old_virt_lda[dim] = old_virt_lda[dim-1]*old_dist.virt_phase[dim-1];
176  }
177 
178  int64_t offset = 0;
179 
180  int64_t zero_len_toff = 0;
181 
182  #ifdef USE_OMP
183  for (int dim=old_dist.order-1; dim>=0; dim--){
184  int64_t iist = MAX(0,(gidx_st[dim]-old_dist.perank[dim]));
185  int64_t ist = iist/old_dist.phase[dim];//(old_phys_dim[dim]*old_dist.virt_phase[dim]);
186  if (sym[dim] != NS) ist = MIN(ist,idx[dim+1]);
187  int plen[old_dist.order];
188  memcpy(plen,old_virt_edge_len,old_dist.order*sizeof(int));
189  int idim = dim;
190  do {
191  plen[idim] = ist;
192  idim--;
193  } while (idim >= 0 && sym[idim] != NS);
194  //gidx[dim] += ist*old_phys_dim[dim]*old_dist.virt_phase[dim];
195  gidx[dim] += ist*old_dist.phase[dim];//old_phys_dim[dim]*old_dist.virt_phase[dim];
196  idx[dim] = ist;
197  idx_acc[dim] = sy_packed_size(dim+1, plen, sym);
198  offset += idx_acc[dim];
199 
200  ASSERT(ist == 0 || gidx[dim] <= gidx_st[dim]);
201  // ASSERT(ist < old_virt_edge_len[dim]);
202 
203  if (gidx[dim] > gidx_st[dim]) break;
204 
205  int64_t vst = iist-ist*old_dist.phase[dim];//*old_phys_dim[dim]*old_dist.virt_phase[dim];
206  if (vst > 0 ){
207  vst = MIN(old_dist.virt_phase[dim]-1,vst/old_dist.phys_phase[dim]);
208  gidx[dim] += vst*old_dist.phys_phase[dim];
209  virt_offset[dim] = vst;
210  offset += vst*old_virt_lda[dim];
211  } else vst = 0;
212  if (gidx[dim] > gidx_st[dim]) break;
213  }
214  #endif
215 
216  bool done = false;
217  for (;!done;){
218  int64_t bucket0 = 0;
219  bool outside0 = false;
220  int len_zero_max = ends[0];
221  #ifdef USE_OMP
222  bool is_at_end = true;
223  bool is_at_start = true;
224  for (int dim = old_dist.order-1;dim >0;dim--){
225  if (gidx[dim] > gidx_st[dim]){
226  is_at_start = false;
227  break;
228  }
229  if (gidx[dim] < gidx_st[dim]){
230  outside0 = true;
231  break;
232  }
233  }
234  if (is_at_start){
235  zero_len_toff = gidx_st[0];
236  }
237  for (int dim = old_dist.order-1;dim >0;dim--){
238  if (gidx_end[dim] < gidx[dim]){
239  outside0 = true;
240  done = true;
241  break;
242  }
243  if (gidx_end[dim] > gidx[dim]){
244  is_at_end = false;
245  break;
246  }
247  }
248  if (is_at_end){
249  len_zero_max = MIN(ends[0],gidx_end[0]);
250  done = true;
251  }
252  #endif
253 
254  if (!outside0){
255  for (int dim = 1;dim < old_dist.order;dim++){
256  if (bucket_offset[dim][virt_offset[dim]+idx[dim]*old_dist.virt_phase[dim]] == -1) outside0 = true;
257  bucket0 += bucket_offset[dim][virt_offset[dim]+idx[dim]*old_dist.virt_phase[dim]];
258  }
259  }
260 
261  if (!outside0){
262  for (int dim = 1;dim < old_dist.order;dim++){
263  if (gidx[dim] >= (sym[dim] == NS ? ends[dim] :
264  (sym[dim] == SY ? gidx[dim+1]+1 :
265  gidx[dim+1])) ||
266  gidx[dim] < offs[dim]){
267  outside0 = true;
268  break;
269  }
270  }
271  }
272 
273  int idx_max = (sym[0] == NS ? old_virt_edge_len[0] : idx[1]+1);
274  int idx_st = 0;
275 
276  if (!outside0){
277  int gidx_min = MAX(zero_len_toff,offs[0]);
278  int gidx_max = (sym[0] == NS ? ends[0] : (sym[0] == SY ? gidx[1]+1 : gidx[1]));
279  gidx_max = MIN(gidx_max, len_zero_max);
280  for (idx[0] = idx_st;idx[0] < idx_max;idx[0]++){
281  int virt_min = MAX(0,MIN(old_dist.virt_phase[0],(gidx_min-gidx[0])/old_dist.phys_phase[0]));
282  //int virt_min = MAX(0,MIN(old_dist.virt_phase[0],(gidx_min-gidx[0]+old_dist.phys_phase[0]-1)/old_dist.phys_phase[0]));
283  int virt_max = MAX(0,MIN(old_dist.virt_phase[0],(gidx_max-gidx[0]+old_dist.phys_phase[0]-1)/old_dist.phys_phase[0]));
284 
285  offset += old_virt_nelem*virt_min;
286  if (forward){
287  ASSERT(is_copy);
288  for (virt_offset[0] = virt_min;
289  virt_offset[0] < virt_max;
290  virt_offset[0] ++)
291  {
292  int64_t bucket = bucket0+bucket_offset[0][virt_offset[0]+idx[0]*old_dist.virt_phase[0]];
293  #ifdef USE_OMP
294  bucket_store[offset] = bucket;
295  count_store[offset] = count[bucket]++;
296  thread_store[offset] = tid;
297  #else
298 /* printf("[%d] bucket = %d offset = %ld\n", rank, bucket, offset);
299  printf("[%d] count[bucket] = %d, nbucket = %d\n", rank, count[bucket]+1, nbucket);
300  std::cout << "old_data[offset]=";
301  sr->print(old_data+ sr->el_size*offset);*/
302  sr->copy(new_data[bucket]+sr->el_size*(count[bucket]++), old_data+ sr->el_size*offset);
303 /* std::cout << "\nnew_data[bucket][count[bucket]++]=";
304  sr->print(new_data[bucket]+sr->el_size*(count[bucket]-1));
305  std::cout << "\n";*/
306  #endif
307  offset += old_virt_nelem;
308  }
309  }
310  else{
311  for (virt_offset[0] = virt_min;
312  virt_offset[0] < virt_max;
313  virt_offset[0] ++)
314  {
315  int64_t bucket = bucket0+bucket_offset[0][virt_offset[0]+idx[0]*old_dist.virt_phase[0]];
316  #ifdef USE_OMP
317  bucket_store[offset] = bucket;
318  count_store[offset] = count[bucket]++;
319  thread_store[offset] = tid;
320  #else
321  if (is_copy)
322  sr->copy(old_data+sr->el_size*offset, new_data[bucket]+sr->el_size*(count[bucket]++));
323  else
324  sr->acc( old_data+sr->el_size*offset, beta, new_data[bucket]+sr->el_size*(count[bucket]++), alpha);
325 // old_data[offset] = beta*old_data[offset] + alpha*new_data[bucket][count[bucket]++];
326  #endif
327  offset += old_virt_nelem;
328  }
329  }
330 
331  offset++;
332  offset -= old_virt_nelem*virt_max;
333  gidx[0] += old_dist.phase[0];//old_phys_dim[0]*old_dist.virt_phase[0];
334  }
335 
336  offset -= idx_max;
337  gidx[0] -= idx_max*old_dist.phase[0];//old_phys_dim[0]*old_dist.virt_phase[0];
338  }
339 
340  idx_acc[0] = idx_max;
341 
342  idx[0] = 0;
343 
344  zero_len_toff = 0;
345 
346  /* Adjust outer indices */
347  if (!done){
348  for (int dim = 1;dim < old_dist.order;dim++){
349  offset += old_virt_lda[dim];
350 
351  virt_offset[dim] ++;//= old_virt_edge_len[dim];
352  gidx[dim]+=old_dist.phys_phase[dim];
353  if (virt_offset[dim] == old_dist.virt_phase[dim]){
354  offset -= old_virt_lda[dim]*old_dist.virt_phase[dim];
355  gidx[dim] -= old_dist.phase[dim];
356  virt_offset[dim] = 0;
357 
358  offset += idx_acc[dim-1];
359  idx_acc[dim] += idx_acc[dim-1];
360  idx_acc[dim-1] = 0;
361 
362  gidx[dim] -= idx[dim]*old_dist.phase[dim];//phys_dim[dim]*old_dist.virt_phase[dim];
363  idx[dim]++;
364 
365  if (idx[dim] == (sym[dim] == NS ? old_virt_edge_len[dim] : idx[dim+1]+1)){
366  offset -= idx_acc[dim];
367  //index should always be zero here sicne everything is SY and not SH
368  idx[dim] = 0;//(dim == 0 || sym[dim-1] == NS ? 0 : idx[dim-1]);
369  //gidx[dim] += idx[dim]*old_phys_dim[dim]*old_dist.virt_phase[dim];
370 
371  if (dim == old_dist.order-1) done = true;
372  }
373  else{
374  //gidx[dim] += idx[dim]*old_phys_dim[dim]*old_dist.virt_phase[dim];
375  gidx[dim] += idx[dim]*old_dist.phase[dim];//old_phys_dim[dim]*old_dist.virt_phase[dim];
376  break;
377  }
378  }
379  else{
380  idx_acc[dim-1] = 0;
381  break;
382  }
383  }
384  if (old_dist.order <= 1) done = true;
385  }
386  }
387  cdealloc(gidx);
388  cdealloc(idx_acc);
389  cdealloc(virt_acc);
390  cdealloc(idx);
391  cdealloc(virt_offset);
392  cdealloc(old_virt_lda);
393 
394  #ifndef USE_OMP
395  #if DEBUG >= 1
396  bool pass = true;
397  for (int i = 0;i < nbucket-1;i++){
398  if (count[i] != (int64_t)((new_data[i+1]-new_data[i])/sr->el_size)){
399  printf("rank = %d count %d should have been %d is %ld\n", rank, i, (int)((new_data[i+1]-new_data[i])/sr->el_size), count[i]);
400  pass = false;
401  }
402  }
403  if (!pass) ABORT;
404  #endif
405  #endif
406  cdealloc(offs);
407  cdealloc(ends);
408 
409  #ifndef USE_OMP
410  cdealloc(count);
411  TAU_FSTOP(cyclic_pup_bucket);
412  #else
413  par_virt_counts[tid] = count;
414  } //#pragma omp endfor
415  for (int bckt=0; bckt<nbucket; bckt++){
416  int par_tmp = 0;
417  for (int thread=0; thread<max_ntd; thread++){
418  par_tmp += par_virt_counts[thread][bckt];
419  par_virt_counts[thread][bckt] = par_tmp - par_virt_counts[thread][bckt];
420  }
421  #if DEBUG >= 1
422  if (bckt < nbucket-1 && par_tmp != (new_data[bckt+1]-new_data[bckt])/sr->el_size){
423  printf("rank = %d count for bucket %d is %d should have been %ld\n",rank,bckt,par_tmp,(int64_t)(new_data[bckt+1]-new_data[bckt])/sr->el_size);
424  ABORT;
425  }
426  #endif
427  }
428  TAU_FSTOP(cyclic_pup_bucket);
429  TAU_FSTART(cyclic_pup_move);
430  {
431  int64_t tot_sz = MAX(old_size, new_size);
432  int64_t i;
433  if (forward){
434  ASSERT(is_copy);
435  #pragma omp parallel for private(i)
436  for (i=0; i<tot_sz; i++){
437  if (bucket_store[i] != -1){
438  int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
439  int64_t ct = count_store[i]+pc;
440  sr->copy(new_data[bucket_store[i]]+ct*sr->el_size, old_data+i*sr->el_size);
441  }
442  }
443  } else {
444  if (is_copy){// alpha == 1.0 && beta == 0.0){
445  #pragma omp parallel for private(i)
446  for (i=0; i<tot_sz; i++){
447  if (bucket_store[i] != -1){
448  int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
449  int64_t ct = count_store[i]+pc;
450  sr->copy(old_data+i*sr->el_size, new_data[bucket_store[i]]+ct*sr->el_size);
451  }
452  }
453  } else {
454  #pragma omp parallel for private(i)
455  for (i=0; i<tot_sz; i++){
456  if (bucket_store[i] != -1){
457  int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
458  int64_t ct = count_store[i]+pc;
459  sr->acc(old_data+i*sr->el_size, beta, new_data[bucket_store[i]]+ct*sr->el_size, alpha);
460  }
461  }
462  }
463  }
464  }
465  TAU_FSTOP(cyclic_pup_move);
466  for (int t=0; t<max_ntd; t++){
467  cdealloc(par_virt_counts[t]);
468  }
469  cdealloc(par_virt_counts);
470  cdealloc(count_store);
471  cdealloc(bucket_store);
472  cdealloc(thread_store);
473  #endif
474 
475  }
476 
477  void cyclic_reshuffle(int const * sym,
478  distribution const & old_dist,
479  int const * old_offsets,
480  int * const * old_permutation,
481  distribution const & new_dist,
482  int const * new_offsets,
483  int * const * new_permutation,
484  char ** ptr_tsr_data,
485  char ** ptr_tsr_cyclic_data,
486  algstrct const * sr,
487  CommData ord_glb_comm,
488  bool reuse_buffers,
489  char const * alpha,
490  char const * beta){
491  int i, np, old_nvirt, new_nvirt, old_np, new_np, idx_lyr;
492  int64_t vbs_old, vbs_new;
493  int64_t swp_nval;
494  int * hsym;
495  int64_t * send_counts, * recv_counts;
496  int * idx;
497  int64_t * idx_offs;
498  int64_t * send_displs;
499  int64_t * recv_displs;
500  int * new_virt_lda, * old_virt_lda;
501  int * old_sub_edge_len, * new_sub_edge_len;
502  int order = old_dist.order;
503 
504  char * tsr_data = *ptr_tsr_data;
505  char * tsr_cyclic_data = *ptr_tsr_cyclic_data;
506  if (order == 0){
507  bool is_copy = false;
508  if (sr->isequal(sr->mulid(), alpha) && sr->isequal(sr->addid(), beta)) is_copy = true;
509  alloc_ptr(sr->el_size, (void**)&tsr_cyclic_data);
510  if (ord_glb_comm.rank == 0){
511  if (is_copy)
512  sr->copy(tsr_cyclic_data, tsr_data);
513  else
514  sr->acc(tsr_cyclic_data, beta, tsr_data, alpha);
515  } else {
516  sr->copy(tsr_cyclic_data, sr->addid());
517  }
518  *ptr_tsr_cyclic_data = tsr_cyclic_data;
519  return;
520  }
521 
522  ASSERT(!reuse_buffers || sr->isequal(beta, sr->addid()));
523  ASSERT(old_dist.is_cyclic&&new_dist.is_cyclic);
524 
526  np = ord_glb_comm.np;
527 
528  alloc_ptr(order*sizeof(int), (void**)&hsym);
529  alloc_ptr(order*sizeof(int), (void**)&idx);
530  alloc_ptr(order*sizeof(int64_t), (void**)&idx_offs);
531  alloc_ptr(order*sizeof(int), (void**)&old_virt_lda);
532  alloc_ptr(order*sizeof(int), (void**)&new_virt_lda);
533 
534  new_nvirt = 1;
535  old_nvirt = 1;
536  old_np = 1;
537  new_np = 1;
538  idx_lyr = ord_glb_comm.rank;
539  for (i=0; i<order; i++) {
540  new_virt_lda[i] = new_nvirt;
541  old_virt_lda[i] = old_nvirt;
542  // nbuf = nbuf*new_dist.phase[i];
543  /*printf("is_new_pad = %d\n", is_new_pad);
544  if (is_new_pad)
545  printf("new_dist.padding[%d] = %d\n", i, new_dist.padding[i]);
546  printf("is_old_pad = %d\n", is_old_pad);
547  if (is_old_pad)
548  printf("old_dist.padding[%d] = %d\n", i, old_dist.padding[i]);*/
549  old_nvirt = old_nvirt*old_dist.virt_phase[i];
550  new_nvirt = new_nvirt*new_dist.virt_phase[i];
551  new_np = new_np*new_dist.phase[i]/new_dist.virt_phase[i];
552  old_np = old_np*old_dist.phase[i]/old_dist.virt_phase[i];
553  idx_lyr -= old_dist.perank[i]*old_dist.pe_lda[i];
554  }
555  vbs_old = old_dist.size/old_nvirt;
556 
557  mst_alloc_ptr(np*sizeof(int64_t), (void**)&recv_counts);
558  mst_alloc_ptr(np*sizeof(int64_t), (void**)&send_counts);
559  mst_alloc_ptr(np*sizeof(int64_t), (void**)&send_displs);
560  mst_alloc_ptr(np*sizeof(int64_t), (void**)&recv_displs);
561  alloc_ptr(order*sizeof(int), (void**)&old_sub_edge_len);
562  alloc_ptr(order*sizeof(int), (void**)&new_sub_edge_len);
563  int ** bucket_offset;
564 
565  int *real_edge_len; alloc_ptr(sizeof(int)*order, (void**)&real_edge_len);
566  for (i=0; i<order; i++) real_edge_len[i] = old_dist.pad_edge_len[i]-old_dist.padding[i];
567 
568  int *old_phys_dim; alloc_ptr(sizeof(int)*order, (void**)&old_phys_dim);
569  for (i=0; i<order; i++) old_phys_dim[i] = old_dist.phase[i]/old_dist.virt_phase[i];
570 
571  int *new_phys_dim; alloc_ptr(sizeof(int)*order, (void**)&new_phys_dim);
572  for (i=0; i<order; i++) new_phys_dim[i] = new_dist.phase[i]/new_dist.virt_phase[i];
573 
574  int *old_phys_edge_len; alloc_ptr(sizeof(int)*order, (void**)&old_phys_edge_len);
575  for (int dim = 0;dim < order;dim++) old_phys_edge_len[dim] = (real_edge_len[dim]+old_dist.padding[dim])/old_phys_dim[dim];
576 
577  int *new_phys_edge_len; alloc_ptr(sizeof(int)*order, (void**)&new_phys_edge_len);
578  for (int dim = 0;dim < order;dim++) new_phys_edge_len[dim] = (real_edge_len[dim]+new_dist.padding[dim])/new_phys_dim[dim];
579 
580  int *old_virt_edge_len; alloc_ptr(sizeof(int)*order, (void**)&old_virt_edge_len);
581  for (int dim = 0;dim < order;dim++) old_virt_edge_len[dim] = old_phys_edge_len[dim]/old_dist.virt_phase[dim];
582 
583  int *new_virt_edge_len; alloc_ptr(sizeof(int)*order, (void**)&new_virt_edge_len);
584  for (int dim = 0;dim < order;dim++) new_virt_edge_len[dim] = new_phys_edge_len[dim]/new_dist.virt_phase[dim];
585 
586 
587 
588  bucket_offset =
589  compute_bucket_offsets( old_dist,
590  new_dist,
591  real_edge_len,
592  old_phys_edge_len,
593  old_virt_lda,
594  old_offsets,
595  old_permutation,
596  new_phys_edge_len,
597  new_virt_lda,
598  1,
599  old_nvirt,
600  new_nvirt,
601  old_virt_edge_len);
602 
603 
604 
606  /* Calculate bucket counts to begin exchange */
607  calc_cnt_displs(sym,
608  old_dist,
609  new_dist,
610  new_nvirt,
611  np,
612  old_virt_edge_len,
613  new_virt_lda,
614  send_counts,
615  recv_counts,
616  send_displs,
617  recv_displs,
618  ord_glb_comm,
619  idx_lyr,
620  bucket_offset);
621 
623  /*for (i=0; i<np; i++){
624  printf("[%d] send_counts[%d] = %d recv_counts[%d] = %d\n", ord_glb_comm.rank, i, send_counts[i], i, recv_counts[i]);
625  }
626  for (i=0; i<nbuf; i++){
627  printf("[%d] svirt_displs[%d] = %d rvirt_displs[%d] = %d\n", ord_glb_comm.rank, i, svirt_displs[i], i, rvirt_displs[i]);
628  }*/
629 
630  // }
631  for (i=0; i<order; i++){
632  new_sub_edge_len[i] = new_dist.pad_edge_len[i];
633  old_sub_edge_len[i] = old_dist.pad_edge_len[i];
634  }
635  for (i=0; i<order; i++){
636  new_sub_edge_len[i] = new_sub_edge_len[i] / new_dist.phase[i];
637  old_sub_edge_len[i] = old_sub_edge_len[i] / old_dist.phase[i];
638  }
639  for (i=1; i<order; i++){
640  hsym[i-1] = sym[i];
641  }
642  swp_nval = new_nvirt*sy_packed_size(order, new_sub_edge_len, sym);
643  vbs_new = swp_nval/new_nvirt;
644 
645  char * send_buffer, * recv_buffer;
646  if (reuse_buffers){
647  mst_alloc_ptr(MAX(old_dist.size,swp_nval)*sr->el_size, (void**)&tsr_cyclic_data);
648  } else {
649  mst_alloc_ptr(old_dist.size*sr->el_size, (void**)&send_buffer);
650  mst_alloc_ptr(swp_nval*sr->el_size, (void**)&recv_buffer);
651  }
652 
653  TAU_FSTART(pack_virt_buf);
654  if (idx_lyr == 0){
655  /*char new1[old_dist.size*sr->el_size];
656  char new2[old_dist.size*sr->el_size];
657  sr->set(new1, sr->addid(), old_dist.size);
658  sr->set(new2, sr->addid(), old_dist.size);
659  //if (ord_glb_comm.rank == 0)
660  //printf("old_dist.size = %ld\n",old_dist.size);
661  //std::fill((double*)new1, ((double*)new1)+old_dist.size, 0.0);
662  //std::fill((double*)new2, ((double*)new2)+old_dist.size, 0.0);
663  order_globally(sym, old_dist, old_virt_edge_len, old_virt_lda, vbs_old, 1, tsr_data, new1, sr);
664  order_globally(sym, old_dist, old_virt_edge_len, old_virt_lda, vbs_old, 0, new1, new2, sr);
665 
666  if (ord_glb_comm.rank == 0){
667  for (int64_t i=0; i<old_dist.size; i++){
668  if (!sr->isequal(new2+i*sr->el_size, tsr_data +i*sr->el_size)){
669  printf("tsr_data[%ld] was ",i);
670  sr->print(tsr_data +i*sr->el_size);
671  printf(" became ");
672  sr->print(new2+i*sr->el_size);
673  printf("\n");
674  ASSERT(0);
675  }
676  }
677  }*/
678 
679 
680  char **new_data; alloc_ptr(sizeof(char*)*np, (void**)&new_data);
681  if (reuse_buffers){
682  for (int64_t p = 0;p < np;p++){
683  new_data[p] = tsr_cyclic_data+sr->el_size*send_displs[p];
684  }
685  } else {
686  for (int64_t p = 0;p < np;p++){
687  new_data[p] = send_buffer+sr->el_size*send_displs[p];
688  }
689  }
690 
692  old_dist,
693  new_dist,
694  real_edge_len,
695  old_phys_dim,
696  old_phys_edge_len,
697  old_virt_edge_len,
698  vbs_old,
699  old_offsets,
700  old_permutation,
701  np,
702  new_phys_dim,
703  new_phys_edge_len,
704  new_virt_edge_len,
705  vbs_new,
706  tsr_data,
707  new_data,
708  1,
709  bucket_offset,
710  sr->mulid(),
711  sr->addid(),
712  sr);
713  cdealloc(new_data);
714  }
715  for (int dim = 0;dim < order;dim++){
716  cdealloc(bucket_offset[dim]);
717  }
718  cdealloc(bucket_offset);
719 
720  TAU_FSTOP(pack_virt_buf);
721 
722  if (reuse_buffers){
723  if (swp_nval > old_dist.size){
724  cdealloc(tsr_data);
725  mst_alloc_ptr(swp_nval*sr->el_size, (void**)&tsr_data);
726  }
727  send_buffer = tsr_cyclic_data;
728  recv_buffer = tsr_data;
729  }
730 
731  /* Communicate data */
732  TAU_FSTART(ALL_TO_ALL_V);
733  ord_glb_comm.all_to_allv(send_buffer, send_counts, send_displs, sr->el_size,
734  recv_buffer, recv_counts, recv_displs);
735  TAU_FSTOP(ALL_TO_ALL_V);
736 
737  if (reuse_buffers)
738  sr->set(tsr_cyclic_data, sr->addid(), swp_nval);
739  else
740  cdealloc(send_buffer);
741  TAU_FSTART(unpack_virt_buf);
742  /* Deserialize data into correctly ordered virtual sub blocks */
743  if (recv_displs[ord_glb_comm.np-1] + recv_counts[ord_glb_comm.np-1] > 0){
744  char **new_data; alloc_ptr(sizeof(char*)*np, (void**)&new_data);
745  for (int64_t p = 0;p < np;p++){
746  new_data[p] = recv_buffer+recv_displs[p]*sr->el_size;
747  }
748  bucket_offset =
749  compute_bucket_offsets( new_dist,
750  old_dist,
751  real_edge_len,
752  new_phys_edge_len,
753  new_virt_lda,
754  new_offsets,
755  new_permutation,
756  old_phys_edge_len,
757  old_virt_lda,
758  0,
759  new_nvirt,
760  old_nvirt,
761  new_virt_edge_len);
762 
764  new_dist,
765  old_dist,
766  real_edge_len,
767  new_phys_dim,
768  new_phys_edge_len,
769  new_virt_edge_len,
770  vbs_new,
771  new_offsets,
772  new_permutation,
773  np,
774  old_phys_dim,
775  old_phys_edge_len,
776  old_virt_edge_len,
777  vbs_old,
778  tsr_cyclic_data,
779  new_data,
780  0,
781  bucket_offset,
782  alpha,
783  beta,
784  sr);
785  for (int dim = 0;dim < order;dim++){
786  cdealloc(bucket_offset[dim]);
787  }
788  cdealloc(bucket_offset);
789  cdealloc(new_data);
790  }
791  TAU_FSTOP(unpack_virt_buf);
792 
793  if (!reuse_buffers) cdealloc(recv_buffer);
794  *ptr_tsr_cyclic_data = tsr_cyclic_data;
795  *ptr_tsr_data = tsr_data;
796 
797  cdealloc(real_edge_len);
798  cdealloc(old_phys_dim);
799  cdealloc(new_phys_dim);
800  cdealloc(hsym);
801  cdealloc(idx);
802  cdealloc(idx_offs);
803  cdealloc(old_virt_lda);
804  cdealloc(new_virt_lda);
805  cdealloc(recv_counts);
806  cdealloc(send_counts);
807  cdealloc(send_displs);
808  cdealloc(recv_displs);
809  cdealloc(old_sub_edge_len);
810  cdealloc(new_sub_edge_len);
811  cdealloc(new_virt_edge_len);
812  cdealloc(old_virt_edge_len);
813  cdealloc(new_phys_edge_len);
814  cdealloc(old_phys_edge_len);
815 
817 
818  }
819 }
void calc_idx_arr(int order, int const *lens, int const *sym, int64_t idx, int *idx_arr)
Definition: util.cxx:72
int ** compute_bucket_offsets(distribution const &old_dist, distribution const &new_dist, int const *len, int const *old_phys_edge_len, int const *old_virt_lda, int const *old_offsets, int *const *old_permutation, int const *new_phys_edge_len, int const *new_virt_lda, int forward, int old_virt_np, int new_virt_np, int const *old_virt_edge_len)
computes offsets for redistribution targets along each edge length
Definition: redist.cxx:111
virtual bool isequal(char const *a, char const *b) const
returns true if algstrct elements a and b are equal
Definition: algstrct.cxx:340
def rank(self)
Definition: core.pyx:312
void acc(char *b, char const *beta, char const *a, char const *alpha) const
compute b=beta*b + alpha*a
Definition: algstrct.cxx:514
virtual void copy(char *a, char const *b) const
copies element b to element a
Definition: algstrct.cxx:538
#define ASSERT(...)
Definition: util.h:88
Definition: common.h:37
virtual char const * addid() const
MPI datatype for pairs.
Definition: algstrct.cxx:89
void cyclic_reshuffle(int const *sym, distribution const &old_dist, int const *old_offsets, int *const *old_permutation, distribution const &new_dist, int const *new_offsets, int *const *new_permutation, char **ptr_tsr_data, char **ptr_tsr_cyclic_data, algstrct const *sr, CommData ord_glb_comm, bool reuse_buffers, char const *alpha, char const *beta)
Goes from any set of phases to any new set of phases.
void all_to_allv(void *send_buffer, int64_t const *send_counts, int64_t const *send_displs, int64_t datum_size, void *recv_buffer, int64_t const *recv_counts, int64_t const *recv_displs)
performs all-to-all-v with 64-bit integer counts and offset on arbitrary length types (datum_size)...
Definition: common.cxx:424
void pad_cyclic_pup_virt_buff(int const *sym, distribution const &old_dist, distribution const &new_dist, int const *len, int const *old_phys_dim, int const *old_phys_edge_len, int const *old_virt_edge_len, int64_t old_virt_nelem, int const *old_offsets, int *const *old_permutation, int total_np, int const *new_phys_dim, int const *new_phys_edge_len, int const *new_virt_edge_len, int64_t new_virt_nelem, char *old_data, char **new_data, int forward, int *const *bucket_offset, char const *alpha, char const *beta, algstrct const *sr)
#define MAX(a, b)
Definition: util.h:180
virtual void set(char *a, char const *b, int64_t n) const
sets n elements of array a to value b
Definition: algstrct.cxx:629
int mst_alloc_ptr(int64_t len, void **const ptr)
mst_alloc abstraction
Definition: memcontrol.cxx:269
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
Definition: memcontrol.cxx:320
void calc_cnt_displs(int const *sym, distribution const &old_dist, distribution const &new_dist, int new_nvirt, int np, int const *old_virt_edge_len, int const *new_virt_lda, int64_t *send_counts, int64_t *recv_counts, int64_t *send_displs, int64_t *recv_displs, CommData ord_glb_comm, int idx_lyr, int *const *bucket_offset)
assigns keys to an array of values
Definition: redist.cxx:170
#define TAU_FSTOP(ARG)
Definition: util.h:281
#define TAU_FSTART(ARG)
Definition: util.h:280
int el_size
size of each element of algstrct in bytes
Definition: algstrct.h:16
int cdealloc(void *ptr)
free abstraction
Definition: memcontrol.cxx:480
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
#define MIN(a, b)
Definition: util.h:176
int64_t packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in packed symmetric (SY, SH, or AS) layout
Definition: util.cxx:38
Definition: common.h:37
virtual char const * mulid() const
identity element for multiplication i.e. 1
Definition: algstrct.cxx:93
#define ABORT
Definition: util.h:162
int64_t sy_packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in SY (NOT HOLLOW) packed symmetric layout
Definition: util.cxx:10
def np(self)
Definition: core.pyx:315