1 #ifndef __INT_SYMMETRIZATION_H__ 2 #define __INT_SYMMETRIZATION_H__ 5 #include "../tensor/untyped_tensor.h" 6 #include "../summation/summation.h" 7 #include "../contraction/contraction.h" 108 std::vector<int>& signs,
109 summation
const & new_perm,
122 std::vector<int>& signs,
123 contraction
const & new_perm,
138 std::vector<summation>& perms,
139 std::vector<int>& signs);
150 std::vector<contraction>& perms,
151 std::vector<int>& signs);
def sum(tensor, init_A, axis=None, dtype=None, out=None, keepdims=None)
void get_sym_perms(summation const &sum, std::vector< summation > &perms, std::vector< int > &signs)
finds all permutations of a summation that must be done for a broken symmetry
void order_perm(tensor const *A, tensor const *B, int *idx_arr, int off_A, int off_B, int *idx_A, int *idx_B, int &add_sign, int &mod)
orders the summation indices of one tensor that don't break summation symmetries
void desymmetrize(tensor *sym_tsr, tensor *nonsym_tsr, bool is_C)
unfolds the data of a tensor
void add_sym_perm(std::vector< summation > &perms, std::vector< int > &signs, summation const &new_perm, int new_sign)
puts a summation map into a nice ordering according to preserved symmetries, and adds it if it is dis...
void cmp_sym_perms(int ndim, int const *sym, int *nperm, int **perm, double *sign)
finds all permutations of a tensor according to a symmetry
void symmetrize(tensor *sym_tsr, tensor *nonsym_tsr)
folds the data of a tensor