2 #include "../shared/util.h" 3 #include "../shared/blas_symbs.h" 6 #include "../sparse_formats/csr.h" 19 bool operator < (
const CompPair& other)
const {
20 return (key < other.
key);
27 bool operator < (
const IntPair& other)
const {
28 return (key < other.
key);
36 return (key < other.
key);
43 bool operator < (
const BoolPair& other)
const {
44 return (key < other.
key);
63 return (key < other.
key);
68 algstrct::algstrct(
int el_size_){
73 MPI_Op algstrct::addmop()
const {
74 printf(
"CTF ERROR: no addition MPI_Op present for this algebraic structure\n");
80 MPI_Datatype algstrct::mdtype()
const {
81 printf(
"CTF ERROR: no MPI_Datatype present for this algebraic structure\n");
89 char const * algstrct::addid()
const {
93 char const * algstrct::mulid()
const {
97 void algstrct::safeaddinv(
char const *
a,
char *&
b)
const {
98 printf(
"CTF ERROR: no additive inverse present for this algebraic structure\n");
103 void algstrct::addinv(
char const *
a,
char *
b)
const {
104 printf(
"CTF ERROR: no additive inverse present for this algebraic structure\n");
109 void algstrct::add(
char const *
a,
char const *
b,
char * c)
const {
110 printf(
"CTF ERROR: no addition operation present for this algebraic structure\n");
115 void algstrct::accum(
char const *
a,
char *
b)
const {
120 void algstrct::mul(
char const *
a,
char const *
b,
char * c)
const {
121 printf(
"CTF ERROR: no multiplication operation present for this algebraic structure\n");
126 void algstrct::safemul(
char const *
a,
char const *
b,
char *& c)
const {
127 printf(
"CTF ERROR: no multiplication operation present for this algebraic structure\n");
132 void algstrct::min(
char const *
a,
char const *
b,
char * c)
const {
133 printf(
"CTF ERROR: no min operation present for this algebraic structure\n");
138 void algstrct::max(
char const *
a,
char const *
b,
char * c)
const {
139 printf(
"CTF ERROR: no max operation present for this algebraic structure\n");
144 void algstrct::cast_int(int64_t i,
char * c)
const {
145 printf(
"CTF ERROR: integer scaling not possible for this algebraic structure\n");
150 void algstrct::cast_double(
double d,
char * c)
const {
151 printf(
"CTF ERROR: double scaling not possible for this algebraic structure\n");
156 double algstrct::cast_to_double(
char const * c)
const {
157 printf(
"CTF ERROR: double cast not possible for this algebraic structure\n");
163 int64_t algstrct::cast_to_int(
char const * c)
const {
164 printf(
"CTF ERROR: int cast not possible for this algebraic structure\n");
170 void algstrct::print(
char const *
a, FILE * fp)
const {
171 for (
int i=0; i<el_size; i++){
172 fprintf(fp,
"%x",a[i]);
176 void algstrct::min(
char * c)
const {
177 printf(
"CTF ERROR: min limit not present for this algebraic structure\n");
182 void algstrct::max(
char * c)
const {
183 printf(
"CTF ERROR: max limit not present for this algebraic structure\n");
188 void algstrct::sort(int64_t n,
char * pairs)
const {
189 switch (this->el_size){
234 #pragma omp parallel for 236 for (int64_t i=0; i<n; i++){
237 idx_pairs[i].
key = *(int64_t*)(pairs+i*(
sizeof(int64_t)+this->el_size));
238 idx_pairs[i].
idx = i;
241 char * swap_buffer = this->pair_alloc(n);
243 this->copy_pairs(swap_buffer, pairs, n);
245 std::sort(idx_pairs, idx_pairs+n);
251 #pragma omp parallel for 253 for (int64_t i=0; i<n; i++){
254 pip[i].
write_val(piw[idx_pairs[i].idx].d());
256 this->pair_dealloc(swap_buffer);
262 void algstrct::scal(
int n,
266 if (isequal(alpha, addid())){
267 if (incX == 1)
set(X, addid(), n);
269 for (
int i=0; i<n; i++){
270 copy(X+i*el_size, addid());
274 printf(
"CTF ERROR: scal not present for this algebraic structure\n");
280 void algstrct::axpy(
int n,
286 printf(
"CTF ERROR: axpy not present for this algebraic structure\n");
302 printf(
"CTF ERROR: gemm_batch not present for this algebraic structure\n");
317 printf(
"CTF ERROR: gemm not present for this algebraic structure\n");
332 printf(
"CTF ERROR: offload gemm not present for this algebraic structure\n");
336 bool algstrct::is_offloadable()
const {
340 bool algstrct::isequal(
char const *
a,
char const *
b)
const {
341 if (a == NULL && b == NULL)
return true;
342 if (a == NULL || b == NULL)
return false;
344 for (
int i=0; i<el_size; i++) {
345 if (a[i] != b[i]) iseq =
false;
350 void algstrct::coo_to_csr(int64_t nz,
int nrow,
char * csr_vs,
int * csr_cs,
int * csr_rs,
char const * coo_vs,
int const * coo_rs,
int const * coo_cs)
const {
351 printf(
"CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
355 void algstrct::csr_to_coo(int64_t nz,
int nrow,
char const * csr_vs,
int const * csr_ja,
int const * csr_ia,
char * coo_vs,
int * coo_rs,
int * coo_cs)
const {
356 printf(
"CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
362 char * algstrct::csr_add(
char * cA,
char * cB)
const {
367 char * algstrct::csr_reduce(
char * cA,
int root, MPI_Comm cm)
const {
369 MPI_Comm_rank(cm, &r);
370 MPI_Comm_size(cm, &p);
374 double t_st = MPI_Wtime();
375 while (p%s != 0) s++;
379 MPI_Comm_split(cm, r/s, sr, &scm);
380 MPI_Comm_split(cm, sr, r/s, &rcm);
383 int64_t sz_A = A.
size();
390 int64_t tot_buf_size = 0;
391 for (
int i=0; i<s; i++){
392 if (i==sr) snd_szs[i] = 0;
393 else snd_szs[i] = parts[i]->
size();
394 tot_buf_size += snd_szs[i];
397 MPI_Alltoall(snd_szs, 1, MPI_INT, rcv_szs, 1, MPI_INT, scm);
398 int64_t tot_rcv_sz = 0;
399 for (
int i=0; i<s; i++){
401 tot_rcv_sz += rcv_szs[i];
408 for (
int i=0; i<s; i++){
409 if (i>0) rcv_displs[i] = rcv_szs[i-1]+rcv_displs[i-1];
411 if (i==sr) smnds[i] = parts[i]->
all_data;
412 else smnds[i] = rcv_buf + rcv_displs[i];
417 MPI_Alltoallv(parts[0]->all_data, snd_szs, snd_displs, MPI_CHAR, rcv_buf, rcv_szs, rcv_displs, MPI_CHAR, scm);
418 for (
int i=0; i<s; i++){
438 for (
int z=1; z<s; z<<=1){
439 for (
int i=0; i<s-z; i+=2*z){
440 char * csr_new = csr_add(smnds[i], smnds[i+z]);
441 if ((smnds[i] < parts_buffer ||
442 smnds[i] > parts_buffer+tot_buf_size) &&
443 (smnds[i] < rcv_buf ||
444 smnds[i] > rcv_buf+tot_rcv_sz))
446 if ((smnds[i+z] < parts_buffer ||
447 smnds[i+z] > parts_buffer+tot_buf_size) &&
448 (smnds[i+z] < rcv_buf ||
449 smnds[i+z] > rcv_buf+tot_rcv_sz))
457 char * red_sum = csr_reduce(smnds[0], root/s, rcm);
459 if (smnds[0] != red_sum)
cdealloc(smnds[0]);
465 if (sroot == sr) sz = 0;
466 MPI_Gather(&sz, 1, MPI_INT, cb_sizes, 1, MPI_INT, sroot, scm);
467 int64_t tot_cb_size = 0;
470 for (
int i=0; i<s; i++){
471 cb_displs[i] = tot_cb_size;
472 tot_cb_size += cb_sizes[i];
476 MPI_Gatherv(red_sum, sz, MPI_CHAR, cb_bufs, cb_sizes, cb_displs, MPI_CHAR, sroot, scm);
480 for (
int i=0; i<s; i++){
481 smnds[i] = cb_bufs + cb_displs[i];
482 if (i==sr) smnds[i] = red_sum;
487 double t_end = MPI_Wtime() - t_st;
488 double tps[] = {t_end, 1.0, log2((
double)p), (double)sz_A};
508 double algstrct::estimate_csr_red_time(int64_t msg_sz,
CommData const * cdt)
const {
510 double ps[] = {1.0, log2((
double)cdt->
np), (double)msg_sz};
514 void algstrct::acc(
char *
b,
char const * beta,
char const *
a,
char const * alpha)
const {
521 void algstrct::accmul(
char * c,
char const *
a,
char const *
b,
char const * alpha)
const {
524 mul(tmp, alpha, tmp);
529 void algstrct::safecopy(
char *&
a,
char const *
b)
const {
539 memcpy(a, b, el_size);
542 void algstrct::copy_pair(
char *
a,
char const *
b)
const {
543 memcpy(a, b, pair_size());
547 memcpy(a, b, el_size*n);
550 void algstrct::copy_pairs(
char *
a,
char const *
b, int64_t n)
const {
551 memcpy(a, b, pair_size()*n);
565 CTF_BLAS::ZCOPY(&n, (std::complex<double>
const*)a, &inc_a, (std::complex<double>*)b, &inc_b);
569 #pragma omp parallel for 571 for (int64_t i=0; i<nn; i++){
572 copy(b+el_size*inc_b*i, a+el_size*inc_a*i);
583 int64_t lda_b)
const {
584 if (lda_a == m && lda_b == n){
585 memcpy(b,a,el_size*m*n);
587 for (
int i=0; i<n; i++){
588 memcpy(b+el_size*lda_b*i,a+el_size*lda_a*i,m*el_size);
600 char const * beta)
const {
601 if (!isequal(beta, mulid())){
602 if (isequal(beta, addid())){
604 set(
b, addid(), m*n);
606 for (
int i=0; i<n; i++){
607 set(b+i*lda_b*el_size, addid(), m);
612 scal(m*n, beta, b, 1);
614 for (
int i=0; i<n; i++){
615 scal(m, beta, b+i*lda_b*el_size, 1);
620 if (lda_a == m && lda_b == m){
621 axpy(m*n, alpha, a, 1, b, 1);
623 for (
int i=0; i<n; i++){
624 axpy(m, alpha, a+el_size*lda_a*i, 1, b+el_size*lda_b*i, 1);
629 void algstrct::set(
char *
a,
char const *
b, int64_t n)
const {
632 float * ia = (
float*)a;
633 float ib = *((
float*)b);
634 std::fill(ia, ia+n, ib);
638 double * ia = (
double*)a;
639 double ib = *((
double*)b);
640 std::fill(ia, ia+n, ib);
644 std::complex<double> * ia = (std::complex<double>*)a;
645 std::complex<double> ib = *((std::complex<double>*)b);
646 std::fill(ia, ia+n, ib);
650 for (
int i=0; i<n; i++) {
651 memcpy(a+i*el_size, b, el_size);
658 void algstrct::set_pair(
char *
a, int64_t
key,
char const * vb)
const {
659 memcpy(a, &key,
sizeof(int64_t));
660 memcpy(get_value(a), vb, el_size);
663 void algstrct::set_pairs(
char *
a,
char const *
b, int64_t n)
const {
664 for (
int i=0; i<n; i++) {
665 memcpy(a + i*pair_size(), b, pair_size());
669 int64_t algstrct::get_key(
char const *
a)
const {
673 char * algstrct::get_value(
char *
a)
const {
674 return a+
sizeof(int64_t);
677 char const * algstrct::get_const_value(
char const *
a)
const {
678 return a+
sizeof(int64_t);
681 char * algstrct::pair_alloc(int64_t n)
const {
689 void algstrct::dealloc(
char * ptr)
const {
693 void algstrct::pair_dealloc(
char * ptr)
const {
697 void algstrct::init(int64_t n,
char * arr)
const {
703 void algstrct::coomm(
int m,
int n,
int k,
char const * alpha,
char const * A,
int const * rows_A,
int const * cols_A, int64_t nnz_A,
char const * B,
char const * beta,
char * C,
bivar_function const * func)
const {
704 printf(
"CTF ERROR: coomm not present for this algebraic structure\n");
708 void algstrct::csrmm(
int m,
int n,
int k,
char const * alpha,
char const * A,
int const * JA,
int const * IA, int64_t nnz_A,
char const * B,
char const * beta,
char * C,
bivar_function const * func)
const {
709 printf(
"CTF ERROR: csrmm not present for this algebraic structure\n");
713 void algstrct::csrmultd
728 printf(
"CTF ERROR: csrmultd not present for this algebraic structure\n");
732 void algstrct::csrmultcsr
746 char *& C_CSR)
const {
748 printf(
"CTF ERROR: csrmultcsr not present for this algebraic structure\n");
753 sr=pi.
sr; ptr=pi.
ptr;
756 ConstPairIterator::ConstPairIterator(
algstrct const * sr_,
char const * ptr_){
764 int64_t ConstPairIterator::k()
const {
765 return ((int64_t*)ptr)[0];
768 char const * ConstPairIterator::d()
const {
769 return sr->get_const_value(ptr);
772 void ConstPairIterator::read(
char * buf, int64_t n)
const {
773 memcpy(buf, ptr, sr->pair_size()*n);
776 void ConstPairIterator::read_val(
char * buf)
const {
777 memcpy(buf, sr->get_const_value(ptr), sr->el_size);
780 PairIterator::PairIterator(
algstrct const * sr_,
char * ptr_){
789 int64_t PairIterator::k()
const {
790 return ((int64_t*)ptr)[0];
793 char * PairIterator::d()
const {
794 return sr->get_value(ptr);
797 void PairIterator::read(
char * buf, int64_t n)
const {
798 sr->copy_pair(buf, ptr);
801 void PairIterator::read_val(
char * buf)
const {
802 sr->copy(buf, sr->get_const_value(ptr));
805 void PairIterator::write(
char const * buf, int64_t n){
806 sr->copy_pairs(ptr, buf, n);
810 this->write(iter.
ptr, n);
814 this->write(iter.
ptr, n);
817 void PairIterator::write_val(
char const * buf){
818 sr->copy(sr->get_value(ptr), buf);
821 void PairIterator::write_key(int64_t
key){
822 ((int64_t*)ptr)[0] =
key;
825 void PairIterator::sort(int64_t n){
832 #pragma omp parallel for 834 for (int64_t i=0; i<n; i++){
835 int64_t k = rA[i].
k();
837 for (
int j=0; j<order; j++){
838 k_new += (k%old_lens[j])*new_lda[j];
841 ((int64_t*)wA[i].ptr)[0] = k_new;
849 void ConstPairIterator::pin(int64_t n,
int order,
int const * lens,
int const * divisor,
PairIterator pi_new){
853 alloc_ptr(order*
sizeof(
int), (
void**)&div_lens);
854 for (
int j=0; j<order; j++){
855 div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
859 #pragma omp parallel for 861 for (int64_t i=0; i<n; i++){
862 int64_t
key = pi[i].
k();
866 for (
int j=0; j<order; j++){
869 new_key += ((key%lens[j])/divisor[j])*lda;
873 ((int64_t*)pi_new[i].ptr)[0] = new_key;
883 void depin(
algstrct const * sr,
int order,
int const * lens,
int const * divisor,
int nvirt,
int const * virt_dim,
int const * phys_rank,
char * X, int64_t & new_nnz_B, int64_t * nnz_blk,
char *& new_B,
bool check_padding){
888 alloc_ptr(order*
sizeof(
int), (
void**)&div_lens);
889 for (
int j=0; j<order; j++){
890 div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
894 check_padding =
false;
895 for (
int v=0; v<nvirt; v++){
897 for (
int j=0; j<order; j++){
898 int vo = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
899 if (lens[j]%divisor[j] != 0 && vo >= lens[j]%divisor[j]){
900 check_padding =
true;
906 int64_t * old_nnz_blk_B = nnz_blk;
911 memcpy(old_nnz_blk_B, nnz_blk,
sizeof(int64_t)*nvirt);
912 memset(nnz_blk, 0,
sizeof(int64_t)*nvirt);
916 alloc_ptr(order*
sizeof(
int), (
void**)&virt_offset);
920 for (
int v=0; v<nvirt; v++){
923 for (
int j=0; j<order; j++){
924 virt_offset[j] = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
929 int64_t new_nnz_blk = 0;
932 for (int64_t i=0; i<old_nnz_blk_B[v]; i++){
933 int64_t
key = vpi[i].
k();
936 bool is_outside =
false;
937 for (
int j=0; j<order; j++){
939 if (((key%div_lens[j])*divisor[j]+virt_offset[j])>=lens[j]){
943 new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
945 key = key/div_lens[j];
949 ((int64_t*)vpi_new[new_nnz_blk].ptr)[0] = new_key;
950 vpi_new[new_nnz_blk].write_val(vpi[i].d());
954 nnz_blk[v] = new_nnz_blk;
955 new_nnz_B += nnz_blk[v];
956 nnz_off += old_nnz_blk_B[v];
962 #pragma omp parallel for 964 for (int64_t i=0; i<nnz_blk[v]; i++){
965 int64_t
key = vpi[i].
k();
969 for (int64_t j=0; j<order; j++){
970 new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
972 key = key/div_lens[j];
974 ((int64_t*)vpi_new[i].ptr)[0] = new_key;
977 nnz_off += nnz_blk[v];
992 switch (sr->el_size){
1026 #pragma omp parallel for 1028 for (int64_t i=0; i<n; i++){
1029 keys[i] = (*this)[i].k();
1031 return std::lower_bound(keys, keys+n, op.
k())-keys;
void permute(int order, int const *perm, int *arr)
permute an array
virtual int pair_size() const
gets pair size el_size plus the key size
virtual char * pair_alloc(int64_t n) const
allocate space for n (int64_t,dtype) pairs, necessary for object types
static char * csr_add(char *cA, char *cB, accumulatable const *adder)
LinModel< 3 > csrred_mdl_cst(csrred_mdl_cst_init,"csrred_mdl_cst")
LinModel< 3 > csrred_mdl(csrred_mdl_init,"csrred_mdl")
void SCOPY(const int *n, const float *dX, const int *incX, float *dY, const int *incY)
void * alloc(int64_t len)
alloc abstraction
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
untyped internal class for triply-typed bivariate function
void DCOPY(const int *n, const double *dX, const int *incX, double *dY, const int *incY)
void depin(algstrct const *sr, int order, int const *lens, int const *divisor, int nvirt, int const *virt_dim, int const *phys_rank, char *X, int64_t &new_nnz_B, int64_t *nnz_blk, char *&new_B, bool check_padding)
depins keys of n pairs
int64_t k() const
returns key of pair at head of ptr
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
int64_t k() const
returns key of pair at head of ptr
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
char * all_data
serialized buffer containing all info, index, and values related to matrix
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void write_val(char const *buf)
sets value of head pair to what is in buf
void ZCOPY(const int *n, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void partition(int s, char **parts_buffer, CSR_Matrix **parts)
splits CSR matrix into s submatrices (returned) corresponding to subsets of rows, all parts allocated...
double csrred_mdl_cst_init[]
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
int64_t size() const
retrieves buffer size out of all_data