2 #include "../shared/util.h"     3 #include "../shared/blas_symbs.h"     6 #include "../sparse_formats/csr.h"    19     bool operator < (
const CompPair& other)
 const {
    20       return (key < other.
key);
    27     bool operator < (
const IntPair& other)
 const {
    28       return (key < other.
key);
    36       return (key < other.
key);
    43     bool operator < (
const BoolPair& other)
 const {
    44       return (key < other.
key);
    63       return (key < other.
key);
    68   algstrct::algstrct(
int el_size_){
    73   MPI_Op algstrct::addmop()
 const {
    74     printf(
"CTF ERROR: no addition MPI_Op present for this algebraic structure\n");
    80   MPI_Datatype algstrct::mdtype()
 const {
    81     printf(
"CTF ERROR: no MPI_Datatype present for this algebraic structure\n");
    89   char const * algstrct::addid()
 const {
    93   char const * algstrct::mulid()
 const {
    97   void algstrct::safeaddinv(
char const * 
a, 
char *& 
b)
 const {
    98     printf(
"CTF ERROR: no additive inverse present for this algebraic structure\n");
   103   void algstrct::addinv(
char const * 
a, 
char * 
b)
 const {
   104     printf(
"CTF ERROR: no additive inverse present for this algebraic structure\n");
   109   void algstrct::add(
char const * 
a, 
char const * 
b, 
char * c)
 const {
   110     printf(
"CTF ERROR: no addition operation present for this algebraic structure\n");
   115   void algstrct::accum(
char const * 
a, 
char * 
b)
 const {
   120   void algstrct::mul(
char const * 
a, 
char const * 
b, 
char * c)
 const {
   121     printf(
"CTF ERROR: no multiplication operation present for this algebraic structure\n");
   126   void algstrct::safemul(
char const * 
a, 
char const * 
b, 
char *& c)
 const {
   127     printf(
"CTF ERROR: no multiplication operation present for this algebraic structure\n");
   132   void algstrct::min(
char const * 
a, 
char const * 
b, 
char * c)
 const {
   133     printf(
"CTF ERROR: no min operation present for this algebraic structure\n");
   138   void algstrct::max(
char const * 
a, 
char const * 
b, 
char * c)
 const {
   139     printf(
"CTF ERROR: no max operation present for this algebraic structure\n");
   144   void algstrct::cast_int(int64_t i, 
char * c)
 const {
   145     printf(
"CTF ERROR: integer scaling not possible for this algebraic structure\n");
   150   void algstrct::cast_double(
double d, 
char * c)
 const {
   151     printf(
"CTF ERROR: double scaling not possible for this algebraic structure\n");
   156   double algstrct::cast_to_double(
char const * c)
 const {
   157     printf(
"CTF ERROR: double cast not possible for this algebraic structure\n");
   163   int64_t algstrct::cast_to_int(
char const * c)
 const {
   164     printf(
"CTF ERROR: int cast not possible for this algebraic structure\n");
   170   void algstrct::print(
char const * 
a, FILE * fp)
 const {
   171     for (
int i=0; i<el_size; i++){
   172       fprintf(fp,
"%x",a[i]);
   176   void algstrct::min(
char * c)
 const {
   177     printf(
"CTF ERROR: min limit not present for this algebraic structure\n");
   182   void algstrct::max(
char * c)
 const {
   183     printf(
"CTF ERROR: max limit not present for this algebraic structure\n");
   188   void algstrct::sort(int64_t n, 
char * pairs)
 const {
   189     switch (this->el_size){
   234         #pragma omp parallel for   236         for (int64_t i=0; i<n; i++){
   237           idx_pairs[i].
key = *(int64_t*)(pairs+i*(
sizeof(int64_t)+this->el_size));
   238           idx_pairs[i].
idx = i;
   241         char * swap_buffer = this->pair_alloc(n);
   243         this->copy_pairs(swap_buffer, pairs, n);
   245         std::sort(idx_pairs, idx_pairs+n);
   251         #pragma omp parallel for   253         for (int64_t i=0; i<n; i++){
   254           pip[i].
write_val(piw[idx_pairs[i].idx].d());
   256         this->pair_dealloc(swap_buffer);
   262   void algstrct::scal(
int          n,
   266     if (isequal(alpha, addid())){
   267       if (incX == 1) 
set(X, addid(), n);
   269         for (
int i=0; i<n; i++){
   270           copy(X+i*el_size, addid());
   274       printf(
"CTF ERROR: scal not present for this algebraic structure\n");
   280   void algstrct::axpy(
int          n,
   286     printf(
"CTF ERROR: axpy not present for this algebraic structure\n");
   302     printf(
"CTF ERROR: gemm_batch not present for this algebraic structure\n");
   317     printf(
"CTF ERROR: gemm not present for this algebraic structure\n");
   332     printf(
"CTF ERROR: offload gemm not present for this algebraic structure\n");
   336   bool algstrct::is_offloadable()
 const {
   340   bool algstrct::isequal(
char const * 
a, 
char const * 
b)
 const {
   341     if (a == NULL && b == NULL) 
return true;
   342     if (a == NULL || b == NULL) 
return false;
   344     for (
int i=0; i<el_size; i++) {
   345       if (a[i] != b[i]) iseq = 
false;
   350   void algstrct::coo_to_csr(int64_t nz, 
int nrow, 
char * csr_vs, 
int * csr_cs, 
int * csr_rs, 
char const * coo_vs, 
int const * coo_rs, 
int const * coo_cs)
 const {
   351     printf(
"CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
   355   void algstrct::csr_to_coo(int64_t nz, 
int nrow, 
char const * csr_vs, 
int const * csr_ja, 
int const * csr_ia, 
char * coo_vs, 
int * coo_rs, 
int * coo_cs)
 const {
   356     printf(
"CTF ERROR: cannot convert elements of this algebraic structure to CSR\n");
   362   char * algstrct::csr_add(
char * cA, 
char * cB)
 const {
   367   char * algstrct::csr_reduce(
char * cA, 
int root, MPI_Comm cm)
 const {
   369     MPI_Comm_rank(cm, &r);
   370     MPI_Comm_size(cm, &p);
   374     double t_st = MPI_Wtime();
   375     while (p%s != 0) s++;
   379     MPI_Comm_split(cm, r/s, sr, &scm);
   380     MPI_Comm_split(cm, sr, r/s, &rcm);
   383     int64_t sz_A = A.
size();
   390     int64_t tot_buf_size = 0;
   391     for (
int i=0; i<s; i++){
   392       if (i==sr) snd_szs[i] = 0;
   393       else snd_szs[i] = parts[i]->
size();
   394       tot_buf_size += snd_szs[i];
   397     MPI_Alltoall(snd_szs, 1, MPI_INT, rcv_szs, 1, MPI_INT, scm);
   398     int64_t tot_rcv_sz = 0;
   399     for (
int i=0; i<s; i++){
   401       tot_rcv_sz += rcv_szs[i];
   408     for (
int i=0; i<s; i++){
   409       if (i>0) rcv_displs[i] = rcv_szs[i-1]+rcv_displs[i-1];
   411       if (i==sr) smnds[i] = parts[i]->
all_data;
   412       else smnds[i] = rcv_buf + rcv_displs[i];
   417     MPI_Alltoallv(parts[0]->all_data, snd_szs, snd_displs, MPI_CHAR, rcv_buf, rcv_szs, rcv_displs, MPI_CHAR, scm);
   418     for (
int i=0; i<s; i++){
   438     for (
int z=1; z<s; z<<=1){
   439       for (
int i=0; i<s-z; i+=2*z){
   440         char * csr_new = csr_add(smnds[i], smnds[i+z]);
   441         if ((smnds[i] < parts_buffer ||
   442              smnds[i] > parts_buffer+tot_buf_size) &&
   443             (smnds[i] < rcv_buf ||
   444              smnds[i] > rcv_buf+tot_rcv_sz))
   446         if ((smnds[i+z] < parts_buffer ||
   447              smnds[i+z] > parts_buffer+tot_buf_size) &&
   448             (smnds[i+z] < rcv_buf ||
   449              smnds[i+z] > rcv_buf+tot_rcv_sz))
   457     char * red_sum = csr_reduce(smnds[0], root/s, rcm);
   459     if (smnds[0] != red_sum) 
cdealloc(smnds[0]);
   465       if (sroot == sr) sz = 0;
   466       MPI_Gather(&sz, 1, MPI_INT, cb_sizes, 1, MPI_INT, sroot, scm);
   467       int64_t tot_cb_size = 0;
   470         for (
int i=0; i<s; i++){
   471           cb_displs[i] = tot_cb_size;
   472           tot_cb_size += cb_sizes[i];
   476       MPI_Gatherv(red_sum, sz, MPI_CHAR, cb_bufs, cb_sizes, cb_displs, MPI_CHAR, sroot, scm);
   480         for (
int i=0; i<s; i++){
   481           smnds[i] = cb_bufs + cb_displs[i];
   482           if (i==sr) smnds[i] = red_sum;
   487         double t_end = MPI_Wtime() - t_st;
   488         double tps[] = {t_end, 1.0, log2((
double)p), (double)sz_A};
   508   double algstrct::estimate_csr_red_time(int64_t msg_sz, 
CommData const * cdt)
 const {
   510     double ps[] = {1.0, log2((
double)cdt->
np), (double)msg_sz};
   514   void algstrct::acc(
char * 
b, 
char const * beta, 
char const * 
a, 
char const * alpha)
 const {
   521   void algstrct::accmul(
char * c, 
char const * 
a, 
char const * 
b, 
char const * alpha)
 const {
   524     mul(tmp, alpha, tmp);
   529   void algstrct::safecopy(
char *& 
a, 
char const * 
b)
 const {
   539     memcpy(a, b, el_size);
   542   void algstrct::copy_pair(
char * 
a, 
char const * 
b)
 const {
   543     memcpy(a, b, pair_size());
   547     memcpy(a, b, el_size*n);
   550   void algstrct::copy_pairs(
char * 
a, 
char const * 
b, int64_t n)
 const {
   551     memcpy(a, b, pair_size()*n);
   565         CTF_BLAS::ZCOPY(&n, (std::complex<double> 
const*)a, &inc_a, (std::complex<double>*)b, &inc_b);
   569         #pragma omp parallel for   571         for (int64_t i=0; i<nn; i++){
   572           copy(b+el_size*inc_b*i, a+el_size*inc_a*i);
   583                       int64_t      lda_b)
 const {
   584     if (lda_a == m && lda_b == n){
   585       memcpy(b,a,el_size*m*n);
   587       for (
int i=0; i<n; i++){
   588         memcpy(b+el_size*lda_b*i,a+el_size*lda_a*i,m*el_size);
   600                       char const * beta)
 const {
   601     if (!isequal(beta, mulid())){
   602       if (isequal(beta, addid())){
   604           set(
b, addid(), m*n);
   606           for (
int i=0; i<n; i++){
   607             set(b+i*lda_b*el_size, addid(), m);
   612           scal(m*n, beta, b, 1);
   614           for (
int i=0; i<n; i++){
   615             scal(m, beta, b+i*lda_b*el_size, 1);
   620     if (lda_a == m && lda_b == m){
   621       axpy(m*n, alpha, a, 1, b, 1);
   623       for (
int i=0; i<n; i++){
   624         axpy(m, alpha, a+el_size*lda_a*i, 1, b+el_size*lda_b*i, 1);
   629   void algstrct::set(
char * 
a, 
char const * 
b, int64_t n)
 const {
   632           float * ia = (
float*)a;
   633           float ib = *((
float*)b);
   634           std::fill(ia, ia+n, ib);
   638           double * ia = (
double*)a;
   639           double ib = *((
double*)b);
   640           std::fill(ia, ia+n, ib);
   644           std::complex<double> * ia = (std::complex<double>*)a;
   645           std::complex<double> ib = *((std::complex<double>*)b);
   646           std::fill(ia, ia+n, ib);
   650           for (
int i=0; i<n; i++) {
   651             memcpy(a+i*el_size, b, el_size);
   658   void algstrct::set_pair(
char * 
a, int64_t 
key, 
char const * vb)
 const {
   659     memcpy(a, &key, 
sizeof(int64_t));
   660     memcpy(get_value(a), vb, el_size);
   663   void algstrct::set_pairs(
char * 
a, 
char const * 
b, int64_t n)
 const {
   664     for (
int i=0; i<n; i++) {
   665       memcpy(a + i*pair_size(), b, pair_size());
   669   int64_t algstrct::get_key(
char const * 
a)
 const {
   673   char * algstrct::get_value(
char * 
a)
 const {
   674     return a+
sizeof(int64_t);
   677   char const * algstrct::get_const_value(
char const * 
a)
 const {
   678     return a+
sizeof(int64_t);
   681   char * algstrct::pair_alloc(int64_t n)
 const {
   689   void algstrct::dealloc(
char * ptr)
 const {
   693   void algstrct::pair_dealloc(
char * ptr)
 const {
   697   void algstrct::init(int64_t n, 
char * arr)
 const {
   703   void algstrct::coomm(
int m, 
int n, 
int k, 
char const * alpha, 
char const * A, 
int const * rows_A, 
int const * cols_A, int64_t nnz_A, 
char const * B, 
char const * beta, 
char * C, 
bivar_function const * func)
 const {
   704     printf(
"CTF ERROR: coomm not present for this algebraic structure\n");
   708   void algstrct::csrmm(
int m, 
int n, 
int k, 
char const * alpha, 
char const * A, 
int const * JA, 
int const * IA, int64_t nnz_A, 
char const * B, 
char const * beta, 
char * C, 
bivar_function const * func)
 const {
   709     printf(
"CTF ERROR: csrmm not present for this algebraic structure\n");
   713   void algstrct::csrmultd
   728     printf(
"CTF ERROR: csrmultd not present for this algebraic structure\n");
   732   void algstrct::csrmultcsr
   746                  char *&      C_CSR) 
const {
   748     printf(
"CTF ERROR: csrmultcsr not present for this algebraic structure\n");
   753     sr=pi.
sr; ptr=pi.
ptr;
   756   ConstPairIterator::ConstPairIterator(
algstrct const * sr_, 
char const * ptr_){
   764   int64_t ConstPairIterator::k()
 const {
   765     return ((int64_t*)ptr)[0];
   768   char const * ConstPairIterator::d()
 const {
   769     return sr->get_const_value(ptr);
   772   void ConstPairIterator::read(
char * buf, int64_t n)
 const {
   773     memcpy(buf, ptr, sr->pair_size()*n);
   776   void ConstPairIterator::read_val(
char * buf)
 const {
   777     memcpy(buf, sr->get_const_value(ptr), sr->el_size);
   780   PairIterator::PairIterator(
algstrct const * sr_, 
char * ptr_){
   789   int64_t PairIterator::k()
 const {
   790     return ((int64_t*)ptr)[0];
   793   char * PairIterator::d()
 const {
   794     return sr->get_value(ptr);
   797   void PairIterator::read(
char * buf, int64_t n)
 const {
   798     sr->copy_pair(buf, ptr);
   801   void PairIterator::read_val(
char * buf)
 const {
   802     sr->copy(buf, sr->get_const_value(ptr));
   805   void PairIterator::write(
char const * buf, int64_t n){
   806     sr->copy_pairs(ptr, buf, n);
   810     this->write(iter.
ptr, n);
   814     this->write(iter.
ptr, n);
   817   void PairIterator::write_val(
char const * buf){
   818     sr->copy(sr->get_value(ptr), buf);
   821   void PairIterator::write_key(int64_t 
key){
   822     ((int64_t*)ptr)[0] = 
key;
   825   void PairIterator::sort(int64_t n){
   832     #pragma omp parallel for   834     for (int64_t i=0; i<n; i++){
   835       int64_t k = rA[i].
k();
   837       for (
int j=0; j<order; j++){
   838         k_new += (k%old_lens[j])*new_lda[j];
   841       ((int64_t*)wA[i].ptr)[0] = k_new;
   849   void ConstPairIterator::pin(int64_t n, 
int order, 
int const * lens, 
int const * divisor, 
PairIterator pi_new){
   853     alloc_ptr(order*
sizeof(
int), (
void**)&div_lens);
   854     for (
int j=0; j<order; j++){
   855       div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
   859     #pragma omp parallel for   861     for (int64_t i=0; i<n; i++){
   862       int64_t 
key = pi[i].
k();
   866       for (
int j=0; j<order; j++){
   869         new_key += ((key%lens[j])/divisor[j])*lda;
   873       ((int64_t*)pi_new[i].ptr)[0] = new_key;
   883   void depin(
algstrct const * sr, 
int order, 
int const * lens, 
int const * divisor, 
int nvirt, 
int const * virt_dim, 
int const * phys_rank, 
char * X, int64_t & new_nnz_B, int64_t * nnz_blk, 
char *& new_B, 
bool check_padding){
   888     alloc_ptr(order*
sizeof(
int), (
void**)&div_lens);
   889     for (
int j=0; j<order; j++){
   890       div_lens[j] = (lens[j]/divisor[j] + (lens[j]%divisor[j] > 0));
   894       check_padding = 
false;
   895       for (
int v=0; v<nvirt; v++){
   897         for (
int j=0; j<order; j++){
   898           int vo = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
   899           if (lens[j]%divisor[j] != 0 && vo >= lens[j]%divisor[j]){
   900             check_padding = 
true;
   906     int64_t * old_nnz_blk_B = nnz_blk;
   911       memcpy(old_nnz_blk_B, nnz_blk, 
sizeof(int64_t)*nvirt);
   912       memset(nnz_blk, 0, 
sizeof(int64_t)*nvirt);
   916     alloc_ptr(order*
sizeof(
int), (
void**)&virt_offset);
   920     for (
int v=0; v<nvirt; v++){
   923       for (
int j=0; j<order; j++){
   924         virt_offset[j] = (vv%virt_dim[j])*(divisor[j]/virt_dim[j])+phys_rank[j];
   929         int64_t new_nnz_blk = 0;
   932         for (int64_t i=0; i<old_nnz_blk_B[v]; i++){
   933           int64_t 
key = vpi[i].
k();
   936           bool is_outside = 
false;
   937           for (
int j=0; j<order; j++){
   939             if (((key%div_lens[j])*divisor[j]+virt_offset[j])>=lens[j]){
   943             new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
   945             key = key/div_lens[j];
   949             ((int64_t*)vpi_new[new_nnz_blk].ptr)[0] = new_key;
   950             vpi_new[new_nnz_blk].write_val(vpi[i].d());
   954         nnz_blk[v] = new_nnz_blk;
   955         new_nnz_B += nnz_blk[v];
   956         nnz_off += old_nnz_blk_B[v];
   962         #pragma omp parallel for   964         for (int64_t i=0; i<nnz_blk[v]; i++){
   965           int64_t 
key = vpi[i].
k();
   969           for (int64_t j=0; j<order; j++){
   970             new_key += ((key%div_lens[j])*divisor[j]+virt_offset[j])*lda;
   972             key = key/div_lens[j];
   974           ((int64_t*)vpi_new[i].ptr)[0] = new_key;
   977         nnz_off += nnz_blk[v];
   992     switch (sr->el_size){
  1026         #pragma omp parallel for  1028         for (int64_t i=0; i<n; i++){
  1029           keys[i] = (*this)[i].k();
  1031         return std::lower_bound(keys, keys+n, op.
k())-keys;
 void permute(int order, int const *perm, int *arr)
permute an array 
virtual int pair_size() const 
gets pair size el_size plus the key size 
virtual char * pair_alloc(int64_t n) const 
allocate space for n (int64_t,dtype) pairs, necessary for object types 
static char * csr_add(char *cA, char *cB, accumulatable const *adder)
LinModel< 3 > csrred_mdl_cst(csrred_mdl_cst_init,"csrred_mdl_cst")
LinModel< 3 > csrred_mdl(csrred_mdl_init,"csrred_mdl")
void SCOPY(const int *n, const float *dX, const int *incX, float *dY, const int *incY)
void * alloc(int64_t len)
alloc abstraction 
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
untyped internal class for triply-typed bivariate function 
void DCOPY(const int *n, const double *dX, const int *incX, double *dY, const int *incY)
void depin(algstrct const *sr, int order, int const *lens, int const *divisor, int nvirt, int const *virt_dim, int const *phys_rank, char *X, int64_t &new_nnz_B, int64_t *nnz_blk, char *&new_B, bool check_padding)
depins keys of n pairs 
int64_t k() const 
returns key of pair at head of ptr 
int alloc_ptr(int64_t len, void **const ptr)
alloc abstraction 
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
int64_t k() const 
returns key of pair at head of ptr 
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
char * all_data
serialized buffer containing all info, index, and values related to matrix 
int cdealloc(void *ptr)
free abstraction 
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void write_val(char const *buf)
sets value of head pair to what is in buf 
void ZCOPY(const int *n, const std::complex< double > *dX, const int *incX, std::complex< double > *dY, const int *incY)
void partition(int s, char **parts_buffer, CSR_Matrix **parts)
splits CSR matrix into s submatrices (returned) corresponding to subsets of rows, all parts allocated...
double csrred_mdl_cst_init[]
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
int64_t size() const 
retrieves buffer size out of all_data