1 #ifndef __INT_SEMIRING_H__ 2 #define __INT_SEMIRING_H__ 4 #include "../interface/common.h" 20 char *
b)
const { assert(0); }
26 virtual void init_shell(int64_t n,
char * arr)
const { assert(0); };
42 void (*
abs)(
char const *
a,
74 virtual algstrct * clone()
const = 0;
77 virtual bool is_ordered()
const = 0;
80 virtual MPI_Op addmop()
const;
83 virtual MPI_Datatype mdtype()
const;
89 virtual char const * addid()
const;
92 virtual char const * mulid()
const;
95 virtual void addinv(
char const * a,
char *
b)
const;
98 virtual void safeaddinv(
char const * a,
char *& b)
const;
101 virtual void add(
char const * a,
106 virtual void accum(
char const * a,
char * b)
const;
109 virtual bool has_mul()
const {
return false; }
112 virtual void mul(
char const * a,
117 virtual void safemul(
char const * a,
122 virtual void min(
char const * a,
127 virtual void max(
char const * a,
132 virtual void min(
char * c)
const;
135 virtual void max(
char * c)
const;
138 virtual void cast_int(int64_t i,
char * c)
const;
141 virtual void cast_double(
double d,
char * c)
const;
144 virtual int64_t cast_to_int(
char const * c)
const;
147 virtual double cast_to_double(
char const * c)
const;
150 virtual void print(
char const * a, FILE * fp=stdout)
const;
153 virtual void scal(
int n,
159 virtual void axpy(
int n,
167 virtual void gemm(
char tA,
202 virtual bool is_offloadable()
const;
205 virtual void coomm(
int m,
219 virtual void csrmm(
int m,
233 virtual void csrmultd
248 virtual void csrmultcsr
262 char *& C_CSR)
const;
265 virtual bool isequal(
char const * a,
char const * b)
const;
268 virtual void coo_to_csr(int64_t nz,
int nrow,
char * csr_vs,
int * csr_cs,
int * csr_rs,
char const * coo_vs,
int const * coo_rs,
int const * coo_cs)
const;
271 virtual void csr_to_coo(int64_t nz,
int nrow,
char const * csr_vs,
int const * csr_ja,
int const * csr_ia,
char * coo_vs,
int * coo_rs,
int * coo_cs)
const;
274 virtual char * csr_add(
char * cA,
char * cB)
const;
277 virtual char * csr_reduce(
char * cA,
int root, MPI_Comm cm)
const;
283 virtual char * pair_alloc(int64_t n)
const;
289 virtual char *
alloc(int64_t n)
const;
292 virtual int64_t get_key(
char const * a)
const;
295 virtual char * get_value(
char * a)
const;
296 virtual char const * get_const_value(
char const * a)
const;
301 virtual void dealloc(
char * ptr)
const;
307 virtual void pair_dealloc(
char * ptr)
const;
313 virtual void init(int64_t n,
char * arr)
const;
318 virtual void sort(int64_t n,
char * pairs)
const;
321 double estimate_csr_red_time(int64_t msg_sz,
CommData const * cdt)
const;
324 void acc(
char * b,
char const * beta,
char const * a,
char const * alpha)
const;
327 void accmul(
char * c,
char const * a,
char const * b,
char const * alpha)
const;
330 virtual void copy(
char * a,
char const * b)
const;
333 void safecopy(
char *& a,
char const * b)
const;
336 virtual void copy(
char * a,
char const * b, int64_t n)
const;
339 virtual void copy(int64_t n,
char const * a,
int inc_a,
char * b,
int inc_b)
const;
348 int64_t lda_b)
const;
359 char const * beta)
const;
362 virtual void copy_pair(
char * a,
char const * b)
const;
365 virtual void copy_pairs(
char * a,
char const * b, int64_t n)
const;
368 virtual void set(
char *
a,
char const *
b, int64_t n)
const;
371 virtual void set_pair(
char * a, int64_t
key,
char const * vb)
const;
374 virtual void set_pairs(
char * a,
char const * b, int64_t n)
const;
398 char const * d()
const;
405 void read(
char * buf, int64_t n=1)
const;
411 void read_val(
char * buf)
const;
417 void permute(int64_t n,
int order,
int const * old_lens, int64_t
const * new_lda,
PairIterator wA);
422 void pin(int64_t n,
int order,
int const * lens,
int const * divisor,
PairIterator pi_new);
432 void depin(
algstrct const * sr,
int order,
int const * lens,
int const * divisor,
int nvirt,
int const * virt_dim,
int const * phys_rank,
char * X, int64_t & new_nnz_B, int64_t * nnz_blk,
char *& new_B,
bool check_padding);
456 void read(
char * buf, int64_t n=1)
const;
462 void read_val(
char * buf)
const;
469 void write(
char const * buf, int64_t n=1);
489 void write_val(
char const * buf);
495 void write_key(int64_t
key);
500 void sort(int64_t n);
void permute(int order, int const *perm, int *arr)
permute an array
bool has_coo_ker
whether there was a custom COO CSRMM kernel provided for this algebraic structure ...
virtual int pair_size() const
gets pair size el_size plus the key size
int bivar_function(int n, World &dw)
void * alloc(int64_t len)
alloc abstraction
virtual void accum(char const *a, char *b) const
b+=a
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
algstrct()
default constructor
untyped internal class for triply-typed bivariate function
virtual bool has_mul() const
void depin(algstrct const *sr, int order, int const *lens, int const *divisor, int nvirt, int const *virt_dim, int const *phys_rank, char *X, int64_t &new_nnz_B, int64_t *nnz_blk, char *&new_B, bool check_padding)
depins keys of n pairs
abstract class that knows how to add
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
virtual ~algstrct()=0
destructor
int el_size
size of each element of algstrct in bytes
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
virtual void init_shell(int64_t n, char *arr) const
initialize n objects to zero