6 #include "../tensor/algstrct.h" 7 #include "../interface/fun_term.h" 8 #include "../sparse_formats/csr.h" 30 virtual void apply_f(
char const *
a,
char const *
b,
char * c)
const = 0;
65 bool is_left_dist=
false,
66 bool is_right_dist=
false){
69 commutative = is_comm;
70 left_distributive = is_left_dist;
71 right_distributive = is_right_dist;
94 char * C)
const { assert(0); }
104 char * C)
const { assert(0); }
117 algstrct const * sr_C)
const { assert(0); }
132 algstrct const * sr_C)
const { assert(0); }
147 algstrct const * sr_C)
const { assert(0); }
153 char const * all_data,
155 char * C)
const { assert(0); }
174 virtual void run(
char * A,
char * B,
char * C) { printf(
"SHOULD NOTR\n"); };
177 virtual int64_t
mem_rec() {
return mem_fp(); };
213 void run(
char * A,
char * B,
char * C);
229 double est_time_fp(
int nlyr);
234 double est_time_rec(
int nlyr);
241 int const * phys_mapped,
252 #endif // __CTR_COMM_H__
a term is an abstract object representing some expression of tensors
virtual double est_time_fp(int nlyr)
virtual int64_t mem_rec()
virtual void ccsrmultcsr(int m, int n, int k, char const *A, int const *JA, int const *IA, int nnz_A, char const *B, int const *JB, int const *IB, int nnz_B, char *&C_CSR, algstrct const *sr_C) const
void operator()(Term const &A, Term const &B, Term const &C) const
evaluate C+=f(A,B) or f(A,B,C) if transform
untyped internal class for triply-typed bivariate function
virtual void coffload_csrmm(int m, int n, int k, char const *all_data, char const *B, char *C) const
class for execution distributed contraction of tensors
virtual void coffload_gemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
virtual bool is_accumulator() const
virtual void ccoomm(int m, int n, int k, char const *A, int const *rows_A, int const *cols_A, int64_t nnz_A, char const *B, char *C) const
virtual void apply_f(char const *a, char const *b, char *c) const =0
apply function f to values stored at a and b
virtual void ccsrmultd(int m, int n, int k, char const *A, int const *JA, int const *IA, int nnz_A, char const *B, int const *JB, int const *IB, int nnz_B, char *C, algstrct const *sr_C) const
virtual double est_time_rec(int nlyr)
bivar_function(bool is_comm=false, bool is_left_dist=false, bool is_right_dist=false)
constructor sets function properties, pessimistic defaults
virtual void ccsrmm(int m, int n, int k, char const *A, int const *JA, int const *IA, int64_t nnz_A, char const *B, char *C, algstrct const *sr_C) const
virtual void run(char *A, char *B, char *C)
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
virtual ~bivar_function()
virtual void cgemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
virtual void acc_f(char const *a, char const *b, char *c, CTF_int::algstrct const *sr_C) const =0
compute c = c+f(a,b)