4 #include "../shared/util.h" 12 int const * old_phys_dim,
13 int const * old_phys_edge_len,
14 int const * old_virt_edge_len,
15 int64_t old_virt_nelem,
16 int const * old_offsets,
17 int *
const * old_permutation,
19 int const * new_phys_dim,
20 int const * new_phys_edge_len,
21 int const * new_virt_edge_len,
22 int64_t new_virt_nelem,
26 int *
const * bucket_offset,
32 if (old_dist.
order == 0){
34 sr->
copy(new_data[0], old_data);
37 sr->
copy(old_data, new_data[0]);
39 sr->
acc(old_data, beta, new_data[0], alpha);
50 int nbucket = total_np;
54 MPI_Comm_rank(MPI_COMM_WORLD,&
rank);
59 int max_ntd = omp_get_max_threads();
60 max_ntd =
MAX(1,
MIN(max_ntd,new_virt_nelem/nbucket));
62 int64_t old_size, new_size;
73 int64_t * bucket_store;
74 int64_t * count_store;
75 int64_t * thread_store;
76 mst_alloc_ptr(
sizeof(int64_t)*
MAX(old_size,new_size), (
void**)&bucket_store);
78 mst_alloc_ptr(
sizeof(int64_t)*
MAX(old_size,new_size), (
void**)&thread_store);
79 std::fill(bucket_store, bucket_store+
MAX(old_size,new_size), -1);
81 int64_t ** par_virt_counts;
82 alloc_ptr(
sizeof(int64_t*)*max_ntd, (
void**)&par_virt_counts);
83 for (
int t=0; t<max_ntd; t++){
84 mst_alloc_ptr(
sizeof(int64_t)*nbucket, (
void**)&par_virt_counts[t]);
85 std::fill(par_virt_counts[t], par_virt_counts[t]+nbucket, 0);
87 #pragma omp parallel num_threads(max_ntd) 92 if (old_offsets == NULL)
101 int tid = omp_get_thread_num();
102 int ntd = omp_get_num_threads();
105 int gidx_st[old_dist.
order];
106 int gidx_end[old_dist.
order];
107 if (old_dist.
order > 1){
109 int64_t chnk = all_size/ntd;
110 int64_t glb_idx_st = chnk*tid +
MIN(tid,all_size%ntd);
111 int64_t glb_idx_end = glb_idx_st+chnk+(tid<(all_size%ntd));
121 if (gidx_end[old_dist.
order-1] != len[old_dist.
order-1]){
123 printf(
"glb_idx_end = %ld, gidx_end[%d]= %d, len[%d] = %d\n",
124 glb_idx_end,
dim, gidx_end[
dim], dim, len[dim]);
135 gidx_end[0] = ends[0];
148 int64_t * count = par_virt_counts[tid];
150 int64_t *count;
alloc_ptr(
sizeof(int64_t)*nbucket, (
void**)&count);
151 memset(count, 0,
sizeof(int64_t)*nbucket);
155 memset(gidx, 0,
sizeof(
int)*old_dist.
order);
160 int64_t *virt_offset;
alloc_ptr(
sizeof(int64_t)*old_dist.
order, (
void**)&virt_offset);
161 memset(virt_offset, 0,
sizeof(int64_t)*old_dist.
order);
164 memset(idx, 0,
sizeof(
int)*old_dist.
order);
166 int64_t *virt_acc;
alloc_ptr(
sizeof(int64_t)*old_dist.
order, (
void**)&virt_acc);
167 memset(virt_acc, 0,
sizeof(int64_t)*old_dist.
order);
169 int64_t *idx_acc;
alloc_ptr(
sizeof(int64_t)*old_dist.
order, (
void**)&idx_acc);
170 memset(idx_acc, 0,
sizeof(int64_t)*old_dist.
order);
172 int64_t *old_virt_lda;
alloc_ptr(
sizeof(int64_t)*old_dist.
order, (
void**)&old_virt_lda);
173 old_virt_lda[0] = old_virt_nelem;
180 int64_t zero_len_toff = 0;
185 int64_t ist = iist/old_dist.
phase[
dim];
187 int plen[old_dist.
order];
188 memcpy(plen,old_virt_edge_len,old_dist.
order*
sizeof(
int));
193 }
while (idim >= 0 && sym[idim] !=
NS);
198 offset += idx_acc[
dim];
203 if (gidx[dim] > gidx_st[dim])
break;
205 int64_t vst = iist-ist*old_dist.
phase[
dim];
209 virt_offset[
dim] = vst;
210 offset += vst*old_virt_lda[
dim];
212 if (gidx[dim] > gidx_st[dim])
break;
219 bool outside0 =
false;
220 int len_zero_max = ends[0];
222 bool is_at_end =
true;
223 bool is_at_start =
true;
225 if (gidx[
dim] > gidx_st[
dim]){
229 if (gidx[dim] < gidx_st[dim]){
235 zero_len_toff = gidx_st[0];
238 if (gidx_end[
dim] < gidx[
dim]){
243 if (gidx_end[dim] > gidx[dim]){
249 len_zero_max =
MIN(ends[0],gidx_end[0]);
273 int idx_max = (sym[0] ==
NS ? old_virt_edge_len[0] : idx[1]+1);
277 int gidx_min =
MAX(zero_len_toff,offs[0]);
278 int gidx_max = (sym[0] ==
NS ? ends[0] : (sym[0] ==
SY ? gidx[1]+1 : gidx[1]));
279 gidx_max =
MIN(gidx_max, len_zero_max);
280 for (idx[0] = idx_st;idx[0] < idx_max;idx[0]++){
285 offset += old_virt_nelem*virt_min;
288 for (virt_offset[0] = virt_min;
289 virt_offset[0] < virt_max;
292 int64_t bucket = bucket0+bucket_offset[0][virt_offset[0]+idx[0]*old_dist.
virt_phase[0]];
294 bucket_store[offset] = bucket;
295 count_store[offset] = count[bucket]++;
296 thread_store[offset] = tid;
302 sr->
copy(new_data[bucket]+sr->
el_size*(count[bucket]++), old_data+ sr->
el_size*offset);
307 offset += old_virt_nelem;
311 for (virt_offset[0] = virt_min;
312 virt_offset[0] < virt_max;
315 int64_t bucket = bucket0+bucket_offset[0][virt_offset[0]+idx[0]*old_dist.
virt_phase[0]];
317 bucket_store[offset] = bucket;
318 count_store[offset] = count[bucket]++;
319 thread_store[offset] = tid;
322 sr->
copy(old_data+sr->
el_size*offset, new_data[bucket]+sr->
el_size*(count[bucket]++));
324 sr->
acc( old_data+sr->
el_size*offset, beta, new_data[bucket]+sr->
el_size*(count[bucket]++), alpha);
327 offset += old_virt_nelem;
332 offset -= old_virt_nelem*virt_max;
333 gidx[0] += old_dist.
phase[0];
337 gidx[0] -= idx_max*old_dist.
phase[0];
340 idx_acc[0] = idx_max;
349 offset += old_virt_lda[
dim];
356 virt_offset[
dim] = 0;
358 offset += idx_acc[
dim-1];
359 idx_acc[
dim] += idx_acc[
dim-1];
365 if (idx[
dim] == (sym[
dim] ==
NS ? old_virt_edge_len[
dim] : idx[
dim+1]+1)){
366 offset -= idx_acc[
dim];
371 if (
dim == old_dist.
order-1) done =
true;
384 if (old_dist.
order <= 1) done =
true;
397 for (
int i = 0;i < nbucket-1;i++){
398 if (count[i] != (int64_t)((new_data[i+1]-new_data[i])/sr->
el_size)){
399 printf(
"rank = %d count %d should have been %d is %ld\n",
rank, i, (
int)((new_data[i+1]-new_data[i])/sr->
el_size), count[i]);
413 par_virt_counts[tid] = count;
415 for (
int bckt=0; bckt<nbucket; bckt++){
417 for (
int thread=0; thread<max_ntd; thread++){
418 par_tmp += par_virt_counts[thread][bckt];
419 par_virt_counts[thread][bckt] = par_tmp - par_virt_counts[thread][bckt];
422 if (bckt < nbucket-1 && par_tmp != (new_data[bckt+1]-new_data[bckt])/sr->
el_size){
423 printf(
"rank = %d count for bucket %d is %d should have been %ld\n",
rank,bckt,par_tmp,(int64_t)(new_data[bckt+1]-new_data[bckt])/sr->
el_size);
431 int64_t tot_sz =
MAX(old_size, new_size);
435 #pragma omp parallel for private(i) 436 for (i=0; i<tot_sz; i++){
437 if (bucket_store[i] != -1){
438 int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
439 int64_t ct = count_store[i]+pc;
445 #pragma omp parallel for private(i) 446 for (i=0; i<tot_sz; i++){
447 if (bucket_store[i] != -1){
448 int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
449 int64_t ct = count_store[i]+pc;
454 #pragma omp parallel for private(i) 455 for (i=0; i<tot_sz; i++){
456 if (bucket_store[i] != -1){
457 int64_t pc = par_virt_counts[thread_store[i]][bucket_store[i]];
458 int64_t ct = count_store[i]+pc;
459 sr->
acc(old_data+i*sr->
el_size, beta, new_data[bucket_store[i]]+ct*sr->
el_size, alpha);
466 for (
int t=0; t<max_ntd; t++){
479 int const * old_offsets,
480 int *
const * old_permutation,
482 int const * new_offsets,
483 int *
const * new_permutation,
484 char ** ptr_tsr_data,
485 char ** ptr_tsr_cyclic_data,
491 int i,
np, old_nvirt, new_nvirt, old_np, new_np, idx_lyr;
492 int64_t vbs_old, vbs_new;
495 int64_t * send_counts, * recv_counts;
498 int64_t * send_displs;
499 int64_t * recv_displs;
500 int * new_virt_lda, * old_virt_lda;
501 int * old_sub_edge_len, * new_sub_edge_len;
502 int order = old_dist.
order;
504 char * tsr_data = *ptr_tsr_data;
505 char * tsr_cyclic_data = *ptr_tsr_cyclic_data;
507 bool is_copy =
false;
510 if (ord_glb_comm.
rank == 0){
512 sr->
copy(tsr_cyclic_data, tsr_data);
514 sr->
acc(tsr_cyclic_data, beta, tsr_data, alpha);
518 *ptr_tsr_cyclic_data = tsr_cyclic_data;
526 np = ord_glb_comm.
np;
528 alloc_ptr(order*
sizeof(
int), (
void**)&hsym);
529 alloc_ptr(order*
sizeof(
int), (
void**)&idx);
530 alloc_ptr(order*
sizeof(int64_t), (
void**)&idx_offs);
531 alloc_ptr(order*
sizeof(
int), (
void**)&old_virt_lda);
532 alloc_ptr(order*
sizeof(
int), (
void**)&new_virt_lda);
538 idx_lyr = ord_glb_comm.
rank;
539 for (i=0; i<order; i++) {
540 new_virt_lda[i] = new_nvirt;
541 old_virt_lda[i] = old_nvirt;
555 vbs_old = old_dist.
size/old_nvirt;
561 alloc_ptr(order*
sizeof(
int), (
void**)&old_sub_edge_len);
562 alloc_ptr(order*
sizeof(
int), (
void**)&new_sub_edge_len);
563 int ** bucket_offset;
565 int *real_edge_len;
alloc_ptr(
sizeof(
int)*order, (
void**)&real_edge_len);
568 int *old_phys_dim;
alloc_ptr(
sizeof(
int)*order, (
void**)&old_phys_dim);
569 for (i=0; i<order; i++) old_phys_dim[i] = old_dist.
phase[i]/old_dist.
virt_phase[i];
571 int *new_phys_dim;
alloc_ptr(
sizeof(
int)*order, (
void**)&new_phys_dim);
572 for (i=0; i<order; i++) new_phys_dim[i] = new_dist.
phase[i]/new_dist.
virt_phase[i];
574 int *old_phys_edge_len;
alloc_ptr(
sizeof(
int)*order, (
void**)&old_phys_edge_len);
577 int *new_phys_edge_len;
alloc_ptr(
sizeof(
int)*order, (
void**)&new_phys_edge_len);
580 int *old_virt_edge_len;
alloc_ptr(
sizeof(
int)*order, (
void**)&old_virt_edge_len);
583 int *new_virt_edge_len;
alloc_ptr(
sizeof(
int)*order, (
void**)&new_virt_edge_len);
631 for (i=0; i<order; i++){
635 for (i=0; i<order; i++){
636 new_sub_edge_len[i] = new_sub_edge_len[i] / new_dist.
phase[i];
637 old_sub_edge_len[i] = old_sub_edge_len[i] / old_dist.
phase[i];
639 for (i=1; i<order; i++){
642 swp_nval = new_nvirt*
sy_packed_size(order, new_sub_edge_len, sym);
643 vbs_new = swp_nval/new_nvirt;
645 char * send_buffer, * recv_buffer;
680 char **new_data;
alloc_ptr(
sizeof(
char*)*np, (
void**)&new_data);
682 for (int64_t p = 0;p <
np;p++){
683 new_data[p] = tsr_cyclic_data+sr->
el_size*send_displs[p];
686 for (int64_t p = 0;p <
np;p++){
687 new_data[p] = send_buffer+sr->
el_size*send_displs[p];
723 if (swp_nval > old_dist.
size){
727 send_buffer = tsr_cyclic_data;
728 recv_buffer = tsr_data;
734 recv_buffer, recv_counts, recv_displs);
738 sr->
set(tsr_cyclic_data, sr->
addid(), swp_nval);
743 if (recv_displs[ord_glb_comm.
np-1] + recv_counts[ord_glb_comm.
np-1] > 0){
744 char **new_data;
alloc_ptr(
sizeof(
char*)*np, (
void**)&new_data);
745 for (int64_t p = 0;p <
np;p++){
746 new_data[p] = recv_buffer+recv_displs[p]*sr->
el_size;
793 if (!reuse_buffers)
cdealloc(recv_buffer);
794 *ptr_tsr_cyclic_data = tsr_cyclic_data;
795 *ptr_tsr_data = tsr_data;
void calc_idx_arr(int order, int const *lens, int const *sym, int64_t idx, int *idx_arr)
int ** compute_bucket_offsets(distribution const &old_dist, distribution const &new_dist, int const *len, int const *old_phys_edge_len, int const *old_virt_lda, int const *old_offsets, int *const *old_permutation, int const *new_phys_edge_len, int const *new_virt_lda, int forward, int old_virt_np, int new_virt_np, int const *old_virt_edge_len)
computes offsets for redistribution targets along each edge length
virtual bool isequal(char const *a, char const *b) const
returns true if algstrct elements a and b are equal
void acc(char *b, char const *beta, char const *a, char const *alpha) const
compute b=beta*b + alpha*a
virtual void copy(char *a, char const *b) const
copies element b to element a
virtual char const * addid() const
MPI datatype for pairs.
void cyclic_reshuffle(int const *sym, distribution const &old_dist, int const *old_offsets, int *const *old_permutation, distribution const &new_dist, int const *new_offsets, int *const *new_permutation, char **ptr_tsr_data, char **ptr_tsr_cyclic_data, algstrct const *sr, CommData ord_glb_comm, bool reuse_buffers, char const *alpha, char const *beta)
Goes from any set of phases to any new set of phases.
void all_to_allv(void *send_buffer, int64_t const *send_counts, int64_t const *send_displs, int64_t datum_size, void *recv_buffer, int64_t const *recv_counts, int64_t const *recv_displs)
performs all-to-all-v with 64-bit integer counts and offset on arbitrary length types (datum_size)...
void pad_cyclic_pup_virt_buff(int const *sym, distribution const &old_dist, distribution const &new_dist, int const *len, int const *old_phys_dim, int const *old_phys_edge_len, int const *old_virt_edge_len, int64_t old_virt_nelem, int const *old_offsets, int *const *old_permutation, int total_np, int const *new_phys_dim, int const *new_phys_edge_len, int const *new_virt_edge_len, int64_t new_virt_nelem, char *old_data, char **new_data, int forward, int *const *bucket_offset, char const *alpha, char const *beta, algstrct const *sr)
virtual void set(char *a, char const *b, int64_t n) const
sets n elements of array a to value b
int mst_alloc_ptr(int64_t len, void **const ptr)
mst_alloc abstraction
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
void calc_cnt_displs(int const *sym, distribution const &old_dist, distribution const &new_dist, int new_nvirt, int np, int const *old_virt_edge_len, int const *new_virt_lda, int64_t *send_counts, int64_t *recv_counts, int64_t *send_displs, int64_t *recv_displs, CommData ord_glb_comm, int idx_lyr, int *const *bucket_offset)
assigns keys to an array of values
int el_size
size of each element of algstrct in bytes
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
int64_t packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in packed symmetric (SY, SH, or AS) layout
virtual char const * mulid() const
identity element for multiplication i.e. 1
int64_t sy_packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in SY (NOT HOLLOW) packed symmetric layout