3 #include "../shared/util.h" 18 #pragma omp parallel for private(knew, k, lda, i, j) 20 for (i=0; i<num_pair; i++){
24 for (j=0; j<order; j++){
25 knew += lda*(k%edge_len[j]);
26 lda *= (edge_len[j]+padding[j]);
33 #pragma omp parallel for private(knew, k, lda, i, j) 35 for (i=0; i<num_pair; i++){
39 for (j=0; j<order; j++){
40 knew += lda*((k%edge_len[j])+offsets[j]);
41 lda *= (edge_len[j]+padding[j]);
56 int const * prepadding,
59 int64_t * new_num_pair,
66 int mntd = omp_get_max_threads();
67 int64_t * num_ins_t = (int64_t*)
CTF_int::alloc(
sizeof(int64_t)*mntd);
68 int64_t * pre_ins_t = (int64_t*)
CTF_int::alloc(
sizeof(int64_t)*mntd);
70 std::fill(num_ins_t, num_ins_t+mntd, 0);
76 int64_t i, j, st, end, tid;
78 int64_t kparts[order];
79 tid = omp_get_thread_num();
80 int ntd = omp_get_num_threads();
83 act_ntd = omp_get_num_threads();
86 st = (num_pair/ntd)*tid;
90 end = (num_pair/ntd)*(tid+1);
93 for (i=st; i<end; i++){
95 for (j=0; j<order; j++){
96 kparts[j] = k%(edge_len[j]+padding[j]);
97 if (kparts[j] >= (int64_t)edge_len[j] ||
98 kparts[j] < prepadding[j])
break;
99 k = k/(edge_len[j]+padding[j]);
102 for (j=0; j<order; j++){
104 if (kparts[j+1] < kparts[j])
107 if (sym[j] ==
AS || sym[j] ==
SH){
108 if (kparts[j+1] <= kparts[j])
121 for (
int j=1; j<mntd; j++){
122 pre_ins_t[j] = num_ins_t[j-1] + pre_ins_t[j-1];
128 int64_t i, j, st, end, tid;
130 int64_t kparts[order];
131 tid = omp_get_thread_num();
132 int ntd = omp_get_num_threads();
136 assert(act_ntd == ntd);
138 st = (num_pair/ntd)*tid;
142 end = (num_pair/ntd)*(tid+1);
144 for (i=st; i<end; i++){
146 for (j=0; j<order; j++){
147 kparts[j] = k%(edge_len[j]+padding[j]);
148 if (kparts[j] >= (int64_t)edge_len[j] ||
149 kparts[j] < prepadding[j])
break;
150 k = k/(edge_len[j]+padding[j]);
153 for (j=0; j<order; j++){
155 if (kparts[j+1] < kparts[j])
158 if (sym[j] ==
AS || sym[j] ==
SH){
159 if (kparts[j+1] <= kparts[j])
164 new_pairs[pre_ins_t[tid]].
write(pairs[i]);
171 num_ins = pre_ins_t[act_ntd-1];
173 *new_num_pair = num_ins;
177 int64_t i, j, num_ins;
184 for (i=0; i<num_pair; i++){
186 for (j=0; j<order; j++){
187 kparts[j] = k%(edge_len[j]+padding[j]);
188 if (kparts[j] >= (int64_t)edge_len[j] ||
189 kparts[j] < prepadding[j])
break;
190 k = k/(edge_len[j]+padding[j]);
193 for (j=0; j<order; j++){
195 if (kparts[j+1] < kparts[j])
198 if (sym[j] ==
AS || sym[j] ==
SH){
199 if (kparts[j+1] <= kparts[j])
204 new_pairs[num_ins].
write(pairs[i]);
209 *new_num_pair = num_ins;
377 int const * edge_len,
381 int const * phys_phase,
382 int const * virt_phase,
383 int const * cphase_rank,
387 if (order == 0)
return;
393 int i, act_lda, act_max, curr_idx, sym_idx;
395 int64_t p, buf_offset;
396 int * idx, * virt_rank, * phase_rank, * virt_len;
404 virt_len[
dim] = edge_len[
dim]/phase[
dim];
407 memcpy(phase_rank, cphase_rank, order*
sizeof(
int));
408 memset(virt_rank, 0,
sizeof(
int)*order);
410 int tid, ntd, vst, vend;
412 tid = omp_get_thread_num();
413 ntd = omp_get_num_threads();
419 int * st_idx=NULL, * end_idx;
420 int64_t st_index = 0;
421 int64_t end_index = size/nvirt;
424 vst = (nvirt/ntd)*tid;
425 vst +=
MIN(nvirt%ntd,tid);
426 vend = vst+(nvirt/ntd);
427 if (tid < nvirt % ntd) vend++;
429 int64_t vrt_sz = size/nvirt;
430 int64_t chunk = size/ntd;
431 int64_t st_chunk = chunk*tid +
MIN(tid,size%ntd);
432 int64_t end_chunk = st_chunk+chunk;
435 vst = st_chunk/vrt_sz;
436 vend = end_chunk/vrt_sz;
437 if ((end_chunk%vrt_sz) > 0) vend++;
439 st_index = st_chunk-vst*vrt_sz;
440 end_index = end_chunk-(vend-1)*vrt_sz;
455 calc_idx_arr(order, virt_len, ssym, end_index, end_idx);
460 st_index -= st_idx[0];
463 if (end_idx[0] != 0){
464 end_index += virt_len[0]-end_idx[0];
468 ASSERT(tid != ntd-1 || vend == nvirt);
469 for (p=0; p<nvirt; p++){
470 if (p>=vst && p<vend){
472 if (((sym[0] ==
AS || sym[0] ==
SH) && phase_rank[0] >= phase_rank[1]) ||
473 ( sym[0] ==
SY && phase_rank[0] > phase_rank[1]) ) {
476 int pad0 = (padding[0]+phase_rank[0])/phase[0];
477 int len0 = virt_len[0]-pad0;
478 int plen0 = virt_len[0];
479 data = vdata + sr->
el_size*p*(size/nvirt);
481 if (p==vst && st_index != 0){
483 memcpy(idx+1,st_idx+1,(order-1)*
sizeof(
int));
484 buf_offset = st_index;
487 memset(idx, 0, order*
sizeof(
int));
492 for (i=1; i<order; i++){
493 curr_idx = idx[i]*phase[i]+phase_rank[i];
494 if (curr_idx >= edge_len[i] - padding[i]){
497 }
else if (i < order-1) {
498 sym_idx = idx[i+1]*phase[i+1]+phase_rank[i+1];
499 if (((sym[i] ==
AS || sym[i] ==
SH) && curr_idx >= sym_idx) ||
500 ( sym[i] ==
SY && curr_idx > sym_idx) ) {
512 if (sym[0] !=
NS) plen0 = idx[1]+1;
520 int s1 =
MIN(plen0-is_sh_pad0,len0);
529 if (p == vend-1 && buf_offset >= end_index)
break;
531 for (i=1; i < order; i++){
533 act_max = virt_len[i];
537 act_max =
MIN(act_max,idx[i+1]+1);
539 if (idx[i] >= act_max)
541 ASSERT(edge_len[i]%phase[i] == 0);
545 if (i >= order)
break;
548 for (act_lda=0; act_lda < order; act_lda++){
549 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
550 virt_rank[act_lda]++;
551 if (virt_rank[act_lda] >= virt_phase[act_lda])
552 virt_rank[act_lda] = 0;
553 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
554 if (virt_rank[act_lda] > 0)
570 int const * edge_len,
574 int const * phys_phase,
575 int const * virt_phase,
576 int const * cphase_rank,
579 int const * sym_mask){
585 int i, act_lda, act_max;
587 int64_t p, buf_offset;
588 int * idx, * virt_rank, * phase_rank, * virt_len;
596 virt_len[
dim] = edge_len[
dim]/phase[
dim];
599 memcpy(phase_rank, cphase_rank, order*
sizeof(
int));
600 memset(virt_rank, 0,
sizeof(
int)*order);
602 int tid, ntd, vst, vend;
604 tid = omp_get_thread_num();
605 ntd = omp_get_num_threads();
611 int * st_idx=NULL, * end_idx;
612 int64_t st_index = 0;
613 int64_t end_index = size/nvirt;
616 vst = (nvirt/ntd)*tid;
617 vst +=
MIN(nvirt%ntd,tid);
618 vend = vst+(nvirt/ntd);
619 if (tid < nvirt % ntd) vend++;
621 int64_t vrt_sz = size/nvirt;
622 int64_t chunk = size/ntd;
623 int64_t st_chunk = chunk*tid +
MIN(tid,size%ntd);
624 int64_t end_chunk = st_chunk+chunk;
627 vst = st_chunk/vrt_sz;
628 vend = end_chunk/vrt_sz;
629 if ((end_chunk%vrt_sz) > 0) vend++;
631 st_index = st_chunk-vst*vrt_sz;
632 end_index = end_chunk-(vend-1)*vrt_sz;
647 calc_idx_arr(order, virt_len, ssym, end_index, end_idx);
652 st_index -= st_idx[0];
655 if (end_idx[0] != 0){
656 end_index -= end_idx[0];
661 ASSERT(tid != ntd-1 || vend == nvirt);
662 for (p=0; p<nvirt; p++){
663 if (st_index == end_index)
break;
664 if (p>=vst && p<vend){
672 int plen0 = virt_len[0];
673 data = vdata + sr->
el_size*p*(size/nvirt);
675 if (p==vst && st_index != 0){
677 memcpy(idx+1,st_idx+1,(order-1)*
sizeof(
int));
678 buf_offset = st_index;
681 memset(idx, 0, order*
sizeof(
int));
685 if (sym[0] !=
NS) plen0 = idx[1]+1;
687 for (i=1; i<order; i++){
688 if (sym_mask[i] == 1){
689 int curr_idx_i = idx[i]*phase[i]+phase_rank[i];
691 for (
int j=i+1; j<order; j++){
692 if (sym_mask[j] == 1){
693 int curr_idx_j = idx[j]*phase[j]+phase_rank[j];
694 if (curr_idx_i == curr_idx_j) iperm++;
697 perm_factor *= iperm;
700 if (sym_mask[0] == 0){
701 if (perm_factor != 1){
704 sr->
scal(plen0, scal_fact,data+buf_offset*sr->
el_size, 1);
708 if (perm_factor != 1){
711 sr->
scal(idx[1]+1, scal_fact,data+buf_offset*sr->
el_size, 1);
713 int curr_idx_0 = idx[1]*phase[0]+phase_rank[0];
715 for (
int j=1; j<order; j++){
716 if (sym_mask[j] == 1){
717 int curr_idx_j = idx[j]*phase[j]+phase_rank[j];
718 if (curr_idx_0 == curr_idx_j) iperm++;
723 sr->
scal(1, scal_fact2, data+(buf_offset+idx[1])*sr->
el_size, 1);
726 if (p == vend-1 && buf_offset >= end_index)
break;
728 for (i=1; i < order; i++){
730 act_max = virt_len[i];
734 act_max =
MIN(act_max,idx[i+1]+1);
736 if (idx[i] >= act_max)
738 ASSERT(edge_len[i]%phase[i] == 0);
742 if (i >= order)
break;
745 for (act_lda=0; act_lda < order; act_lda++){
746 phase_rank[act_lda] -= virt_rank[act_lda]*phys_phase[act_lda];
747 virt_rank[act_lda]++;
748 if (virt_rank[act_lda] >= virt_phase[act_lda])
749 virt_rank[act_lda] = 0;
750 phase_rank[act_lda] += virt_rank[act_lda]*phys_phase[act_lda];
751 if (virt_rank[act_lda] > 0)
void calc_idx_arr(int order, int const *lens, int const *sym, int64_t idx, int *idx_arr)
void write(char const *buf, int64_t n=1)
sets internal pairs to provided data
void zero_padding(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int const *cphase_rank, char *vdata, algstrct const *sr)
sets to zero all values in padded region of tensor
void write_key(int64_t key)
sets key of head pair to key
void * alloc(int64_t len)
alloc abstraction
virtual char const * addid() const
MPI datatype for pairs.
virtual void set(char *a, char const *b, int64_t n) const
sets n elements of array a to value b
int64_t k() const
returns key of pair at head of ptr
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
int64_t k() const
returns key of pair at head of ptr
virtual void scal(int n, char const *alpha, char *X, int incX) const
X["i"]=alpha*X["i"];.
int el_size
size of each element of algstrct in bytes
void pad_key(int order, int64_t num_pair, int const *edge_len, int const *padding, PairIterator pairs, algstrct const *sr, int const *offsets)
applies padding to keys
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void scal_diag(int order, int64_t size, int nvirt, int const *edge_len, int const *sym, int const *padding, int const *phase, int const *phys_phase, int const *virt_phase, int const *cphase_rank, char *vdata, algstrct const *sr, int const *sym_mask)
scales each element by 1/(number of entries equivalent to it after permutation of indices for which s...
void depad_tsr(int order, int64_t num_pair, int const *edge_len, int const *sym, int const *padding, int const *prepadding, char const *pairsb, char *new_pairsb, int64_t *new_num_pair, algstrct const *sr)
retrieves the unpadded pairs
virtual void cast_double(double d, char *c) const
c = &d