3 #include "../shared/iter_tsr.h" 6 #include "../shared/offload.h" 7 #include "../shared/util.h" 16 int const * edge_len_A,
18 int const * idx_map_A,
19 uint64_t *
const* offsets_A,
23 int const * edge_len_B,
25 int const * idx_map_B,
26 uint64_t *
const* offsets_B,
31 int const * edge_len_C,
33 int const * idx_map_C,
34 uint64_t *
const* offsets_C,
37 int const * rev_idx_map,
40 int rA = rev_idx_map[3*idim+0];
41 int rB = rev_idx_map[3*idim+1];
42 int rC = rev_idx_map[3*idim+2];
45 imax = edge_len_A[rA];
47 imax = edge_len_B[rB];
49 imax = edge_len_C[rC];
51 if (rA != -1 && sym_A[rA] !=
NS){
54 if (idx_map_A[rrA+1] > idim)
55 imax = idx[idx_map_A[rrA+1]]+1;
57 }
while (sym_A[rrA] !=
NS && idx_map_A[rrA] < idim);
60 if (rB != -1 && sym_B[rB] !=
NS){
63 if (idx_map_B[rrB+1] > idim)
64 imax = std::min(imax,idx[idx_map_B[rrB+1]]+1);
66 }
while (sym_B[rrB] !=
NS && idx_map_B[rrB] < idim);
69 if (rC != -1 && sym_C[rC] !=
NS){
72 if (idx_map_C[rrC+1] > idim)
73 imax = std::min(imax,idx[idx_map_C[rrC+1]]+1);
75 }
while (sym_C[rrC] !=
NS && idx_map_C[rrC] < idim);
80 if (rA > 0 && sym_A[rA-1] !=
NS){
83 if (idx_map_A[rrA-1] > idim)
84 imin = idx[idx_map_A[rrA-1]];
86 }
while (rrA>0 && sym_A[rrA-1] !=
NS && idx_map_A[rrA] < idim);
89 if (rB > 0 && sym_B[rB-1] !=
NS){
92 if (idx_map_B[rrB-1] > idim)
93 imin = std::max(imin,idx[idx_map_B[rrB-1]]);
95 }
while (rrB>0 && sym_B[rrB-1] !=
NS && idx_map_B[rrB] < idim);
98 if (rC > 0 && sym_C[rC-1] !=
NS){
101 if (idx_map_C[rrC-1] > idim)
102 imin = std::max(imin,idx[idx_map_C[rrC-1]]);
104 }
while (rrC>0 && sym_C[rrC-1] !=
NS && idx_map_C[rrC] < idim);
111 for (
int i=imin; i<imax; i++){
117 memcpy(nidx, idx, idx_max*
sizeof(
int));
119 sym_seq_ctr_loop<idim-1>(alpha, A+offsets_A[idim][nidx[idim]], sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B+offsets_B[idim][nidx[idim]], sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C+offsets_C[idim][nidx[idim]], sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, nidx, rev_idx_map, idx_max);
123 for (
int i=imin; i<imax; i++){
125 memcpy(nidx, idx, idx_max*
sizeof(
int));
127 sym_seq_ctr_loop<idim-1>(alpha, A+offsets_A[idim][nidx[idim]], sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B+offsets_B[idim][nidx[idim]], sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C+offsets_C[idim][nidx[idim]], sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, nidx, rev_idx_map, idx_max);
141 int const * edge_len_A,
143 int const * idx_map_A,
144 uint64_t *
const* offsets_A,
148 int const * edge_len_B,
150 int const * idx_map_B,
151 uint64_t *
const* offsets_B,
156 int const * edge_len_C,
158 int const * idx_map_C,
159 uint64_t *
const* offsets_C,
162 int const * rev_idx_map,
165 int rA = rev_idx_map[0];
166 int rB = rev_idx_map[1];
167 int rC = rev_idx_map[2];
170 imax = edge_len_A[rA];
172 imax = edge_len_B[rB];
174 imax = edge_len_C[rC];
176 if (rA != -1 && sym_A[rA] !=
NS)
177 imax = idx[idx_map_A[rA+1]]+1;
178 if (rB != -1 && sym_B[rB] !=
NS)
179 imax = std::min(imax,idx[idx_map_B[rB+1]]+1);
180 if (rC != -1 && sym_C[rC] !=
NS)
181 imax = std::min(imax,idx[idx_map_C[rC+1]]+1);
185 if (rA > 0 && sym_A[rA-1] !=
NS)
186 imin = idx[idx_map_A[rA-1]];
187 if (rB > 0 && sym_B[rB-1] !=
NS)
188 imin = std::max(imin,idx[idx_map_B[rB-1]]);
189 if (rC > 0 && sym_C[rC-1] !=
NS)
190 imin = std::max(imin,idx[idx_map_C[rC-1]]);
206 if (alpha == NULL || sr_C->isequal(alpha,sr_C->mulid())){
207 for (
int i=imin; i<imax; i++){
208 char tmp[sr_C->el_size];
209 sr_C->mul(A+offsets_A[0][i],
218 for (
int i=imin; i<imax; i++){
219 char tmp[sr_C->el_size];
220 sr_C->mul(A+offsets_A[0][i],
241 if (alpha == NULL || sr_C->isequal(alpha,sr_C->mulid())){
242 for (
int i=imin; i<imax; i++){
243 func->acc_f(A+offsets_A[0][i],
253 for (
int i=imin; i<imax; i++){
254 char tmp[sr_C->el_size];
255 func->apply_f(A+offsets_A[0][i],
276 int const * edge_len_A,
278 int const * idx_map_A,
279 uint64_t *
const* offsets_A,
283 int const * edge_len_B,
285 int const * idx_map_B,
286 uint64_t *
const* offsets_B,
291 int const * edge_len_C,
293 int const * idx_map_C,
294 uint64_t *
const* offsets_C,
297 int const * rev_idx_map,
304 int const * edge_len,
308 std::fill(offsets, offsets+len, 0);
310 for (
int i=0; i<len; i++){
313 }
else if (sym[r-1] ==
NS){
315 for (
int i=0; i<len; i++){
320 memcpy(medge_len, edge_len, r*
sizeof(
int));
322 while (rr>0 && sym[rr-1] !=
NS) rr--;
323 for (
int i=0; i<len; i++){
324 std::fill(medge_len+rr,medge_len+r+1, i);
334 int const * edge_len_A,
336 int const * idx_map_A,
339 int const * edge_len_B,
341 int const * idx_map_B,
344 int const * edge_len_C,
346 int const * idx_map_C,
348 int const * rev_idx_map,
349 uint64_t **& offsets_A,
350 uint64_t **& offsets_B,
351 uint64_t **& offsets_C){
353 offsets_A = (uint64_t**)
CTF_int::alloc(
sizeof(uint64_t*)*tot_order);
354 offsets_B = (uint64_t**)
CTF_int::alloc(
sizeof(uint64_t*)*tot_order);
355 offsets_C = (uint64_t**)
CTF_int::alloc(
sizeof(uint64_t*)*tot_order);
357 for (
int idim=0; idim<tot_order; idim++){
360 int rA = rev_idx_map[3*idim+0];
361 int rB = rev_idx_map[3*idim+1];
362 int rC = rev_idx_map[3*idim+2];
365 len = edge_len_A[rA];
367 len = edge_len_B[rB];
369 len = edge_len_C[rC];
371 offsets_A[idim] = (uint64_t*)
CTF_int::alloc(
sizeof(uint64_t)*len);
372 offsets_B[idim] = (uint64_t*)
CTF_int::alloc(
sizeof(uint64_t)*len);
373 offsets_C[idim] = (uint64_t*)
CTF_int::alloc(
sizeof(uint64_t)*len);
374 compute_syoff(rA, len, sr_A, edge_len_A, sym_A, offsets_A[idim]);
375 compute_syoff(rB, len, sr_B, edge_len_B, sym_B, offsets_B[idim]);
376 compute_syoff(rC, len, sr_C, edge_len_C, sym_C, offsets_C[idim]);
385 int const * edge_len_A,
387 int const * idx_map_A,
391 int const * edge_len_B,
393 int const * idx_map_B,
398 int const * edge_len_C,
400 int const * idx_map_C){
402 int idx, i, idx_max, imin, imax, iA, iB, iC, j, k;
404 int off_idx, sym_pass;
406 int * dlen_A, * dlen_B, * dlen_C;
407 int64_t idx_A, idx_B, idx_C, off_lda;
412 &idx_max, &rev_idx_map);
415 if (alpha == NULL && beta == NULL){
418 }
else if (alpha == NULL){
420 sr_C->
mul(A, B, tmp);
421 sr_C->
mul(C, beta, C);
422 sr_C->
add(tmp, C, C);
426 sr_C->
mul(A, B, tmp);
427 sr_C->
mul(tmp, alpha, tmp);
428 sr_C->
mul(C, beta, C);
429 sr_C->
add(tmp, C, C);
438 memcpy(dlen_A, edge_len_A,
sizeof(
int)*order_A);
439 memcpy(dlen_B, edge_len_B,
sizeof(
int)*order_B);
440 memcpy(dlen_C, edge_len_C,
sizeof(
int)*order_C);
450 sr_C->
scal(sz, beta, C, 1);
458 uint64_t ** offsets_A;
459 uint64_t ** offsets_B;
460 uint64_t ** offsets_C;
461 compute_syoffs(sr_A, order_A, edge_len_A, sym_A, idx_map_A, sr_B, order_B, edge_len_B, sym_B, idx_map_B, sr_C, order_C, edge_len_C, sym_C, idx_map_C, idx_max, rev_idx_map, offsets_A, offsets_B, offsets_C);
464 if (order_C > 1 || (order_C > 0 && idx_map_C[0] != 0)){
470 memset(idx_glb, 0,
sizeof(
int)*idx_max);
472 SWITCH_ORD_CALL(
sym_seq_ctr_loop, idx_max-1, alpha, A, sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B, sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C, sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, NULL, idx_glb, rev_idx_map, idx_max);
478 memset(idx_glb, 0,
sizeof(
int)*idx_max);
480 SWITCH_ORD_CALL(
sym_seq_ctr_loop, idx_max-1, alpha, A, sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B, sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C, sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, NULL, idx_glb, rev_idx_map, idx_max);
484 for (
int l=0; l<idx_max; l++){
494 memset(idx_glb, 0,
sizeof(
int)*idx_max);
496 idx_A = 0, idx_B = 0, idx_C = 0;
515 sr_C->
mul(tmp, alpha, tmp);
522 for (idx=0; idx<idx_max; idx++){
523 imin = 0, imax = INT_MAX;
529 ASSERT(idx_glb[idx] >= imin && idx_glb[idx] < imax);
533 if (idx_glb[idx] >= imax){
536 if (idx_glb[idx] != imin) {
540 if (idx == idx_max)
break;
543 if (!sym_pass)
continue;
545 if (!sym_pass)
continue;
547 if (!sym_pass)
continue;
570 int const * edge_len_A,
572 int const * idx_map_A,
576 int const * edge_len_B,
578 int const * idx_map_B,
583 int const * edge_len_C,
585 int const * idx_map_C,
588 int idx, i, idx_max, imin, imax, iA, iB, iC, j, k;
589 int off_idx, sym_pass;
590 int * idx_glb, * rev_idx_map;
591 int * dlen_A, * dlen_B, * dlen_C;
593 int64_t idx_A, idx_B, idx_C, off_lda;
598 &idx_max, &rev_idx_map);
603 memcpy(dlen_A, edge_len_A,
sizeof(
int)*order_A);
604 memcpy(dlen_B, edge_len_B,
sizeof(
int)*order_B);
605 memcpy(dlen_C, edge_len_C,
sizeof(
int)*order_C);
608 memset(idx_glb, 0,
sizeof(
int)*idx_max);
633 sr_C->
scal(sz, beta, C, 1);
643 uint64_t ** offsets_A;
644 uint64_t ** offsets_B;
645 uint64_t ** offsets_C;
646 compute_syoffs(sr_A, order_A, edge_len_A, sym_A, idx_map_A, sr_B, order_B, edge_len_B, sym_B, idx_map_B, sr_C, order_C, edge_len_C, sym_C, idx_map_C, idx_max, rev_idx_map, offsets_A, offsets_B, offsets_C);
649 if (order_C > 1 || (order_C > 0 && idx_map_C[0] != 0)){
655 memset(idx_glb, 0,
sizeof(
int)*idx_max);
657 SWITCH_ORD_CALL(
sym_seq_ctr_loop, idx_max-1, alpha, A, sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B, sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C, sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, idx_glb, rev_idx_map, idx_max);
663 memset(idx_glb, 0,
sizeof(
int)*idx_max);
665 SWITCH_ORD_CALL(
sym_seq_ctr_loop, idx_max-1, alpha, A, sr_A, order_A, edge_len_A, sym_A, idx_map_A, offsets_A, B, sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C, sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, idx_glb, rev_idx_map, idx_max);
669 for (
int l=0; l<idx_max; l++){
680 idx_A = 0, idx_B = 0, idx_C = 0;
700 for (idx=0; idx<idx_max; idx++){
701 imin = 0, imax = INT_MAX;
707 ASSERT(idx_glb[idx] >= imin && idx_glb[idx] < imax);
711 if (idx_glb[idx] >= imax){
714 if (idx_glb[idx] != imin) {
718 if (idx == idx_max)
break;
721 if (!sym_pass)
continue;
723 if (!sym_pass)
continue;
725 if (!sym_pass)
continue;
749 int const * edge_len_A,
751 int const * idx_map_A,
755 int const * edge_len_B,
757 int const * idx_map_B,
762 int const * edge_len_C,
764 int const * idx_map_C,
768 int idx, i, idx_max, imin, imax, iA, iB, iC, j, k;
769 int off_idx, sym_pass, stride_A, stride_B, stride_C;
770 int * idx_glb, * rev_idx_map;
771 int * dlen_A, * dlen_B, * dlen_C;
772 int64_t idx_A, idx_B, idx_C, off_lda;
774 stride_A = prm->
m*prm->
k*prm->
l;
775 stride_B = prm->
k*prm->
n*prm->
l;
776 stride_C = prm->
m*prm->
n*prm->
l;
781 &idx_max, &rev_idx_map);
786 memcpy(dlen_A, edge_len_A,
sizeof(
int)*order_A);
787 memcpy(dlen_B, edge_len_B,
sizeof(
int)*order_B);
788 memcpy(dlen_C, edge_len_C,
sizeof(
int)*order_C);
791 memset(idx_glb, 0,
sizeof(
int)*idx_max);
808 idx_A = 0, idx_B = 0, idx_C = 0;
820 A+idx_A*stride_A*sr_A->
el_size,
822 C+idx_C*stride_C*sr_C->
el_size);
826 A+idx_A*stride_A*sr_A->
el_size,
827 B+idx_B*stride_B*sr_B->
el_size,
828 C+idx_C*stride_C*sr_C->
el_size);
833 A+idx_A*stride_A*sr_A->
el_size,
835 C+idx_C*stride_C*sr_C->
el_size);
840 A+idx_A*stride_A*sr_A->
el_size,
841 B+idx_B*stride_B*sr_B->
el_size,
842 C+idx_C*stride_C*sr_C->
el_size);
850 B+idx_B*stride_B*sr_B->
el_size,
852 C+idx_C*stride_C*sr_C->
el_size);
856 B+idx_B*stride_B*sr_B->
el_size,
857 A+idx_A*stride_A*sr_A->
el_size,
858 C+idx_C*stride_C*sr_C->
el_size);
863 B+idx_B*stride_B*sr_B->
el_size,
865 C+idx_C*stride_C*sr_C->
el_size);
870 B+idx_B*stride_B*sr_B->
el_size,
871 A+idx_A*stride_A*sr_A->
el_size,
872 C+idx_C*stride_C*sr_C->
el_size);
885 CTF_FLOPS_ADD((2 * (int64_t)prm->
l * (int64_t)prm->
n * (int64_t)prm->
m * (int64_t)(prm->
k+1)));
889 for (idx=0; idx<idx_max; idx++){
890 imin = 0, imax = INT_MAX;
896 ASSERT(idx_glb[idx] >= imin && idx_glb[idx] < imax);
900 if (idx_glb[idx] >= imax){
903 if (idx_glb[idx] != imin) {
907 if (idx == idx_max)
break;
910 if (!sym_pass)
continue;
912 if (!sym_pass)
continue;
914 if (!sym_pass)
continue;
void compute_syoffs(algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, int tot_order, int const *rev_idx_map, uint64_t **&offsets_A, uint64_t **&offsets_B, uint64_t **&offsets_C)
virtual bool isequal(char const *a, char const *b) const
returns true if algstrct elements a and b are equal
void inv_idx(int order_A, int const *idx_A, int order_B, int const *idx_B, int order_C, int const *idx_C, int *order_tot, int **idx_arr)
invert index map
void * alloc(int64_t len)
alloc abstraction
virtual char const * addid() const
MPI datatype for pairs.
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
untyped internal class for triply-typed bivariate function
#define GET_MIN_MAX(__X, nr, wd)
virtual void gemm_batch(char tA, char tB, int l, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
beta*C["ijl"]=alpha*A^tA["ikl"]*B^tB["kjl"];
virtual void set(char *a, char const *b, int64_t n) const
sets n elements of array a to value b
virtual void coffload_gemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
#define SWITCH_ORD_CALL(F, act_ord,...)
int sym_seq_ctr_cust(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, bivar_function const *func)
performs symmetric contraction with custom elementwise function
int sym_seq_ctr_ref(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C)
performs symmetric contraction with reference (unblocked) kernel
virtual void scal(int n, char const *alpha, char *X, int incX) const
X["i"]=alpha*X["i"];.
void sym_seq_ctr_loop(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, uint64_t *const *offsets_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
int sym_seq_ctr_inr(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, iparam const *prm, bivar_function const *func)
performs symmetric contraction with blocked gemm
virtual void add(char const *a, char const *b, char *c) const
c = a+b
int el_size
size of each element of algstrct in bytes
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void sym_seq_ctr_loop< 0 >(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, uint64_t *const *offsets_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
void compute_syoff(int r, int len, algstrct const *sr, int const *edge_len, int const *sym, uint64_t *offsets)
virtual void mul(char const *a, char const *b, char *c) const
c = a*b
template void sym_seq_ctr_loop< MAX_ORD >(char const *alpha, char const *A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, uint64_t *const *offsets_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
virtual char const * mulid() const
identity element for multiplication i.e. 1
virtual void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
int64_t sy_packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in SY (NOT HOLLOW) packed symmetric layout
virtual void cgemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
virtual void acc_f(char const *a, char const *b, char *c, CTF_int::algstrct const *sr_C) const =0
compute c = c+f(a,b)