Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
ctr_comm.h
Go to the documentation of this file.
1 /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
2 
3 #ifndef __CTR_COMM_H__
4 #define __CTR_COMM_H__
5 
6 #include "../tensor/algstrct.h"
7 #include "../interface/fun_term.h"
8 #include "../sparse_formats/csr.h"
9 
10 namespace CTF_int{
11  class contraction;
12 
17  public:
18  bool has_kernel;
23 
30  virtual void apply_f(char const * a, char const * b, char * c) const = 0;
31 
39  virtual void acc_f(char const * a, char const * b, char * c, CTF_int::algstrct const * sr_C) const = 0;
40 
41 
48  void operator()(Term const & A, Term const & B, Term const & C) const;
49 
56  Bifun_Term operator()(Term const & A, Term const & B) const;
57 
64  bivar_function(bool is_comm=false,
65  bool is_left_dist=false,
66  bool is_right_dist=false){
67  has_kernel = false;
68  has_off_gemm = false;
69  commutative = is_comm;
70  left_distributive = is_left_dist;
71  right_distributive = is_right_dist;
72  }
73 
74  virtual ~bivar_function(){}
75 
76  virtual bool is_accumulator() const { return false; }
77 
78  virtual void cgemm(char tA,
79  char tB,
80  int m,
81  int n,
82  int k,
83  char const * A,
84  char const * B,
85  char * C) const {}
86 
87  virtual void coffload_gemm(char tA,
88  char tB,
89  int m,
90  int n,
91  int k,
92  char const * A,
93  char const * B,
94  char * C) const { assert(0); }
95 
96  virtual void ccoomm(int m,
97  int n,
98  int k,
99  char const * A,
100  int const * rows_A,
101  int const * cols_A,
102  int64_t nnz_A,
103  char const * B,
104  char * C) const { assert(0); }
105 
106 
107  virtual void ccsrmm
108  (int m,
109  int n,
110  int k,
111  char const * A,
112  int const * JA,
113  int const * IA,
114  int64_t nnz_A,
115  char const * B,
116  char * C,
117  algstrct const * sr_C) const { assert(0); }
118 
119  virtual void ccsrmultd
120  (int m,
121  int n,
122  int k,
123  char const * A,
124  int const * JA,
125  int const * IA,
126  int nnz_A,
127  char const * B,
128  int const * JB,
129  int const * IB,
130  int nnz_B,
131  char * C,
132  algstrct const * sr_C) const { assert(0); }
133 
134  virtual void ccsrmultcsr
135  (int m,
136  int n,
137  int k,
138  char const * A,
139  int const * JA,
140  int const * IA,
141  int nnz_A,
142  char const * B,
143  int const * JB,
144  int const * IB,
145  int nnz_B,
146  char *& C_CSR,
147  algstrct const * sr_C) const { assert(0); }
148 
149 
150  virtual void coffload_csrmm(int m,
151  int n,
152  int k,
153  char const * all_data,
154  char const * B,
155  char * C) const { assert(0); }
156 
157 
158  };
159 
165  class ctr {
166  public:
167  algstrct const * sr_A;
168  algstrct const * sr_B;
169  algstrct const * sr_C;
170  char const * beta;
171  int num_lyr; /* number of copies of this matrix being computed on */
172  int idx_lyr; /* the index of this copy */
173 
174  virtual void run(char * A, char * B, char * C) { printf("SHOULD NOTR\n"); };
175  virtual void print() { };
176  virtual int64_t mem_fp() { return 0; };
177  virtual int64_t mem_rec() { return mem_fp(); };
178  virtual double est_time_fp(int nlyr) { return 0; };
179  virtual double est_time_rec(int nlyr) { return est_time_fp(nlyr); };
180  virtual ctr * clone() { return NULL; };
181 
185  virtual ~ctr();
186 
190  ctr(ctr * other);
191 
195  ctr(contraction const * c);
196  };
197 
198  class ctr_replicate : public ctr {
199  public:
200  int ncdt_A; /* number of processor dimensions to replicate A along */
201  int ncdt_B; /* number of processor dimensions to replicate B along */
202  int ncdt_C; /* number of processor dimensions to replicate C along */
203  int64_t size_A; /* size of A blocks */
204  int64_t size_B; /* size of B blocks */
205  int64_t size_C; /* size of C blocks */
206 
210  /* Class to be called on sub-blocks */
212 
213  void run(char * A, char * B, char * C);
219  int64_t mem_fp();
224  int64_t mem_rec();
229  double est_time_fp(int nlyr);
234  double est_time_rec(int nlyr);
235  void print();
236  ctr * clone();
237 
238  ctr_replicate(ctr * other);
239  ~ctr_replicate();
240  ctr_replicate(contraction const * c,
241  int const * phys_mapped,
242  int64_t blk_sz_A,
243  int64_t blk_sz_B,
244  int64_t blk_sz_C);
245  };
251 }
252 #endif // __CTR_COMM_H__
a term is an abstract object representing some expression of tensors
Definition: term.h:33
virtual double est_time_fp(int nlyr)
Definition: ctr_comm.h:178
virtual int64_t mem_fp()
Definition: ctr_comm.h:176
CommData ** cdt_C
Definition: ctr_comm.h:209
virtual int64_t mem_rec()
Definition: ctr_comm.h:177
virtual void ccsrmultcsr(int m, int n, int k, char const *A, int const *JA, int const *IA, int nnz_A, char const *B, int const *JB, int const *IB, int nnz_B, char *&C_CSR, algstrct const *sr_C) const
Definition: ctr_comm.h:135
void operator()(Term const &A, Term const &B, Term const &C) const
evaluate C+=f(A,B) or f(A,B,C) if transform
Definition: ctr_comm.cxx:15
untyped internal class for triply-typed bivariate function
Definition: ctr_comm.h:16
algstrct const * sr_B
Definition: ctr_comm.h:168
virtual void coffload_csrmm(int m, int n, int k, char const *all_data, char const *B, char *C) const
Definition: ctr_comm.h:150
algstrct const * sr_C
Definition: ctr_comm.h:169
class for execution distributed contraction of tensors
Definition: contraction.h:16
virtual void coffload_gemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
Definition: ctr_comm.h:87
virtual void print()
Definition: ctr_comm.h:175
virtual bool is_accumulator() const
Definition: ctr_comm.h:76
virtual void ccoomm(int m, int n, int k, char const *A, int const *rows_A, int const *cols_A, int64_t nnz_A, char const *B, char *C) const
Definition: ctr_comm.h:96
algstrct const * sr_A
Definition: ctr_comm.h:167
CommData ** cdt_A
Definition: ctr_comm.h:207
virtual void apply_f(char const *a, char const *b, char *c) const =0
apply function f to values stored at a and b
virtual ctr * clone()
Definition: ctr_comm.h:180
virtual void ccsrmultd(int m, int n, int k, char const *A, int const *JA, int const *IA, int nnz_A, char const *B, int const *JB, int const *IB, int nnz_B, char *C, algstrct const *sr_C) const
Definition: ctr_comm.h:120
virtual double est_time_rec(int nlyr)
Definition: ctr_comm.h:179
bivar_function(bool is_comm=false, bool is_left_dist=false, bool is_right_dist=false)
constructor sets function properties, pessimistic defaults
Definition: ctr_comm.h:64
virtual void ccsrmm(int m, int n, int k, char const *A, int const *JA, int const *IA, int64_t nnz_A, char const *B, char *C, algstrct const *sr_C) const
Definition: ctr_comm.h:108
virtual void run(char *A, char *B, char *C)
Definition: ctr_comm.h:174
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
virtual ~bivar_function()
Definition: ctr_comm.h:74
CommData ** cdt_B
Definition: ctr_comm.h:208
char const * beta
Definition: ctr_comm.h:170
virtual void cgemm(char tA, char tB, int m, int n, int k, char const *A, char const *B, char *C) const
Definition: ctr_comm.h:78
virtual void acc_f(char const *a, char const *b, char *c, CTF_int::algstrct const *sr_C) const =0
compute c = c+f(a,b)