5 #include "../shared/util.h" 13 int const * new_edge_len,
14 int *
const * permutation,
16 int64_t * new_num_pair,
20 int mntd = omp_get_max_threads();
25 std::fill(counts,counts+mntd,0);
30 int i, j, tid, ntd, outside;
31 int64_t lda, wkey, knew, kdim, tstart, tnum_pair, cnum_pair;
33 tid = omp_get_thread_num();
34 ntd = omp_get_num_threads();
39 tnum_pair = num_pair/ntd;
40 tstart = tnum_pair * tid +
MIN(tid, num_pair % ntd);
41 if (tid < num_pair % ntd) tnum_pair++;
46 char * my_pairs_buf = sr->
pair_alloc(tnum_pair);
51 for (i=tstart; i<tstart+tnum_pair; i++){
56 for (j=0; j<order; j++){
57 kdim = wkey%edge_len[j];
58 if (permutation[j] != NULL){
59 if (permutation[j][kdim] == -1){
62 knew += lda*permutation[j][kdim];
67 lda *= new_edge_len[j];
68 wkey = wkey/edge_len[j];
74 sr->
set_pair(tkp, knew, pairs[i].d());
75 my_pairs[cnum_pair].
write(tkp);
79 counts[tid] = cnum_pair;
85 for (i=0; i<tid; i++){
88 pairs[pfx].
write(my_pairs,cnum_pair);
93 for (
int i=0; i<mntd; i++){
94 *new_num_pair += counts[i];
101 int const * edge_len,
102 int const * new_edge_len,
103 int *
const * permutation,
108 int mntd = omp_get_max_threads();
112 int64_t counts[mntd];
113 std::fill(counts,counts+mntd,0);
116 for (
int d=0; d<order; d++){
117 if (permutation[d] == NULL){
118 depermutation[d] = NULL;
120 depermutation[d] = (
int*)
CTF_int::alloc(new_edge_len[d]*
sizeof(
int));
121 std::fill(depermutation[d],depermutation[d]+new_edge_len[d], -1);
122 for (
int i=0; i<edge_len[d]; i++){
123 if (permutation[d][i] > -1)
124 depermutation[d][permutation[d][i]] = i;
134 int64_t lda, wkey, knew, kdim, tstart, tnum_pair;
136 tid = omp_get_thread_num();
137 ntd = omp_get_num_threads();
142 tnum_pair = num_pair/ntd;
143 tstart = tnum_pair * tid +
MIN(tid, num_pair % ntd);
144 if (tid < num_pair % ntd) tnum_pair++;
152 for (i=tstart; i<tstart+tnum_pair; i++){
156 for (j=0; j<order; j++){
157 kdim = wkey%new_edge_len[j];
158 if (depermutation[j] != NULL){
159 ASSERT(depermutation[j][kdim] != -1);
160 knew += lda*depermutation[j][kdim];
165 wkey = wkey/new_edge_len[j];
170 for (
int d=0; d<order; d++){
171 if (permutation[d] != NULL)
183 int const * edge_len,
186 int const * phys_phase,
187 int const * virt_dim,
192 int i, imax, act_lda, act_max;
193 int64_t p, idx_offset, buf_offset;
194 int * idx, * virt_rank;
209 memset(virt_rank, 0,
sizeof(
int)*order);
212 for (i=1; i<order; i++){
213 edge_lda[i] = edge_lda[i-1]*edge_len[i-1];
216 char const * data = vdata + sr->
el_size*p*(size/nvirt);
219 idx_offset = 0, buf_offset = 0;
220 for (act_lda=1; act_lda<order; act_lda++){
221 idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
225 memset(idx, 0, order*
sizeof(
int));
226 imax = edge_len[0]/phase[0];
231 for (i=0; i<imax; i++){
232 ASSERT(buf_offset+i<size);
233 if (p*(size/nvirt) + buf_offset + i >= size){
234 printf(
"exceeded how much I was supposed to read %ld/%ld\n", p*(size/nvirt)+buf_offset+i,size);
237 pairs[buf_offset+i].write_key(idx_offset+i*phase[0]+phase_rank[0]);
238 pairs[buf_offset+i].write_val(data+(buf_offset+i)*sr->
el_size);
242 for (act_lda=1; act_lda < order; act_lda++){
243 idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
246 act_max = edge_len[act_lda]/phase[act_lda];
247 if (sym[act_lda] !=
NS) act_max = idx[act_lda+1]+1;
248 if (idx[act_lda] >= act_max)
250 idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
252 ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
253 if (idx[act_lda] > 0)
256 if (act_lda >= order)
break;
258 for (act_lda=0; act_lda < order; act_lda++){
259 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
260 virt_rank[act_lda]++;
261 if (virt_rank[act_lda] >= virt_dim[act_lda])
262 virt_rank[act_lda] = 0;
263 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
264 if (virt_rank[act_lda] > 0)
267 if (act_lda >= order)
break;
269 ASSERT(buf_offset == size/nvirt);
279 int const * edge_len,
282 int const * phys_phase,
283 int const * virt_dim,
289 int64_t
const * edge_lda,
290 std::function<
bool(
char const*)> f){
291 int i, imax, act_lda, act_max;
292 int64_t p, idx_offset, buf_offset;
293 int * idx, * virt_rank;
294 memset(nnz_blk, 0,
sizeof(int64_t)*nvirt);
302 }
else vpairs = NULL;
311 memset(virt_rank, 0,
sizeof(
int)*order);
317 char const * data = vdata + sr->
el_size*p*(size/nvirt);
321 memset(idx, 0, order*
sizeof(
int));
322 imax = edge_len[0]/phase[0];
327 for (i=0; i<imax; i++){
328 ASSERT(buf_offset+i<size);
329 keep_vals[buf_offset+i] = f(data+(buf_offset+i)*sr->
el_size);
330 nnz_blk[virt_blk] += keep_vals[buf_offset+i];
334 for (act_lda=1; act_lda < order; act_lda++){
336 act_max = edge_len[act_lda]/phase[act_lda];
337 if (sym[act_lda] !=
NS) act_max = idx[act_lda+1]+1;
338 if (idx[act_lda] >= act_max)
340 ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
341 if (idx[act_lda] > 0)
344 if (act_lda >= order)
break;
346 for (act_lda=0; act_lda < order; act_lda++){
347 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
348 virt_rank[act_lda]++;
349 if (virt_rank[act_lda] >= virt_dim[act_lda])
350 virt_rank[act_lda] = 0;
351 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
352 if (virt_rank[act_lda] > 0)
356 if (act_lda >= order)
break;
358 int64_t * nnz_blk_lda = (int64_t*)
alloc(
sizeof(int64_t)*nvirt);
360 for (
int i=1; i<nvirt; i++){
361 nnz_blk_lda[i] = nnz_blk_lda[i-1]+nnz_blk[i-1];
363 vpairs = sr->
pair_alloc(nnz_blk_lda[nvirt-1]+nnz_blk[nvirt-1]);
366 memset(nnz_blk, 0,
sizeof(int64_t)*nvirt);
369 char const * data = vdata + sr->
el_size*p*(size/nvirt);
372 idx_offset = 0, buf_offset = 0;
373 for (act_lda=1; act_lda<order; act_lda++){
374 idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
378 memset(idx, 0, order*
sizeof(
int));
379 imax = edge_len[0]/phase[0];
384 for (i=0; i<imax; i++){
385 ASSERT(buf_offset+i<size);
386 if (keep_vals[buf_offset+i]){
388 pairs[nnz_blk[virt_blk]].write_key(idx_offset+i*phase[0]+phase_rank[0]);
389 pairs[nnz_blk[virt_blk]].write_val(data+(buf_offset+i)*sr->
el_size);
395 for (act_lda=1; act_lda < order; act_lda++){
396 idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
399 act_max = edge_len[act_lda]/phase[act_lda];
400 if (sym[act_lda] !=
NS) act_max = idx[act_lda+1]+1;
401 if (idx[act_lda] >= act_max)
403 idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
405 ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
406 if (idx[act_lda] > 0)
409 if (act_lda >= order)
break;
411 for (act_lda=0; act_lda < order; act_lda++){
412 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
413 virt_rank[act_lda]++;
414 if (virt_rank[act_lda] >= virt_dim[act_lda])
415 virt_rank[act_lda] = 0;
416 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
417 if (virt_rank[act_lda] > 0)
421 if (act_lda >= order)
break;
435 int const * phys_phase,
436 int const * virt_phase,
437 int const * bucket_lda,
438 int const * edge_len,
440 int64_t * bucket_counts,
441 int64_t * bucket_off,
445 memset(bucket_counts, 0,
sizeof(int64_t)*np);
447 int64_t * sub_counts, * sub_offs;
450 memset(sub_counts, 0, np*
sizeof(int64_t)*omp_get_max_threads());
457 #pragma omp parallel for schedule(static,256) 459 for (int64_t i=0; i<num_pair; i++){
460 int64_t k = mapped_data[i].
k();
463 for (
int j=0; j<order; j++){
469 loc += ((k%edge_len[j])%phys_phase[j])*bucket_lda[j];
477 sub_counts[loc+omp_get_thread_num()*
np]++;
479 bucket_counts[loc]++;
485 for (
int j=0; j<omp_get_max_threads(); j++){
486 for (int64_t i=0; i<
np; i++){
487 bucket_counts[i] = sub_counts[j*np+i] + bucket_counts[i];
494 for (int64_t i=1; i<
np; i++){
495 bucket_off[i] = bucket_counts[i-1] + bucket_off[i-1];
500 memset(sub_offs, 0,
sizeof(int64_t)*np);
501 for (
int i=1; i<omp_get_max_threads(); i++){
502 for (int64_t j=0; j<
np; j++){
503 sub_offs[j+i*
np]=sub_counts[j+(i-1)*np]+sub_offs[j+(i-1)*
np];
507 memset(bucket_counts, 0,
sizeof(int64_t)*np);
513 #pragma omp parallel for schedule(static,256) 515 for (int64_t i=0; i<num_pair; i++){
516 int64_t k = mapped_data[i].
k();
518 for (
int j=0; j<order; j++){
521 loc += ((k%edge_len[j])%phys_phase[j])*bucket_lda[j];
525 bucket_data[bucket_off[loc] + sub_offs[loc+omp_get_thread_num()*
np]].
write(mapped_data[i]);
526 sub_offs[loc+omp_get_thread_num()*
np]++;
528 bucket_data[bucket_off[loc] + bucket_counts[loc]].
write(mapped_data[i]);
529 bucket_counts[loc]++;
542 int const * phys_phase,
543 int const * virt_phase,
544 int const * edge_len,
548 int64_t * virt_counts, * virt_prefix, * virt_lda;
558 for (
int i=1; i<order; i++){
559 ASSERT(virt_phase[i] > 0);
560 virt_lda[i] = virt_phase[i-1]*virt_lda[i-1];
564 memset(virt_counts, 0,
sizeof(int64_t)*num_virt);
566 int64_t * sub_counts, * sub_offs;
567 CTF_int::alloc_ptr(num_virt*
sizeof(int64_t)*omp_get_max_threads(), (
void**)&sub_counts);
568 CTF_int::alloc_ptr(num_virt*
sizeof(int64_t)*omp_get_max_threads(), (
void**)&sub_offs);
569 memset(sub_counts, 0, num_virt*
sizeof(int64_t)*omp_get_max_threads());
576 #pragma omp parallel for schedule(static) 577 for (int64_t i=0; i<num_pair; i++){
578 int64_t k = mapped_data[i].
k();
581 for (
int j=0; j<order; j++){
584 loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
589 sub_counts[loc+omp_get_thread_num()*num_virt]++;
593 for (
int j=0; j<omp_get_max_threads(); j++){
594 for (int64_t i=0; i<num_virt; i++){
595 virt_counts[i] = sub_counts[j*num_virt+i] + virt_counts[i];
599 for (int64_t i=1; i<num_virt; i++){
600 virt_prefix[i] = virt_prefix[i-1] + virt_counts[i-1];
603 memset(sub_offs, 0,
sizeof(int64_t)*num_virt);
604 for (
int i=1; i<omp_get_max_threads(); i++){
605 for (int64_t j=0; j<num_virt; j++){
606 sub_offs[j+i*num_virt]=sub_counts[j+(i-1)*num_virt]+sub_offs[j+(i-1)*num_virt];
609 TAU_FSTOP(bucket_by_virt_assemble_offsets);
611 #pragma omp parallel for schedule(static) 612 for (int64_t i=0; i<num_pair; i++){
613 int64_t k = mapped_data[i].
k();
616 for (
int j=0; j<order; j++){
619 loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
622 bucket_data[virt_prefix[loc] + sub_offs[loc+omp_get_thread_num()*num_virt]].
write(mapped_data[i]);
623 sub_offs[loc+omp_get_thread_num()*num_virt]++;
627 for (int64_t i=0; i<num_pair; i++){
628 int64_t k = mapped_data[i].
k();
630 for (
int j=0; j<order; j++){
633 loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
640 for (int64_t i=1; i<num_virt; i++){
641 virt_prefix[i] = virt_prefix[i-1] + virt_counts[i-1];
643 memset(virt_counts, 0,
sizeof(int64_t)*num_virt);
645 for (int64_t i=0; i<num_pair; i++){
646 int64_t k = mapped_data[i].
k();
648 for (
int j=0; j<order; j++){
651 loc += (((k%edge_len[j])/phys_phase[j])%virt_phase[j])*virt_lda[j];
654 bucket_data[virt_prefix[loc] + virt_counts[loc]].
write(mapped_data[i]);
661 #pragma omp parallel for 663 for (int64_t i=0; i<num_virt; i++){
666 bucket_data[virt_prefix[i]].
sort(virt_counts[i]);
690 int const * edge_len,
693 int const * phys_phase,
694 int const * virt_dim,
701 int64_t idx_offset, act_max, buf_offset, pr_offset, p;
702 int64_t * idx, * virt_rank, * edge_lda;
709 for (int64_t i=1; i<size; i++){
712 ASSERT(pairs[i].k() == 0 || pairs[i].d() != pairs[0].d());
731 memset(virt_rank, 0,
sizeof(int64_t)*order);
733 for (
int i=1; i<order; i++){
734 edge_lda[i] = edge_lda[i-1]*edge_len[i-1];
742 data = data + sr->
el_size*buf_offset;
743 idx_offset = 0, buf_offset = 0;
744 for (act_lda=1; act_lda<order; act_lda++){
745 idx_offset += phase_rank[act_lda]*edge_lda[act_lda];
748 memset(idx, 0, order*
sizeof(int64_t));
749 int64_t imax = edge_len[0]/phase[0];
754 for (int64_t i=0; i<imax;){
755 if (pr_offset >= size)
758 if (pairs[pr_offset].k() == idx_offset +i*phase[0]+phase_rank[0]){
770 sr->
mul(alpha, data + sr->
el_size*(buf_offset+i), wval);
772 sr->
mul(beta, pairs[pr_offset].d(), wval2);
773 sr->
add(wval, wval2, wval);
783 sr->
mul(beta, data + sr->
el_size*(buf_offset+i), wval);
785 sr->
mul(alpha, pairs[pr_offset].d(), wval2);
786 sr->
add(wval, wval2, wval);
793 while (pr_offset < size && pairs[pr_offset].k() == pairs[pr_offset-1].k()){
801 sr->
mul(alpha, data + sr->
el_size*(buf_offset+i), wval);
803 sr->
mul(beta, pairs[pr_offset].d(), wval2);
804 sr->
add(wval, wval2, wval);
811 pairs[pr_offset].
d(),
812 data + (buf_offset+i)*sr->
el_size);
816 sr->
mul(alpha, pairs[pr_offset].d(), wval);
817 sr->
add(wval, data + sr->
el_size*(buf_offset+i), wval);
834 if (pr_offset >= size)
837 for (act_lda=1; act_lda < order; act_lda++){
838 idx_offset -= (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
841 act_max = edge_len[act_lda]/phase[act_lda];
842 if (sym[act_lda] !=
NS) act_max = idx[act_lda+1]+1;
843 if (idx[act_lda] >= act_max)
845 idx_offset += (idx[act_lda]*phase[act_lda]+phase_rank[act_lda])
847 ASSERT(edge_len[act_lda]%phase[act_lda] == 0);
848 if (idx[act_lda] > 0)
851 if (act_lda == order)
break;
853 for (act_lda=0; act_lda < order; act_lda++){
854 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
855 virt_rank[act_lda]++;
856 if (virt_rank[act_lda] >= virt_dim[act_lda])
857 virt_rank[act_lda] = 0;
858 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
859 if (virt_rank[act_lda] > 0)
862 if (act_lda == order)
break;
866 ASSERT(pr_offset == size);
880 int const * edge_len,
883 int const * phys_phase,
884 int const * virt_phase,
885 int * virt_phys_rank,
886 int const * bucket_lda,
895 int64_t & nnz_loc_new){
896 int64_t new_num_pair, nwrite, swp;
897 int64_t * bucket_counts, * recv_counts;
898 int64_t * recv_displs, * send_displs;
899 int * depadding, * depad_edge_len;
901 int j, is_out,
sign, is_perm;
902 char * swap_datab, * buf_datab;
903 int64_t * old_nnz_blk;
906 memcpy(old_nnz_blk, nnz_blk, num_virt*
sizeof(int64_t));
922 int64_t total_tsr_size = 1;
923 for (
int i=0; i<order; i++){
924 total_tsr_size *= edge_len[i];
927 for (int64_t i=0; i<inwrite; i++){
928 if (wr_pairs[i].k()>=total_tsr_size)
929 printf(
"[%d] %ldth key is %ld size %ld\n",glb_comm.
rank, i, wr_pairs[i].
k(),total_tsr_size);
930 ASSERT(wr_pairs[i].k() >= 0);
931 ASSERT(wr_pairs[i].k() < total_tsr_size);
939 for (
int i=0; i<order; i++){
940 depad_edge_len[i] = edge_len[i] - padding[i];
946 int64_t nchanged = 0;
947 for (int64_t i=0; i<inwrite; i++){
948 cvrt_idx(order, depad_edge_len, wr_pairs[i].k(), ckey);
952 while (is_perm && !is_out){
954 for (j=0; j<order-1; j++){
955 if ((sym[j] ==
SH || sym[j] ==
AS) && ckey[j] == ckey[j+1]){
958 }
else if (sym[j] !=
NS && ckey[j] > ckey[j+1]){
976 cvrt_idx(order, depad_edge_len, ckey, &skey);
977 if (rw ==
'r' && skey != wr_pairs[i].k()){
980 }
else if (rw ==
'r'){
986 int64_t * changed_key_indices;
987 char * new_changed_pairs = sr->
pair_alloc(nchanged);
989 int * changed_key_scale;
994 for (int64_t i=0; i<inwrite; i++){
995 cvrt_idx(order, depad_edge_len, wr_pairs[i].k(), ckey);
999 while (is_perm && !is_out){
1001 for (j=0; j<order-1; j++){
1002 if ((sym[j] ==
SH || sym[j] ==
AS) && ckey[j] == ckey[j+1]){
1005 }
else if (sym[j] !=
NS && ckey[j] > ckey[j+1]){
1007 ckey[j] = ckey[j+1];
1022 int64_t ky = swap_data[nwrite].
k();
1023 cvrt_idx(order, depad_edge_len, ckey, &ky);
1026 swap_data[nwrite].
write_val(wr_pairs[i].d());
1029 sr->
addinv(wr_pairs[i].d(), ainv);
1032 if (rw ==
'r' && swap_data[nwrite].k() != wr_pairs[i].k()){
1035 changed_key_indices[nchanged]= i;
1036 swap_data[nwrite].
read(ncp[nchanged].ptr);
1037 changed_key_scale[nchanged] =
sign;
1041 }
else if (rw ==
'r'){
1042 changed_key_indices[nchanged] = i;
1043 wr_pairs[i].
read(ncp[nchanged].ptr);
1044 changed_key_scale[nchanged] = 0;
1054 pad_key(order, nwrite, depad_edge_len, padding, swap_data, sr);
1057 }
else wlen = depad_edge_len;
1061 phys_phase, virt_phase, bucket_lda,
1062 wlen, swap_data, bucket_counts,
1063 send_displs, buf_data, sr);
1066 MPI_Alltoall(bucket_counts, 1, MPI_INT64_T,
1067 recv_counts, 1, MPI_INT64_T, glb_comm.
cm);
1071 for (
int i=1; i<
np; i++){
1072 recv_displs[i] = recv_displs[i-1] + recv_counts[i-1];
1074 new_num_pair = recv_displs[np-1] + recv_counts[np-1];
1087 if (new_num_pair > nwrite){
1095 if (glb_comm.
np == 1){
1096 char * save_ptr = buf_datab;
1097 buf_datab = swap_datab;
1098 swap_datab = save_ptr;
1103 swap_data.
ptr, recv_counts, recv_displs);
1108 if (new_num_pair > nwrite){
1115 int64_t * virt_counts =
1116 bucket_by_virt(order, num_virt, new_num_pair, phys_phase, virt_phase,
1117 wlen, swap_data, buf_data, sr);
1123 sp_read(sr, nnz_loc, prs_tsr, alpha, new_num_pair, buf_data, beta);
1127 sp_write(num_virt, sr, old_nnz_blk, prs_tsr, beta, virt_counts, prs_write, alpha, nnz_blk, pprs_new);
1128 for (
int v=0; v<num_virt; v++){
1129 if (v==0) nnz_loc_new = nnz_blk[0];
1130 else nnz_loc_new += nnz_blk[v];
1158 buf_data.
sort(new_num_pair);
1160 for (int64_t i=0; i<new_num_pair; i++){
1164 int64_t el_loc = buf_data.
lower_bound(new_num_pair, swap_data[i]);
1166 if (el_loc < 0 || el_loc >= new_num_pair){
1168 DEBUG_PRINTF(
"swap_data[%d].k = %d, not found\n", i, (
int)swap_data[i].k());
1172 swap_data[i].
write_val(buf_data[el_loc].d());
1179 buf_data.
ptr, bucket_counts, send_displs);
1183 for (
int i=0; i<order; i++){
1184 depadding[i] = -padding[i];
1186 pad_key(order, nwrite, edge_len, depadding, buf_data, sr);
1191 buf_data.
sort(nwrite);
1194 for (int64_t i=0; i<inwrite; i++){
1195 if (j<(int64_t)nchanged && changed_key_indices[j] == i){
1196 if (changed_key_scale[j] == 0){
1202 if (changed_key_scale[j] == -1){
1204 sr->
addinv(buf_data[el_loc].d(), aspr);
1207 wr_pairs[i].
write_val(buf_data[el_loc].d());
1211 int64_t el_loc = buf_data.
lower_bound(nwrite, wr_pairs[i]);
1213 wr_pairs[i].
write_val(buf_data[el_loc].d());
1218 if (is_sparse)
cdealloc(depad_edge_len);
1239 int const * edge_len,
1240 int const * padding,
1242 int const * phys_phase,
1243 int const * virt_phase,
1254 memset(prepadding, 0,
sizeof(
int)*order);
1274 int64_t new_num_pair;
1285 for (i=0; i<order; i++){
1286 pad_len[i] = edge_len[i]-padding[i];
1289 depad_tsr(order, nval, pad_len, sym, padding, prepadding,
1290 dpairsb, new_pairsb, &new_num_pair, sr);
1293 if (new_num_pair == 0){
1297 *pairs = new_pairsb;
1298 *nread = new_num_pair;
1300 for (i=0; i<order; i++){
1301 depadding[i] = -padding[i];
1305 pad_key(order, new_num_pair, edge_len, depadding, new_pairs, sr);
1321 for (int64_t t=0; t<ntsr && r<nread; r++){
1322 while (t<ntsr && r<nread && prs_tsr[t].k() != prs_read[r].k()){
1323 if (prs_tsr[t].k() < prs_read[r].k())
1331 if (t<ntsr && r<nread){
1336 sr->
mul(prs_read[r].d(), beta, a);
1341 sr->
mul(prs_tsr[t].d(), alpha, b);
1349 if (beta == NULL && alpha != NULL){
1351 }
else if (beta != NULL){
1357 for (; r<nread; r++){
1373 int64_t tot_new = 0;
1376 for (
int v=0; v<num_virt; v++){
1377 int64_t ntsr = vntsr[v];
1378 int64_t nwrite = vnwrite[v];
1380 prs_tsr = prs_tsr[vntsr[v-1]];
1381 prs_write = prs_write[vnwrite[v-1]];
1385 for (int64_t t=0,
w=0;
w<nwrite;
w++){
1387 if (t<ntsr && prs_tsr[t].k() < prs_write[
w].k())
1389 else if (t<ntsr && prs_tsr[t].k() == prs_write[
w].k()){
1393 if (
w==0 || prs_write[
w-1].k() != prs_write[
w].k())
1410 prs_write = vprs_write;
1412 for (
int v=0; v<num_virt; v++){
1413 int64_t ntsr = vntsr[v];
1414 int64_t nwrite = vnwrite[v];
1415 int64_t nnew = vnnew[v];
1417 prs_tsr = prs_tsr[vntsr[v-1]];
1418 prs_write = prs_write[vnwrite[v-1]];
1419 prs_new = prs_new[vnnew[v-1]];
1422 for (int64_t t=0,
w=0,n=0; n<nnew; n++){
1423 if (t<ntsr && (
w==nwrite || prs_tsr[t].k() < prs_write[
w].k())){
1424 prs_new[n].
write(prs_tsr[t].ptr);
1427 if (t>=ntsr || prs_tsr[t].k() > prs_write[
w].k()){
1428 prs_new[n].
write(prs_write[
w].ptr);
1430 sr->
mul(prs_new[n].d(), alpha, prs_new[n].d());
1437 sr->
mul(prs_write[
w].d(), alpha, a);
1442 sr->
mul(prs_tsr[t].d(), beta, b);
1448 ((int64_t*)(prs_new[n].ptr))[0] = prs_tsr[t].k();
1453 while (
w < nwrite && prs_write[
w].k() == prs_write[
w-1].
k()){
1456 sr->
mul(prs_write[
w].d(), alpha, a);
1457 sr->
add(prs_new[n].d(), a, prs_new[n].d());
1459 sr->
add(prs_new[n].d(), prs_write[
w].d(), prs_new[n].d());
void write(char const *buf, int64_t n=1)
sets internal pairs to provided data
void write_key(int64_t key)
sets key of head pair to key
virtual int pair_size() const
gets pair size el_size plus the key size
virtual char * pair_alloc(int64_t n) const
allocate space for n (int64_t,dtype) pairs, necessary for object types
void read_loc_pairs(int order, int64_t nval, int num_virt, int const *sym, int const *edge_len, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int *phase_rank, int64_t *nread, char const *data, char **pairs, algstrct const *sr)
read tensor pairs local to processor
virtual void copy(char *a, char const *b) const
copies element b to element a
void read(char *buf, int64_t n=1) const
sets external data to what this operator points to
void * alloc(int64_t len)
alloc abstraction
virtual char const * addid() const
MPI datatype for pairs.
void sp_write(int num_virt, algstrct const *sr, int64_t *vntsr, ConstPairIterator vprs_tsr, char const *beta, int64_t *vnwrite, ConstPairIterator vprs_write, char const *alpha, int64_t *vnnew, char *&pprs_new)
writes pairs in a sparse write set to the sparse set of elements defining the tensor, resulting in a set of size between ntsr and ntsr+nwrite
#define DEBUG_PRINTF(...)
void sort(int64_t n)
sorts set of pairs using std::sort
int64_t lower_bound(int64_t n, ConstPairIterator op)
searches for pair op via std::lower_bound
void read_val(char *buf) const
sets external value to the value pointed by the iterator
void all_to_allv(void *send_buffer, int64_t const *send_counts, int64_t const *send_displs, int64_t datum_size, void *recv_buffer, int64_t const *recv_counts, int64_t const *recv_displs)
performs all-to-all-v with 64-bit integer counts and offset on arbitrary length types (datum_size)...
virtual void addinv(char const *a, char *b) const
b = -a
void sp_read(algstrct const *sr, int64_t ntsr, ConstPairIterator prs_tsr, char const *alpha, int64_t nread, PairIterator prs_read, char const *beta)
reads elements of a sparse set defining the tensor, into a sparse read set with potentially repeating...
int64_t k() const
returns key of pair at head of ptr
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
int64_t k() const
returns key of pair at head of ptr
virtual void pair_dealloc(char *ptr) const
deallocate given pointer containing contiguous array of pairs
void assign_keys(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char const *vdata, char *vpairs, algstrct const *sr)
assigns keys to an array of values
void read_val(char *buf) const
sets value to the value pointed by the iterator
void bucket_by_pe(int order, int64_t num_pair, int64_t np, int const *phys_phase, int const *virt_phase, int const *bucket_lda, int const *edge_len, ConstPairIterator mapped_data, int64_t *bucket_counts, int64_t *bucket_off, PairIterator bucket_data, algstrct const *sr)
buckets key-value pairs by processor according to distribution
char * d() const
returns value of pair at head of ptr
void depermute_keys(int order, int num_pair, int const *edge_len, int const *new_edge_len, int *const *permutation, char *pairs_buf, algstrct const *sr)
depermutes keys (apply P^T)
void permute_keys(int order, int num_pair, int const *edge_len, int const *new_edge_len, int *const *permutation, char *pairs_buf, int64_t *new_num_pair, algstrct const *sr)
permutes keys
virtual void add(char const *a, char const *b, char *c) const
c = a+b
int el_size
size of each element of algstrct in bytes
void pad_key(int order, int64_t num_pair, int const *edge_len, int const *padding, PairIterator pairs, algstrct const *sr, int const *offsets)
applies padding to keys
int cdealloc(void *ptr)
free abstraction
int64_t * bucket_by_virt(int order, int num_virt, int64_t num_pair, int const *phys_phase, int const *virt_phase, int const *edge_len, ConstPairIterator mapped_data, PairIterator bucket_data, algstrct const *sr)
buckets key value pairs by block/virtual-processor
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void write_val(char const *buf)
sets value of head pair to what is in buf
void wr_pairs_layout(int order, int np, int64_t inwrite, char const *alpha, char const *beta, char rw, int num_virt, int const *sym, int const *edge_len, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int *virt_phys_rank, int const *bucket_lda, char *wr_pairs_buf, char *rw_data, CommData glb_comm, algstrct const *sr, bool is_sparse, int64_t nnz_loc, int64_t *nnz_blk, char *&pprs_new, int64_t &nnz_loc_new)
read or write pairs from / to tensor
virtual void mul(char const *a, char const *b, char *c) const
c = a*b
void spsfy_tsr(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char const *vdata, char *&vpairs, int64_t *nnz_blk, algstrct const *sr, int64_t const *edge_lda, std::function< bool(char const *)> f)
extracts all tensor values (in pair format) that pass a sparsifier function (including padded zeros i...
void depad_tsr(int order, int64_t num_pair, int const *edge_len, int const *sym, int const *padding, int const *prepadding, char const *pairsb, char *new_pairsb, int64_t *new_num_pair, algstrct const *sr)
retrieves the unpadded pairs
void readwrite(int order, int64_t size, char const *alpha, char const *beta, int nvirt, int const *edge_len, int const *sym, int const *phase, int const *phys_phase, int const *virt_dim, int *phase_rank, char *vdata, char *pairs_buf, char rw, algstrct const *sr)
read or write pairs from / to tensor
void cvrt_idx(int order, int const *lens, int64_t idx, int *idx_arr)
virtual void set_pair(char *a, int64_t key, char const *vb) const
sets 1 elements of pair a to value and key