4 #include "../tensor/algstrct.h" 12 bool try_mkl_coo_to_csr(int64_t nz,
int nrow,
char * csr_vs,
int * csr_ja,
int * csr_ia,
char const * coo_vs,
int const * coo_rs,
int const * coo_cs,
int el_size);
14 bool try_mkl_csr_to_coo(int64_t nz,
int nrow,
char const * csr_vs,
int const * csr_ja,
int const * csr_ia,
char * coo_vs,
int * coo_rs,
int * coo_cs,
int el_size);
16 template <
typename dtype>
17 void seq_coo_to_csr(int64_t nz,
int nrow,
dtype * csr_vs,
int * csr_ja,
int * csr_ia,
dtype const * coo_vs,
int const * coo_rs,
int const * coo_cs){
18 int sz =
sizeof(
dtype);
19 if (sz == 4 || sz == 8 || sz == 16){
20 bool b =
try_mkl_coo_to_csr(nz, nrow, (
char*)csr_vs, csr_ja, csr_ia, (
char const*)coo_vs, coo_rs, coo_cs, sz);
25 #pragma omp parallel for 27 for (
int i=1; i<nrow+1; i++){
31 #pragma omp parallel for 33 for (int64_t i=0; i<nz; i++){
37 #pragma omp parallel for 39 for (
int i=0; i<nrow; i++){
40 csr_ia[i+1] += csr_ia[i];
43 #pragma omp parallel for 45 for (int64_t i=0; i<nz; i++){
52 comp_ref(
int const * a_){ a = a_; }
53 bool operator()(
int u,
int v){
59 std::sort(csr_ja, csr_ja+nz, crc);
61 std::stable_sort(csr_ja, csr_ja+nz, crr);
65 #pragma omp parallel for 67 for (int64_t i=0; i<nz; i++){
71 csr_vs[i] = coo_vs[csr_ja[i]];
77 #pragma omp parallel for 79 for (int64_t i=0; i<nz; i++){
80 csr_ja[i] = coo_cs[csr_ja[i]];
84 template <
typename dtype>
85 void seq_csr_to_coo(int64_t nz,
int nrow,
dtype const * csr_vs,
int const * csr_ja,
int const * csr_ia,
dtype * coo_vs,
int * coo_rs,
int * coo_cs){
86 int sz =
sizeof(
dtype);
87 if (sz == 4 || sz == 8 || sz == 16){
88 bool b =
try_mkl_csr_to_coo(nz, nrow, (
char const*)csr_vs, csr_ja, csr_ia, (
char*)coo_vs, coo_rs, coo_cs, sz);
93 memcpy(coo_cs, csr_ja,
sizeof(
int)*nz);
94 for (
int i=0; i<nrow; i++){
95 std::fill(coo_rs+csr_ia[i]-1, coo_rs+csr_ia[i+1]-1, i+1);
99 template <
typename dtype>
100 void def_coo_to_csr(int64_t nz,
int nrow,
dtype * csr_vs,
int * csr_ja,
int * csr_ia,
dtype const * coo_vs,
int const * coo_rs,
int const * coo_cs){
101 seq_coo_to_csr<dtype>(nz, nrow, csr_vs, csr_ja, csr_ia, coo_vs, coo_rs, coo_cs);
104 template <
typename dtype>
105 void def_csr_to_coo(int64_t nz,
int nrow,
dtype const * csr_vs,
int const * csr_ja,
int const * csr_ia,
dtype * coo_vs,
int * coo_rs,
int * coo_cs){
106 seq_csr_to_coo<dtype>(nz, nrow, csr_vs, csr_ja, csr_ia, coo_vs, coo_rs, coo_cs);
109 template <
typename dtype>
114 template <
typename dtype,
bool is_ord>
115 inline typename std::enable_if<is_ord, dtype>::type
117 dtype b = default_addinv<dtype>(
a);
121 template <
typename dtype,
bool is_ord>
122 inline typename std::enable_if<!is_ord, dtype>::type
124 printf(
"CTF ERROR: cannot compute abs unless the set is ordered");
129 template <
typename dtype, dtype (*abs)(dtype)>
136 template <
typename dtype,
bool is_ord>
137 inline typename std::enable_if<is_ord, dtype>::type
142 template <
typename dtype,
bool is_ord>
143 inline typename std::enable_if<!is_ord, dtype>::type
145 printf(
"CTF ERROR: cannot compute a max unless the set is ordered");
150 template <
typename dtype,
bool is_ord>
151 inline typename std::enable_if<is_ord, dtype>::type
153 return std::numeric_limits<dtype>::max();
156 template <
typename dtype,
bool is_ord>
157 inline typename std::enable_if<!is_ord, dtype>::type
159 printf(
"CTF ERROR: cannot compute a max unless the set is ordered");
165 template <
typename dtype,
bool is_ord>
166 inline typename std::enable_if<is_ord, dtype>::type
168 return std::numeric_limits<dtype>::min();
171 template <
typename dtype,
bool is_ord>
172 inline typename std::enable_if<!is_ord, dtype>::type
174 printf(
"CTF ERROR: cannot compute a max unless the set is ordered");
180 template <
typename dtype,
bool is_ord>
181 inline typename std::enable_if<is_ord, dtype>::type
186 template <
typename dtype,
bool is_ord>
187 inline typename std::enable_if<!is_ord, dtype>::type
189 printf(
"CTF ERROR: cannot compute a min unless the set is ordered");
193 template <
typename dtype>
195 MPI_Datatype newtype;
196 MPI_Type_contiguous(
sizeof(
dtype), MPI_BYTE, &newtype);
197 MPI_Type_commit(&newtype);
209 inline MPI_Datatype get_default_mdtype< std::complex<double> >(
bool & is_custom){ is_custom=
false;
return MPI_CTF_DOUBLE_COMPLEX; }
229 inline MPI_Datatype get_default_mdtype< std::complex<float> >(
bool & is_custom){ is_custom=
false;
return MPI_COMPLEX; }
233 template <
typename dtype>
238 #define INST_ORD_TYPE(dtype) \ 240 constexpr bool get_default_is_ord<dtype>(){ \ 260 template <
typename dtype>
264 bool operator < (const dtypePair<dtype>& other)
const {
265 return (key < other.key);
279 template <
typename dtype=
double,
bool is_ord=CTF_
int::get_default_is_ord<dtype>()>
286 if (is_custom_mdtype) MPI_Type_free(&tmdtype);
291 tmdtype = CTF_int::get_default_mdtype<dtype>(is_custom_mdtype);
294 is_custom_mdtype =
false;
296 pair_sz =
sizeof(std::pair<int64_t,dtype>);
307 return ((std::pair<int64_t,dtype>
const *)a)->first;
311 return (
char*)&(((std::pair<int64_t,dtype>
const *)a)->second);
315 return (
char const *)&(((std::pair<int64_t,dtype>
const *)a)->second);
327 tmdtype = CTF_int::get_default_mdtype<dtype>(is_custom_mdtype);
328 set_abs_to_default();
329 pair_sz =
sizeof(std::pair<int64_t,dtype>);
333 abs = &CTF_int::char_abs< dtype, CTF_int::default_abs<dtype, is_ord> >;
343 ((
dtype*)c)[0] = CTF_int::default_min<dtype,is_ord>(((
dtype*)a)[0],((
dtype*)b)[0]);
349 ((
dtype*)c)[0] = CTF_int::default_max<dtype,is_ord>(((
dtype*)a)[0],((
dtype*)b)[0]);
352 void min(
char * c)
const {
353 ((
dtype*)c)[0] = CTF_int::default_min_lim<dtype,is_ord>();
356 void max(
char * c)
const {
357 ((
dtype*)c)[0] = CTF_int::default_max_lim<dtype,is_ord>();
362 printf(
"CTF ERROR: double cast not possible for this algebraic structure\n");
368 printf(
"CTF ERROR: integer cast not possible for this algebraic structure\n");
373 printf(
"CTF ERROR: double cast not possible for this algebraic structure\n");
380 printf(
"CTF ERROR: int cast not possible for this algebraic structure\n");
386 void print(
char const *
a, FILE * fp=stdout)
const {
387 for (
int i=0; i<el_size; i++){
388 fprintf(fp,
"%x",a[i]);
394 if (a == NULL && b == NULL)
return true;
395 if (a == NULL || b == NULL)
return false;
396 for (
int i=0; i<el_size; i++){
397 if (a[i] != b[i])
return false;
402 void coo_to_csr(int64_t nz,
int nrow,
char * csr_vs,
int * csr_ja,
int * csr_ia,
char const * coo_vs,
int const * coo_rs,
int const * coo_cs)
const {
406 void csr_to_coo(int64_t nz,
int nrow,
char const * csr_vs,
int const * csr_ja,
int const * csr_ia,
char * coo_vs,
int * coo_rs,
int * coo_cs)
const {
412 return (
char*)(
new std::pair<int64_t,dtype>[n]);
417 return (
char*)(
new dtype[n]);
421 return delete [] (
dtype*)ptr;
425 return delete [] (std::pair<int64_t,dtype>*)ptr;
429 void sort(int64_t n,
char * pairs)
const {
433 void copy(
char *
a,
char const *
b)
const {
437 void copy(
char *
a,
char const *
b, int64_t n)
const {
442 ((std::pair<int64_t,dtype> *)a)[0] = ((std::pair<int64_t,dtype>
const *)b)[0];
446 std::copy((std::pair<int64_t,dtype>
const *)b, ((std::pair<int64_t,dtype>
const *)b) + n, (std::pair<int64_t,dtype> *)a);
457 void set(
char *
a,
char const *
b, int64_t n)
const {
462 ((std::pair<int64_t,dtype> *)a)[0] = std::pair<int64_t,dtype>(
key,*((
dtype*)b));
465 void set_pairs(
char * a,
char const * b, int64_t n)
const {
466 std::fill((std::pair<int64_t,dtype> *)a, (std::pair<int64_t,dtype> *)a + n, *(std::pair<int64_t,dtype>
const*)b);
469 void copy(int64_t n,
char const * a,
int inc_a,
char * b,
int inc_b)
const {
472 for (int64_t i=0; i<n; i++){
473 db[inc_b*i] = da[inc_a*i];
482 int64_t lda_b)
const {
486 for (int64_t j=0; j<n; j++){
487 for (int64_t i=0; i<m; i++){
488 db[j*lda_b+i] = da[j*lda_a+i];
493 void init(int64_t n,
char * arr)
const {
503 for (
int i=0; i<n; i++){
504 memcpy(arr+i*el_size,(
char*)&dummy,el_size);
536 ((
float*)c)[0] = (float)d;
546 ((
long double*)c)[0] = (
long double)d;
551 ((
int*)c)[0] = (int)d;
556 ((uint64_t*)c)[0] = (uint64_t)d;
561 ((int64_t*)c)[0] = (int64_t)d;
566 ((std::complex<float>*)c)[0] = (std::complex<float>)d;
571 ((std::complex<double>*)c)[0] = (std::complex<double>)d;
576 ((std::complex<long double>*)c)[0] = (std::complex<long double>)d;
581 ((
float*)c)[0] = (float)d;
586 ((
double*)c)[0] = (double)d;
591 ((
long double*)c)[0] = (
long double)d;
596 ((
int*)c)[0] = (int)d;
601 ((uint64_t*)c)[0] = (uint64_t)d;
606 ((int64_t*)c)[0] = (int64_t)d;
610 inline void Set< std::complex<float>,
false >::cast_int(int64_t d,
char * c)
const {
611 ((std::complex<float>*)c)[0] = (std::complex<float>)d;
615 inline void Set< std::complex<double>,
false >::cast_int(int64_t d,
char * c)
const {
616 ((std::complex<double>*)c)[0] = (std::complex<double>)d;
620 inline void Set< std::complex<long double>,
false >::cast_int(int64_t d,
char * c)
const {
621 ((std::complex<long double>*)c)[0] = (std::complex<long double>)d;
626 return (
double)(((
float*)c)[0]);
631 return ((
double*)c)[0];
636 return (
double)(((
int*)c)[0]);
641 return (
double)(((uint64_t*)c)[0]);
646 return (
double)(((int64_t*)c)[0]);
652 return ((int64_t*)c)[0];
657 return (int64_t)(((
int*)c)[0]);
662 return (int64_t)(((
unsigned int*)c)[0]);
667 return (int64_t)(((uint64_t*)c)[0]);
672 return (int64_t)(((
bool*)c)[0]);
677 fprintf(fp,
"%11.5E",((
float*)a)[0]);
682 fprintf(fp,
"%11.5E",((
double*)a)[0]);
687 fprintf(fp,
"%ld",((int64_t*)a)[0]);
692 fprintf(fp,
"%d",((
int*)a)[0]);
696 inline void Set< std::complex<float>,
false >::print(
char const *
a, FILE * fp)
const {
697 fprintf(fp,
"(%11.5E,%11.5E)",((std::complex<float>*)a)[0].
real(),((std::complex<float>*)a)[0].
imag());
701 inline void Set< std::complex<double>,
false >::print(
char const *
a, FILE * fp)
const {
702 fprintf(fp,
"(%11.5E,%11.5E)",((std::complex<double>*)a)[0].
real(),((std::complex<double>*)a)[0].
imag());
706 inline void Set< std::complex<long double>,
false >::print(
char const *
a, FILE * fp)
const {
707 fprintf(fp,
"(%11.5LE,%11.5LE)",((std::complex<long double>*)a)[0].
real(),((std::complex<long double>*)a)[0].
imag());
712 if (a == NULL && b == NULL)
return true;
713 if (a == NULL || b == NULL)
return false;
714 return ((
float*)a)[0] == ((
float*)b)[0];
719 if (a == NULL && b == NULL)
return true;
720 if (a == NULL || b == NULL)
return false;
721 return ((
double*)a)[0] == ((
double*)b)[0];
726 if (a == NULL && b == NULL)
return true;
727 if (a == NULL || b == NULL)
return false;
728 return ((
int*)a)[0] == ((
int*)b)[0];
733 if (a == NULL && b == NULL)
return true;
734 if (a == NULL || b == NULL)
return false;
735 return ((uint64_t*)a)[0] == ((uint64_t*)b)[0];
740 if (a == NULL && b == NULL)
return true;
741 if (a == NULL || b == NULL)
return false;
742 return ((int64_t*)a)[0] == ((int64_t*)b)[0];
747 if (a == NULL && b == NULL)
return true;
748 if (a == NULL || b == NULL)
return false;
749 return ((
long double*)a)[0] == ((
long double*)b)[0];
753 inline bool Set< std::complex<float>,
false >::isequal(
char const *
a,
char const *
b)
const {
754 if (a == NULL && b == NULL)
return true;
755 if (a == NULL || b == NULL)
return false;
756 return (( std::complex<float> *)a)[0] == (( std::complex<float> *)b)[0];
760 inline bool Set< std::complex<double>,
false >::isequal(
char const *
a,
char const *
b)
const {
761 if (a == NULL && b == NULL)
return true;
762 if (a == NULL || b == NULL)
return false;
763 return (( std::complex<double> *)a)[0] == (( std::complex<double> *)b)[0];
767 inline bool Set< std::complex<long double>,
false >::isequal(
char const *
a,
char const *
b)
const {
768 if (a == NULL && b == NULL)
return true;
769 if (a == NULL || b == NULL)
return false;
770 return (( std::complex<long double> *)a)[0] == (( std::complex<long double> *)b)[0];
Set class defined by a datatype and a min/max function (if it is partially ordered i...
void set_pair(char *a, int64_t key, char const *b) const
sets 1 elements of pair a to value and key
double cast_to_double(char const *c) const
return (double)*c
#define INST_ORD_TYPE(dtype)
char * pair_alloc(int64_t n) const
allocate space for n (int64_t,dtype) pairs, necessary for object types
bool try_mkl_coo_to_csr(int64_t nz, int nrow, char *csr_vs, int *csr_ja, int *csr_ia, char const *coo_vs, int const *coo_rs, int const *coo_cs, int el_size)
void min(char *c) const
c = minimum possible value
MPI_Datatype mdtype() const
MPI datatype.
void copy_pairs(char *a, char const *b, int64_t n) const
copies n pair from array b to array a
std::enable_if< is_ord, dtype >::type default_max_lim()
MPI_Datatype get_default_mdtype< int >(bool &is_custom)
void max(char const *a, char const *b, char *c) const
c = max(a,b)
void pair_dealloc(char *ptr) const
deallocate given pointer containing contiguous array of pairs
virtual CTF_int::algstrct * clone() const
''copy constructor''
void copy(char *a, char const *b) const
copies element b to element a
MPI_Datatype get_default_mdtype< bool >(bool &is_custom)
int pair_size() const
gets pair size el_size plus the key size
void def_csr_to_coo(int64_t nz, int nrow, dtype const *csr_vs, int const *csr_ja, int const *csr_ia, dtype *coo_vs, int *coo_rs, int *coo_cs)
void char_abs(char const *a, char *b)
void seq_csr_to_coo(int64_t nz, int nrow, dtype const *csr_vs, int const *csr_ja, int const *csr_ia, dtype *coo_vs, int *coo_rs, int *coo_cs)
char const * get_const_value(char const *a) const
constexpr bool get_default_is_ord()
MPI_Datatype get_default_mdtype< long double >(bool &is_custom)
MPI_Datatype get_default_mdtype(bool &is_custom)
void set_abs_to_default()
std::enable_if< is_ord, dtype >::type default_max(dtype a, dtype b)
bool try_mkl_csr_to_coo(int64_t nz, int nrow, char const *csr_vs, int const *csr_ja, int const *csr_ia, char *coo_vs, int *coo_rs, int *coo_cs, int el_size)
void copy(int64_t n, char const *a, int inc_a, char *b, int inc_b) const
copies n elements TO array b with increment inc_a FROM array a with increment inc_b ...
bool isequal(char const *a, char const *b) const
returns true if algstrct elements a and b are equal
void print(char const *a, FILE *fp=stdout) const
prints the value
MPI_Datatype get_default_mdtype< double >(bool &is_custom)
void csr_to_coo(int64_t nz, int nrow, char const *csr_vs, int const *csr_ja, int const *csr_ia, char *coo_vs, int *coo_rs, int *coo_cs) const
converts CSR sparse matrix layout to coordinate (COO) layout
MPI_Datatype get_default_mdtype< unsigned int >(bool &is_custom)
void copy_pair(char *a, char const *b) const
copies pair b to element a
std::enable_if< is_ord, dtype >::type default_abs(dtype a)
void coo_to_csr(int64_t nz, int nrow, char *csr_vs, int *csr_ja, int *csr_ia, char const *coo_vs, int const *coo_rs, int const *coo_cs) const
converts coordinate sparse matrix layout to CSR layout
MPI_Datatype MPI_CTF_LONG_DOUBLE_COMPLEX
void(* abs)(char const *a, char *b)
b = max(a,addinv(a))
MPI_Datatype get_default_mdtype< float >(bool &is_custom)
void dealloc(char *ptr) const
deallocate given pointer containing contiguous array of values
void copy(int64_t m, int64_t n, char const *a, int64_t lda_a, char *b, int64_t lda_b) const
copies m-by-n submatrix from a with lda_a to b with lda_b
void max(char *c) const
c = maximum possible value
char * get_value(char *a) const
gets pair to value from pair
void sort(int64_t n, char *pairs) const
sorts n sets of pairs using std::sort
char * alloc(int64_t n) const
allocate space for n items, necessary for object types
std::enable_if< is_ord, dtype >::type default_min(dtype a, dtype b)
std::enable_if< is_ord, dtype >::type default_min_lim()
void copy(char *a, char const *b, int64_t n) const
copies n elements from array b to array a
MPI_Datatype get_default_mdtype< char >(bool &is_custom)
void init(int64_t n, char *arr) const
initialize n objects to zero
MPI_Datatype get_default_mdtype< uint64_t >(bool &is_custom)
dtype default_addinv(dtype a)
void set_pairs(char *a, char const *b, int64_t n) const
sets n elements of array of pairs a to value b
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void def_coo_to_csr(int64_t nz, int nrow, dtype *csr_vs, int *csr_ja, int *csr_ia, dtype const *coo_vs, int const *coo_rs, int const *coo_cs)
void seq_coo_to_csr(int64_t nz, int nrow, dtype *csr_vs, int *csr_ja, int *csr_ia, dtype const *coo_vs, int const *coo_rs, int const *coo_cs)
void cast_double(double d, char *c) const
c = &d
virtual void init_shell(int64_t n, char *arr) const
initialize n objects to zero
MPI_Datatype get_default_mdtype< int64_t >(bool &is_custom)
int64_t cast_to_int(char const *c) const
return (int64_t)*c
int64_t get_key(char const *a) const
gets key from pair
void cast_int(int64_t i, char *c) const
c = &i
void min(char const *a, char const *b, char *c) const
c = min(a,b)
MPI_Datatype MPI_CTF_BOOL
MPI_Datatype MPI_CTF_DOUBLE_COMPLEX