3 #include "../shared/iter_tsr.h" 7 #include "../shared/offload.h" 8 #include "../shared/util.h" 17 int const * edge_len_A,
18 int64_t
const * lda_A,
20 int const * idx_map_A,
24 int const * edge_len_B,
26 int const * idx_map_B,
27 uint64_t *
const* offsets_B,
32 int const * edge_len_C,
34 int const * idx_map_C,
35 uint64_t *
const* offsets_C,
38 int const * rev_idx_map,
41 int rA = rev_idx_map[3*idim+0];
42 int rB = rev_idx_map[3*idim+1];
43 int rC = rev_idx_map[3*idim+2];
45 ASSERT(!(rA != -1 && rB == -1 && rC == -1));
48 imax = edge_len_B[rB];
50 imax = edge_len_C[rC];
52 if (rA != -1 && sym_A[rA] !=
NS){
55 if (idx_map_A[rrA+1] > idim)
56 imax = idx[idx_map_A[rrA+1]]+1;
58 }
while (sym_A[rrA] !=
NS && idx_map_A[rrA] < idim);
61 if (rB != -1 && sym_B[rB] !=
NS){
64 if (idx_map_B[rrB+1] > idim)
65 imax = std::min(imax,idx[idx_map_B[rrB+1]]+1);
67 }
while (sym_B[rrB] !=
NS && idx_map_B[rrB] < idim);
70 if (rC != -1 && sym_C[rC] !=
NS){
73 if (idx_map_C[rrC+1] > idim)
74 imax = std::min(imax,idx[idx_map_C[rrC+1]]+1);
76 }
while (sym_C[rrC] !=
NS && idx_map_C[rrC] < idim);
81 if (rA > 0 && sym_A[rA-1] !=
NS){
84 if (idx_map_A[rrA-1] > idim)
85 imin = idx[idx_map_A[rrA-1]];
87 }
while (rrA>0 && sym_A[rrA-1] !=
NS && idx_map_A[rrA] < idim);
90 if (rB > 0 && sym_B[rB-1] !=
NS){
93 if (idx_map_B[rrB-1] > idim)
94 imin = std::max(imin,idx[idx_map_B[rrB-1]]);
96 }
while (rrB>0 && sym_B[rrB-1] !=
NS && idx_map_B[rrB] < idim);
99 if (rC > 0 && sym_C[rC-1] !=
NS){
102 if (idx_map_C[rrC-1] > idim)
103 imin = std::max(imin,idx[idx_map_C[rrC-1]]);
105 }
while (rrC>0 && sym_C[rrC-1] !=
NS && idx_map_C[rrC] < idim);
107 int64_t key_offset = 0;
108 for (
int i=0; i<order_A; i++){
110 key_offset += idx[idx_map_A[i]]*lda_A[i];
115 for (
int i=imin; i<imax; i++){
121 (rA == -1 && A[0].k()!=key_offset) ||
122 (rA != -1 && (A[0].k()/lda_A[rA]/edge_len_A[rA])!=key_offset/lda_A[rA]/edge_len_A[rA])){
128 if (rA != -1 && (A[0].k()/lda_A[rA])%edge_len_A[rA] != i){
129 ASSERT((A[0].k()/lda_A[rA])%edge_len_A[rA] > i);
133 int64_t new_size_A = size_A;
135 memcpy(nidx, idx, idx_max*
sizeof(
int));
137 spA_dnB_dnC_ctrloop<idim-1>(alpha, cpiA, new_size_A, sr_A, order_A, edge_len_A, lda_A, sym_A, idx_map_A, B+offsets_B[idim][nidx[idim]], sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C+offsets_C[idim][nidx[idim]], sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, nidx, rev_idx_map, idx_max);
139 if (size_A == new_size_A){
142 size_A = new_size_A-1;
158 int const * edge_len_A,
159 int64_t
const * lda_A,
161 int const * idx_map_A,
165 int const * edge_len_B,
167 int const * idx_map_B,
168 uint64_t *
const* offsets_B,
173 int const * edge_len_C,
175 int const * idx_map_C,
176 uint64_t *
const* offsets_C,
179 int const * rev_idx_map,
184 int rA = rev_idx_map[0];
185 int rB = rev_idx_map[1];
186 int rC = rev_idx_map[2];
188 ASSERT(!(rA != -1 && rB == -1 && rC == -1));
191 imax = edge_len_B[rB];
193 imax = edge_len_C[rC];
195 if (rA != -1 && sym_A[rA] !=
NS)
196 imax = idx[idx_map_A[rA+1]]+1;
197 if (rB != -1 && sym_B[rB] !=
NS)
198 imax = std::min(imax,idx[idx_map_B[rB+1]]+1);
199 if (rC != -1 && sym_C[rC] !=
NS)
200 imax = std::min(imax,idx[idx_map_C[rC+1]]+1);
204 if (rA > 0 && sym_A[rA-1] !=
NS)
205 imin = idx[idx_map_A[rA-1]];
206 if (rB > 0 && sym_B[rB-1] !=
NS)
207 imin = std::max(imin,idx[idx_map_B[rB-1]]);
208 if (rC > 0 && sym_C[rC-1] !=
NS)
209 imin = std::max(imin,idx[idx_map_C[rC-1]]);
225 if (alpha == NULL || sr_C->isequal(alpha,sr_C->mulid())){
226 for (
int i=imin; i<imax; i++){
227 char tmp[sr_C->el_size];
237 for (
int i=imin; i<imax; i++){
238 char tmp[sr_C->el_size];
260 if (alpha == NULL || sr_C->isequal(alpha,sr_C->mulid())){
261 for (
int i=imin; i<imax; i++){
262 func->acc_f(A[0].d(),
279 for (
int i=imin; i<imax; i++){
280 char tmp[sr_C->el_size];
281 func->apply_f(A[0].d(),
295 int64_t key_offset = 0;
296 for (
int i=0; i<order_A; i++){
298 key_offset += idx[idx_map_A[i]]*lda_A[i];
301 ASSERT(func == NULL && alpha != NULL && beta != NULL);
302 assert(func == NULL && alpha != NULL && beta != NULL);
304 int64_t sk = A[0].k()-key_offset;
306 int i = sk/lda_A[rA];
309 char tmp[sr_C->el_size];
322 }
while (size_A > 0);
333 int const * edge_len_A,
334 int64_t
const * lda_A,
336 int const * idx_map_A,
340 int const * edge_len_B,
342 int const * idx_map_B,
343 uint64_t *
const* offsets_B,
348 int const * edge_len_C,
350 int const * idx_map_C,
351 uint64_t *
const* offsets_C,
354 int const * rev_idx_map,
363 int const * edge_len_A,
365 int const * idx_map_A,
369 int const * edge_len_B,
371 int const * idx_map_B,
376 int const * edge_len_C,
378 int const * idx_map_C,
384 int * dlen_A, * dlen_B, * dlen_C;
389 &idx_max, &rev_idx_map);
392 if (alpha == NULL && beta == NULL){
395 }
else if (alpha == NULL){
397 sr_C->
mul(A, B, tmp);
398 sr_C->
mul(C, beta, C);
399 sr_C->
add(tmp, C, C);
403 sr_C->
mul(A, B, tmp);
404 sr_C->
mul(tmp, alpha, tmp);
405 sr_C->
mul(C, beta, C);
406 sr_C->
add(tmp, C, C);
415 memcpy(dlen_A, edge_len_A,
sizeof(
int)*order_A);
416 memcpy(dlen_B, edge_len_B,
sizeof(
int)*order_B);
417 memcpy(dlen_C, edge_len_C,
sizeof(
int)*order_C);
427 sr_C->
scal(sz, beta, C, 1);
435 uint64_t ** offsets_A;
436 uint64_t ** offsets_B;
437 uint64_t ** offsets_C;
438 compute_syoffs(sr_A, order_A, edge_len_A, sym_A, idx_map_A, sr_B, order_B, edge_len_B, sym_B, idx_map_B, sr_C, order_C, edge_len_C, sym_C, idx_map_C, idx_max, rev_idx_map, offsets_A, offsets_B, offsets_C);
445 memset(idx_glb, 0,
sizeof(
int)*idx_max);
448 int64_t lda_A[order_A];
449 for (
int i=0; i<order_A; i++){
450 if (i==0) lda_A[i] = 1;
451 else lda_A[i] = lda_A[i-1]*edge_len_A[i-1];
459 SWITCH_ORD_CALL(
spA_dnB_dnC_ctrloop, idx_max-1, alpha, pA, size_A, sr_A, order_A, edge_len_A, lda_A, sym_A, idx_map_A, B, sr_B, order_B, edge_len_B, sym_B, idx_map_B, offsets_B, beta, C, sr_C, order_C, edge_len_C, sym_C, idx_map_C, offsets_C, func, idx_glb, rev_idx_map, idx_max);
462 for (
int l=0; l<idx_max; l++){
void compute_syoffs(algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, int tot_order, int const *rev_idx_map, uint64_t **&offsets_A, uint64_t **&offsets_B, uint64_t **&offsets_C)
virtual bool isequal(char const *a, char const *b) const
returns true if algstrct elements a and b are equal
void inv_idx(int order_A, int const *idx_A, int order_B, int const *idx_B, int order_C, int const *idx_C, int *order_tot, int **idx_arr)
invert index map
void * alloc(int64_t len)
alloc abstraction
virtual char const * addid() const
MPI datatype for pairs.
untyped internal class for triply-typed bivariate function
virtual void set(char *a, char const *b, int64_t n) const
sets n elements of array a to value b
void spA_dnB_dnC_ctrloop< 0 >(char const *alpha, ConstPairIterator &A, int64_t &size_A, algstrct const *sr_A, int order_A, int const *edge_len_A, int64_t const *lda_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
#define SWITCH_ORD_CALL(F, act_ord,...)
virtual void scal(int n, char const *alpha, char *X, int incX) const
X["i"]=alpha*X["i"];.
void spA_dnB_dnC_seq_ctr(char const *alpha, char const *A, int64_t size_A, algstrct const *sr_A, int order_A, int const *edge_len_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, bivar_function const *func)
virtual void add(char const *a, char const *b, char *c) const
c = a+b
int el_size
size of each element of algstrct in bytes
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void spA_dnB_dnC_ctrloop(char const *alpha, ConstPairIterator &A, int64_t &size_A, algstrct const *sr_A, int order_A, int const *edge_len_A, int64_t const *lda_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
void spA_dnB_dnC_ctrloop< MAX_ORD >(char const *alpha, ConstPairIterator &A, int64_t &size_A, algstrct const *sr_A, int order_A, int const *edge_len_A, int64_t const *lda_A, int const *sym_A, int const *idx_map_A, char const *B, algstrct const *sr_B, int order_B, int const *edge_len_B, int const *sym_B, int const *idx_map_B, uint64_t *const *offsets_B, char const *beta, char *C, algstrct const *sr_C, int order_C, int const *edge_len_C, int const *sym_C, int const *idx_map_C, uint64_t *const *offsets_C, bivar_function const *func, int const *idx, int const *rev_idx_map, int idx_max)
virtual void mul(char const *a, char const *b, char *c) const
c = a*b
virtual char const * mulid() const
identity element for multiplication i.e. 1
int64_t sy_packed_size(int order, const int *len, const int *sym)
computes the size of a tensor in SY (NOT HOLLOW) packed symmetric layout