Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
algstrct.h
Go to the documentation of this file.
1 #ifndef __INT_SEMIRING_H__
2 #define __INT_SEMIRING_H__
3 
4 #include "../interface/common.h"
5 
6 namespace CTF_int {
7 
8  class bivar_function;
9 
13  class accumulatable {
14  public:
16  int el_size;
17 
19  virtual void accum(char const * a,
20  char * b) const { assert(0); }
21 
26  virtual void init_shell(int64_t n, char * arr) const { assert(0); };
27 
28  };
29 
34  class algstrct : public accumulatable {
35  public:
39 // MPI_Datatype pmdtype;
40 
42  void (*abs)(char const * a,
43  char * b);
44 
46  virtual int pair_size() const { return el_size + sizeof(int64_t); }
47 
48 
52  algstrct(){ has_coo_ker = false; }
53 
58  //algstrct(algstrct const &other);
59 
64  algstrct(int el_size);
65 
69  virtual ~algstrct() = 0;
70 
74  virtual algstrct * clone() const = 0;
75 // return new algstrct(el_size);
76 
77  virtual bool is_ordered() const = 0;
78 
80  virtual MPI_Op addmop() const;
81 
83  virtual MPI_Datatype mdtype() const;
84 
86 // MPI_Datatype pair_mdtype();
87 
89  virtual char const * addid() const;
90 
92  virtual char const * mulid() const;
93 
95  virtual void addinv(char const * a, char * b) const;
96 
98  virtual void safeaddinv(char const * a, char *& b) const;
99 
101  virtual void add(char const * a,
102  char const * b,
103  char * c) const;
104 
106  virtual void accum(char const * a, char * b) const;
107 
109  virtual bool has_mul() const { return false; }
110 
112  virtual void mul(char const * a,
113  char const * b,
114  char * c) const;
115 
117  virtual void safemul(char const * a,
118  char const * b,
119  char *& c) const;
120 
122  virtual void min(char const * a,
123  char const * b,
124  char * c) const;
125 
127  virtual void max(char const * a,
128  char const * b,
129  char * c) const;
130 
132  virtual void min(char * c) const;
133 
135  virtual void max(char * c) const;
136 
138  virtual void cast_int(int64_t i, char * c) const;
139 
141  virtual void cast_double(double d, char * c) const;
142 
144  virtual int64_t cast_to_int(char const * c) const;
145 
147  virtual double cast_to_double(char const * c) const;
148 
150  virtual void print(char const * a, FILE * fp=stdout) const;
151 
153  virtual void scal(int n,
154  char const * alpha,
155  char * X,
156  int incX) const;
157 
159  virtual void axpy(int n,
160  char const * alpha,
161  char const * X,
162  int incX,
163  char * Y,
164  int incY) const;
165 
167  virtual void gemm(char tA,
168  char tB,
169  int m,
170  int n,
171  int k,
172  char const * alpha,
173  char const * A,
174  char const * B,
175  char const * beta,
176  char * C) const;
177 
179  virtual void gemm_batch(char tA,
180  char tB,
181  int l,
182  int m,
183  int n,
184  int k,
185  char const * alpha,
186  char const * A,
187  char const * B,
188  char const * beta,
189  char * C) const;
190 
191  virtual void offload_gemm(char tA,
192  char tB,
193  int m,
194  int n,
195  int k,
196  char const * alpha,
197  char const * A,
198  char const * B,
199  char const * beta,
200  char * C) const;
201 
202  virtual bool is_offloadable() const;
203 
205  virtual void coomm(int m,
206  int n,
207  int k,
208  char const * alpha,
209  char const * A,
210  int const * rows_A,
211  int const * cols_A,
212  int64_t nnz_A,
213  char const * B,
214  char const * beta,
215  char * C,
216  bivar_function const * func) const;
217 
219  virtual void csrmm(int m,
220  int n,
221  int k,
222  char const * alpha,
223  char const * A,
224  int const * JA,
225  int const * IA,
226  int64_t nnz_A,
227  char const * B,
228  char const * beta,
229  char * C,
230  bivar_function const * func) const;
231 
233  virtual void csrmultd
234  (int m,
235  int n,
236  int k,
237  char const * alpha,
238  char const * A,
239  int const * JA,
240  int const * IA,
241  int64_t nnz_A,
242  char const * B,
243  int const * JB,
244  int const * IB,
245  int64_t nnz_B,
246  char const * beta,
247  char * C) const;
248  virtual void csrmultcsr
249  (int m,
250  int n,
251  int k,
252  char const * alpha,
253  char const * A,
254  int const * JA,
255  int const * IA,
256  int64_t nnz_A,
257  char const * B,
258  int const * JB,
259  int const * IB,
260  int64_t nnz_B,
261  char const * beta,
262  char *& C_CSR) const;
263 
265  virtual bool isequal(char const * a, char const * b) const;
266 
268  virtual void coo_to_csr(int64_t nz, int nrow, char * csr_vs, int * csr_cs, int * csr_rs, char const * coo_vs, int const * coo_rs, int const * coo_cs) const;
269 
271  virtual void csr_to_coo(int64_t nz, int nrow, char const * csr_vs, int const * csr_ja, int const * csr_ia, char * coo_vs, int * coo_rs, int * coo_cs) const;
272 
274  virtual char * csr_add(char * cA, char * cB) const;
275 
277  virtual char * csr_reduce(char * cA, int root, MPI_Comm cm) const;
278 
283  virtual char * pair_alloc(int64_t n) const;
284 
289  virtual char * alloc(int64_t n) const;
290 
292  virtual int64_t get_key(char const * a) const;
293 
295  virtual char * get_value(char * a) const;
296  virtual char const * get_const_value(char const * a) const;
301  virtual void dealloc(char * ptr) const;
302 
307  virtual void pair_dealloc(char * ptr) const;
308 
313  virtual void init(int64_t n, char * arr) const;
314 
318  virtual void sort(int64_t n, char * pairs) const;
319 
321  double estimate_csr_red_time(int64_t msg_sz, CommData const * cdt) const;
322 
324  void acc(char * b, char const * beta, char const * a, char const * alpha) const;
325 
327  void accmul(char * c, char const * a, char const * b, char const * alpha) const;
328 
330  virtual void copy(char * a, char const * b) const;
331 
333  void safecopy(char *& a, char const * b) const;
334 
336  virtual void copy(char * a, char const * b, int64_t n) const;
337 
339  virtual void copy(int64_t n, char const * a, int inc_a, char * b, int inc_b) const;
340 
342  virtual void copy(
343  int64_t m,
344  int64_t n,
345  char const * a,
346  int64_t lda_a,
347  char * b,
348  int64_t lda_b) const;
349 
351  virtual void copy(
352  int64_t m,
353  int64_t n,
354  char const * a,
355  int64_t lda_a,
356  char const * alpha,
357  char * b,
358  int64_t lda_b,
359  char const * beta) const;
360 
362  virtual void copy_pair(char * a, char const * b) const;
363 
365  virtual void copy_pairs(char * a, char const * b, int64_t n) const;
366 
368  virtual void set(char * a, char const * b, int64_t n) const;
369 
371  virtual void set_pair(char * a, int64_t key, char const * vb) const;
372 
374  virtual void set_pairs(char * a, char const * b, int64_t n) const;
375 
376  };
377 
378  class PairIterator;
379 
381  public:
382  algstrct const * sr;
383  char const * ptr;
384 
386  ConstPairIterator(PairIterator const & pi);
387 
389  ConstPairIterator(algstrct const * sr_, char const * ptr_);
390 
392  ConstPairIterator operator[](int n) const;
393 
395  int64_t k() const;
396 
398  char const * d() const;
399 
405  void read(char * buf, int64_t n=1) const;
406 
411  void read_val(char * buf) const;
412 
413 
417  void permute(int64_t n, int order, int const * old_lens, int64_t const * new_lda, PairIterator wA);
418 
422  void pin(int64_t n, int order, int const * lens, int const * divisor, PairIterator pi_new);
423 
424 
425  };
426  //http://stackoverflow.com/questions/630950/pure-virtual-destructor-in-c
428 
432  void depin(algstrct const * sr, int order, int const * lens, int const * divisor, int nvirt, int const * virt_dim, int const * phys_rank, char * X, int64_t & new_nnz_B, int64_t * nnz_blk, char *& new_B, bool check_padding);
433 
434  class PairIterator {
435  public:
436  algstrct const * sr;
437  char * ptr;
438 
440  PairIterator(algstrct const * sr_, char * ptr_);
441 
443  PairIterator operator[](int n) const;
444 
446  int64_t k() const;
447 
449  char * d() const;
450 
456  void read(char * buf, int64_t n=1) const;
457 
462  void read_val(char * buf) const;
463 
469  void write(char const * buf, int64_t n=1);
470 
476  void write(PairIterator const iter, int64_t n=1);
477 
483  void write(ConstPairIterator const iter, int64_t n=1);
484 
489  void write_val(char const * buf);
490 
495  void write_key(int64_t key);
496 
500  void sort(int64_t n);
501 
505  int64_t lower_bound(int64_t n, ConstPairIterator op);
506  };
507 
508 }
509 
510 
511 #endif
void permute(int order, int const *perm, int *arr)
permute an array
Definition: util.cxx:205
bool has_coo_ker
whether there was a custom COO CSRMM kernel provided for this algebraic structure ...
Definition: algstrct.h:37
virtual int pair_size() const
gets pair size el_size plus the key size
Definition: algstrct.h:46
algstrct const * sr
Definition: algstrct.h:436
int bivar_function(int n, World &dw)
void * alloc(int64_t len)
alloc abstraction
Definition: memcontrol.cxx:365
virtual void accum(char const *a, char *b) const
b+=a
Definition: algstrct.h:19
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:82
algstrct()
default constructor
Definition: algstrct.h:52
untyped internal class for triply-typed bivariate function
Definition: ctr_comm.h:16
algstrct const * sr
Definition: algstrct.h:382
virtual bool has_mul() const
Definition: algstrct.h:109
void depin(algstrct const *sr, int order, int const *lens, int const *divisor, int nvirt, int const *virt_dim, int const *phys_rank, char *X, int64_t &new_nnz_B, int64_t *nnz_blk, char *&new_B, bool check_padding)
depins keys of n pairs
Definition: algstrct.cxx:883
abstract class that knows how to add
Definition: algstrct.h:13
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
Definition: semiring.cxx:15
virtual ~algstrct()=0
destructor
Definition: algstrct.h:427
def abs(initA)
Definition: core.pyx:5440
def copy(tensor, A)
Definition: core.pyx:3583
int el_size
size of each element of algstrct in bytes
Definition: algstrct.h:16
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
Definition: algstrct.h:34
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
int64_t key
Definition: back_comp.h:66
virtual void init_shell(int64_t n, char *arr) const
initialize n objects to zero
Definition: algstrct.h:26