13 int parity(
char const *
a,
char const *
b,
char const * c,
int len_A,
int len_B){
15 for (
int i=0; i<len_A; i++){
18 for (
int i=0; i<len_B; i++){
30 for (
int i=0; i<len_A+len_B; i++){
33 for (j=i+1; j<len_A+len_B; j++){
41 assert(j<len_A+len_B);
49 int parity(
char const *
a,
char const * c,
int len_A,
int len_C){
50 if (len_A == len_C)
return parity(a, NULL, c, len_A, 0);
54 for (
int i=0; i<len_C; i++){
56 for (j=0; j<len_A; j++){
57 if (c[i] == a[j])
break;
64 assert(ib == len_C-len_A);
65 return parity(a, b, c, len_A, len_C-len_A);
71 if (par%2 == 1)
return -1.0;
77 int64_t * indices, size;
83 for (int64_t i=0; i<size; i++){
84 values[i] = drand48();
86 tsr.
write(size, indices, values);
95 for (int64_t i=0; i<size; i++){
96 values[i] = drand48();
98 v.
write(size, indices, values);
103 for (
int i=0; i<
dim; i++){
107 for (
int i=0; i<
dim; i++){
108 for (
int j=0; j<dim-1; j++){
109 if (j>=i) str_n1[j] = str[j+1];
110 else str_n1[j] =str[j];
113 assert(
sign(
parity(str+i, str_n1, str, 1, dim-1)) == sgn);
114 tsr[str] += sgn*v[str+i]*tsr_n1[str_n1];
125 for (
int i=0; i<
dim; i++){
129 for (
int i=0; i<
dim; i++){
130 for (
int j=i+1; j<
dim; j++){
133 ptsr[str] += tsr[str];
134 ptsr[str] += tsr[pstr];
135 if (ptsr.
norm2() > 1.E-6)
return false;
148 for (
int i=0; i<
dim; i++){
152 for (
int i=0; i<
dim; i++){
153 for (
int j=i+1; j<
dim; j++){
156 ptsr[str] += tsr[str];
157 ptsr[str] -= tsr[pstr];
158 if (ptsr.
norm2() > 1.E-6)
return false;
168 for (int64_t i=1; i<=n; i++){
182 void chi(
char const * idx,
193 if (p_len+q_len > idx_len){
197 if (idx_len == 0 || (p_len == 0 && q_len == 0)){
199 char ** ip = (
char**)malloc(
sizeof(
char*));
200 char ** iq = (
char**)malloc(
sizeof(
char*));
206 np =
choose(idx_len, p_len);
207 np *=
choose(idx_len-p_len, q_len);
211 char ** ip = (
char**)malloc(
sizeof(
char*)*
np);
212 char ** iq = (
char**)malloc(
sizeof(
char*)*
np);
217 for (
int i=0; i<
np; i++){
218 ip[i] = (
char*)malloc(
sizeof(
char)*p_len);
219 iq[i] = (
char*)malloc(
sizeof(
char)*q_len);
226 chi(idx, idx_len-1, p_len-1, 0, &n1_len, &n1_ip, &qnull);
228 for (
int i=0; i<n1_len; i++){
229 memcpy(ip[i], n1_ip[i],
sizeof(
char)*(p_len-1));
230 ip[i][p_len-1] = idx[idx_len-1];
235 chi(idx, idx_len-1, p_len, 0, &n2_len, &n2_ip, &qnull);
236 assert(n2_len + n1_len == np);
238 for (
int i=0; i<n2_len; i++){
239 memcpy(ip[i+n1_len], n2_ip[i],
sizeof(
char)*p_len);
241 }
else if (p_len == 0){
245 chi(idx, idx_len-1, 0, q_len-1, &n1_len, &pnull, &n1_iq);
247 for (
int i=0; i<n1_len; i++){
248 memcpy(iq[i], n1_iq[i],
sizeof(
char)*(q_len-1));
249 iq[i][q_len-1] = idx[idx_len-1];
254 chi(idx, idx_len-1, 0, q_len, &n2_len, &pnull, &n2_iq);
255 assert(n2_len + n1_len == np);
257 for (
int i=0; i<n2_len; i++){
258 memcpy(iq[i+n1_len], n2_iq[i],
sizeof(
char)*q_len);
264 chi(idx, idx_len-1, p_len-1, q_len, &n1_len, &n1_ip, &n1_iq);
267 for (
int i=0; i<n1_len; i++){
268 memcpy(ip[i], n1_ip[i],
sizeof(
char)*(p_len-1));
269 ip[i][p_len-1] = idx[idx_len-1];
270 memcpy(iq[i], n1_iq[i],
sizeof(
char)*q_len);
276 chi(idx, idx_len-1, p_len, q_len-1, &n2_len, &n2_ip, &n2_iq);
278 for (
int i=0; i<n2_len; i++){
279 memcpy(ip[i+n1_len], n2_ip[i],
sizeof(
char)*p_len);
280 memcpy(iq[i+n1_len], n2_iq[i],
sizeof(
char)*(q_len-1));
281 iq[i+n1_len][q_len-1] = idx[idx_len-1];
287 chi(idx, idx_len-1, p_len, q_len, &n3_len, &n3_ip, &n3_iq);
289 for (
int i=0; i<n3_len; i++){
290 memcpy(ip[i+n1_len+n2_len], n3_ip[i],
sizeof(
char)*p_len);
291 memcpy(iq[i+n1_len+n2_len], n3_iq[i],
sizeof(
char)*q_len);
294 assert(n1_len+n2_len+n3_len==np);
311 void chi(
char const * idx,
318 chi(idx, idx_len, p_len, idx_len-p_len, npair, idx_p, &idx_q);
329 MPI_Comm_rank(ctf.
comm, &rank);
344 for (i=0; i<s+v; i++){
350 idx_A[i] =
'a'+(s+t)+(i-s);
352 for (i=0; i<t+v; i++){
356 idx_B[i] =
'a'+(s+t)+i;
358 idx_B[i] =
'a'+s+(i-v);
360 for (i=0; i<s+t; i++){
366 Tensor<> A(s+v, len_A, sym_A, ctf,
"A", 1);
367 Tensor<> B(t+v, len_B, sym_B, ctf,
"B", 1);
368 Tensor<> C(s+t, len_C, sym_C, ctf,
"C", 1);
369 Tensor<> C_int(s+t, len_C, sym_C, ctf,
"C_psym", 1);
370 Tensor<> C_ans(s+t, len_C, sym_C, ctf,
"C_ans", 1);
382 C_int[idx_C] += A[idx_A]*B[idx_B];
388 chi(idx_C, s+t, s, t, &ncperms, &idx_As, &idx_Bt);
390 for (i=0; i<ncperms; i++){
392 memcpy(idx_C_int, idx_As[i],
sizeof(
char)*s);
393 memcpy(idx_C_int+s, idx_Bt[i],
sizeof(
char)*t);
394 C_ans[idx_C] += C_int[idx_C_int];
397 if (is_C_sym) printf(
"C_ans is symmetric\n");
398 else printf(
"C_ans is NOT symmetric!!\n");
403 for (i=0; i<s+v+t; i++){
409 Tensor<> Z_A_ops(s+v+t, len_Z, sym_Z, ctf,
"Z_A", 1);
410 Tensor<> Z_B_ops(s+v+t, len_Z, sym_Z, ctf,
"Z_B", 1);
411 Tensor<> Z_mults(s+v+t, len_Z, sym_Z, ctf,
"Z", 1);
416 chi(idx_Z, s+t+v, s+v, &nAperms, &idx_Asv);
418 for (i=0; i<nAperms; i++){
419 Z_A_ops[idx_Z] +=
sign(
parity(idx_Asv[i], idx_Z, s+v, s+t+v))*A[idx_Asv[i]];
422 if (is_A_asym) printf(
"Z_A_ops is antisymmetric\n");
423 else printf(
"Z_A_ops is NOT antisymmetric!!\n");
428 chi(idx_Z, s+t+v, t+v, &nBperms, &idx_Btv);
430 for (i=0; i<nBperms; i++){
431 Z_B_ops[idx_Z] +=
sign(
parity(idx_Btv[i], idx_Z, t+v, s+t+v))*B[idx_Btv[i]];
434 if (is_B_asym) printf(
"Z_B_ops is antisymmetric\n");
435 else printf(
"Z_B_ops is NOT antisymmetric!!\n");
438 Z_mults[idx_Z] = Z_A_ops[idx_Z]*Z_B_ops[idx_Z];
441 if (is_Z_sym) printf(
"Z_mults is symmetric\n");
442 else printf(
"Z_mults is NOT symmetric!!\n");
444 memcpy(idx_Z,idx_C,(s+t)*
sizeof(
char));
445 for (i=s+t; i<s+t+v; i++){
446 idx_Z[i] = idx_Z[s+t-1]+(i-s-t+1);
450 C[idx_C]+=
sign(s*t+s*v)*Z_mults[idx_Z];
452 Tensor<> V(s+t, len_C, sym_C, ctf,
"V");
453 for (
int r=0; r<v; r++){
454 for (
int p=std::max(v-t-r,0); p<=v-r; p++){
455 for (
int q=std::max(v-s-r,0); q<=v-p-r; q++){
458 double sgn_V =
sign(s*t+v*(t+p+r)+p*q+(q+r)*(t+1));
464 idx_kr[i] =
'a'+s+t+i;
468 idx_kp[i] =
'a'+s+t+r+i;
472 idx_kq[i] =
'a'+s+t+r+p+i;
475 Tensor<> V_A_ops(s+t+r, len_Z, sym_Z, ctf,
"V_A_ops");
477 memcpy(idx_VA,idx_C,(s+t)*
sizeof(
char));
478 memcpy(idx_VA+s+t,idx_kr,r*
sizeof(
char));
482 chi(idx_C, s+t, s+v-p-r, &nvAperms, &idx_VAsvpr);
483 for (i=0; i<nvAperms; i++){
485 memcpy(idx_VAA, idx_VAsvpr[i], (s+v-p-r)*
sizeof(
char));
486 memcpy(idx_VAA+s+v-p-r, idx_kr, r*
sizeof(
char));
487 memcpy(idx_VAA+s+v-p, idx_kp, p*
sizeof(
char));
488 double sgn_VA =
sign(
parity(idx_VAsvpr[i], idx_C, s+v-p-r, s+t));
489 V_A_ops[idx_VA] += sgn_VA*A[idx_VAA];
492 Tensor<> V_B_ops(s+t+r, len_Z, sym_Z, ctf,
"V_B_ops");
494 memcpy(idx_VB,idx_C,(s+t)*
sizeof(
char));
495 memcpy(idx_VB+s+t,idx_kr,r*
sizeof(
char));
500 chi(idx_C, s+t, t+v-q-r, &nvBperms, &idx_VBtvqr);
501 for (i=0; i<nvBperms; i++){
506 memcpy(idx_VBB, idx_kr, r*
sizeof(
char));
507 memcpy(idx_VBB+r, idx_kq, q*
sizeof(
char));
508 memcpy(idx_VBB+r+q, idx_VBtvqr[i], (t+v-q-r)*
sizeof(
char));
515 double sgn_VB =
sign(
parity(idx_VBtvqr[i], idx_C, t+v-q-r, s+t));
516 V_B_ops[idx_VB] += sgn_VB*B[idx_VBB];
519 V[idx_C] += prefact*V_A_ops[idx_VA]*V_B_ops[idx_VB];
523 Tensor<> W(s+t, len_C, sym_C, ctf,
"W");
524 for (
int r=1; r<=std::min(s,t); r++){
526 for (
int i=0; i<r; i++){
527 idx_kr[i] =
'a'+s+t+i;
530 for (
int i=0; i<v; i++){
531 idx_kv[i] =
'a'+s+t+r+i;
533 Tensor<> U(s+t-r, len_C, sym_C, ctf,
"U");
537 memcpy(idx_U, idx_kr,
sizeof(
char)*r);
538 memcpy(idx_UA, idx_kr,
sizeof(
char)*r);
539 memcpy(idx_UB+t+v-r, idx_kr,
sizeof(
char)*r);
541 char ** idxj, ** idxl;
542 chi(idx_C, s+t-2*r, s-r, t-r, &npermU, &idxj, &idxl);
543 memcpy(idx_U+r,idx_C,s+t-2*r);
544 for (
int i=0; i<npermU; i++){
545 memcpy(idx_UA+r,idxj[i],s-r);
546 memcpy(idx_UB+v,idxl[i],t-r);
547 memcpy(idx_UA+s, idx_kv,
sizeof(
char)*v);
549 memcpy(idx_UB, idx_kv,
sizeof(
char)*v);
551 U[idx_U] += A[idx_UA]*B[idx_UB];
555 chi(idx_C, s+t, s+t-r, &npermW1,&idxh1);
556 for (
int j=0; j<npermW1; j++){
559 char ** idxh, ** idxr;
560 chi(idxh1[j], s+t-r, r, s+t-2*r, &npermW, &idxr, &idxh);
561 for (
int i=0; i<npermW; i++){
562 memcpy(idx_U,idxr[i],r);
563 memcpy(idx_U+r,idxh[i],s+t-2*r);
565 W[idx_C] -= U[idx_U];
573 C[idx_C] -= V[idx_C];
574 C[idx_C] -= W[idx_C];
576 C[idx_C] -= C_ans[idx_C];
578 double nrm = C.
norm2();
580 printf(
"error 2-norm is %.4E\n",nrm);
582 int pass = (nrm <=1.E-3);
585 if (pass) printf(
"{ fast symmetric tensor contraction algorithm test } passed\n");
586 else printf(
"{ fast symmetric tensor contraction algorithm test } failed\n");
595 char ** itr = std::find(begin, end, option);
596 if (itr != end && ++itr != end){
603 int main(
int argc,
char ** argv){
605 int const in_num = argc;
606 char ** input_str = argv;
608 MPI_Init(&argc, &argv);
609 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
610 MPI_Comm_size(MPI_COMM_WORLD, &np);
613 n = atoi(
getCmdOption(input_str, input_str+in_num,
"-n"));
618 s = atoi(
getCmdOption(input_str, input_str+in_num,
"-s"));
623 t = atoi(
getCmdOption(input_str, input_str+in_num,
"-t"));
628 v = atoi(
getCmdOption(input_str, input_str+in_num,
"-v"));
633 World dw(MPI_COMM_WORLD, argc, argv);
635 printf(
"Contracting symmetric A of order %d with B of order %d into C of order %d, all with dimension %d\n",s+v,t+v,s+t,n);
void get_rand_as_tsr(Tensor<> &tsr, int seed)
void chi(char const *idx, int idx_len, int p_len, int q_len, int *npair, char ***idx_p, char ***idx_q)
int * sym
symmetries among tensor dimensions
Vector class which encapsulates a 1D tensor.
bool check_asym(Tensor<> &tsr)
an instance of the CTF library (world) on a MPI communicator
int order
number of tensor dimensions
void read_local(int64_t *npair, int64_t **global_idx, dtype **data, bool unpack_sym=false) const
Using get_local_data(), which returns an array that must be freed with delete [], is more efficient...
dtype norm2()
computes the frobenius norm of the tensor (needs sqrt()!)
CTF::World * wrld
distributed processor context on which tensor is defined
int * lens
unpadded tensor edge lengths
int main(int argc, char **argv)
int fast_tensor_ctr(int n, int s, int t, int v, World &ctf)
char * getCmdOption(char **begin, char **end, const std::string &option)
bool check_sym(Tensor<> &tsr)
an instance of a tensor within a CTF world
int64_t chchoose(int64_t n, int64_t k)
int parity(char const *a, char const *b, char const *c, int len_A, int len_B)
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...
int64_t choose(int64_t n, int64_t k)
MPI_Comm comm
set of processors making up this world