Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
ctr_2d_general.h
Go to the documentation of this file.
1 /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
2 #include "ctr_comm.h"
3 
4 #ifndef __CTR_2D_GENERAL_H__
5 #define __CTR_2D_GENERAL_H__
6 
7 namespace CTF_int{
8  class tensor;
9  int ctr_2d_gen_build(int is_used,
10  CommData global_comm,
11  int i,
12  int * virt_dim,
13  int & cg_edge_len,
14  int & total_iter,
15  tensor * A,
16  int i_A,
17  CommData *& cg_cdt_A,
18  int64_t & cg_ctr_lda_A,
19  int64_t & cg_ctr_sub_lda_A,
20  bool & cg_move_A,
21  int * blk_len_A,
22  int64_t & blk_sz_A,
23  int const * virt_blk_len_A,
24  int & load_phase_A,
25  tensor * B,
26  int i_B,
27  CommData *& cg_cdt_B,
28  int64_t & cg_ctr_lda_B,
29  int64_t & cg_ctr_sub_lda_B,
30  bool & cg_move_B,
31  int * blk_len_B,
32  int64_t & blk_sz_B,
33  int const * virt_blk_len_B,
34  int & load_phase_B,
35  tensor * C,
36  int i_C,
37  CommData *& cg_cdt_C,
38  int64_t & cg_ctr_lda_C,
39  int64_t & cg_ctr_sub_lda_C,
40  bool & cg_move_C,
41  int * blk_len_C,
42  int64_t & blk_sz_C,
43  int const * virt_blk_len_C,
44  int & load_phase_C);
45 
46 
47  class ctr_2d_general : public ctr {
48  public:
49  int edge_len;
50 
51  int64_t ctr_lda_A; /* local lda_A of contraction dimension 'k' */
52  int64_t ctr_sub_lda_A; /* elements per local lda_A
53  of contraction dimension 'k' */
54  int64_t ctr_lda_B; /* local lda_B of contraction dimension 'k' */
55  int64_t ctr_sub_lda_B; /* elements per local lda_B
56  of contraction dimension 'k' */
57  int64_t ctr_lda_C; /* local lda_C of contraction dimension 'k' */
58  int64_t ctr_sub_lda_C; /* elements per local lda_C
59  of contraction dimension 'k' */
60  #ifdef OFFLOAD
61  bool alloc_host_buf;
62  #endif
63 
64  bool move_A;
65  bool move_B;
66  bool move_C;
67 
71  /* Class to be called on sub-blocks */
73 
77  void print();
83  void run(char * A, char * B, char * C);
89  int64_t mem_fp();
94  int64_t mem_rec();
99  double est_time_fp(int nlyr);
104  double est_time_rec(int nlyr);
105  ctr * clone();
106 
118  void find_bsizes(int64_t & b_A,
119  int64_t & b_B,
120  int64_t & b_C,
121  int64_t & s_A,
122  int64_t & s_B,
123  int64_t & s_C,
124  int64_t & aux_size);
128  ctr_2d_general(ctr * other);
132  ~ctr_2d_general();
137  ctr_2d_general(contraction * c) : ctr(c){ move_A=0; move_B=0; move_C=0; }
138  };
139 }
140 #endif
~ctr_2d_general()
deallocs ctr_2d_general object
void print()
print ctr object
int64_t mem_rec()
returns the number of bytes of buffer space we need recursively
double est_time_rec(int nlyr)
returns the number of bytes send by each proc recursively
class for execution distributed contraction of tensors
Definition: contraction.h:16
int64_t mem_fp()
returns the number of bytes of buffer space we need
void run(char *A, char *B, char *C)
Basically doing SUMMA, except assumes equal block size on each processor. Performs rank-b updates whe...
ctr_2d_general(ctr *other)
copies ctr object
double est_time_fp(int nlyr)
returns the number of bytes this kernel will send per processor
void find_bsizes(int64_t &b_A, int64_t &b_B, int64_t &b_C, int64_t &s_A, int64_t &s_B, int64_t &s_C, int64_t &aux_size)
determines buffer and block sizes needed for ctr_2d_general
ctr_2d_general(contraction *c)
partial constructor, most of the logic is in the ctr_2d_gen_build function
int ctr_2d_gen_build(int is_used, CommData global_comm, int i, int *virt_dim, int &cg_edge_len, int &total_iter, tensor *A, int i_A, CommData *&cg_cdt_A, int64_t &cg_ctr_lda_A, int64_t &cg_ctr_sub_lda_A, bool &cg_move_A, int *blk_len_A, int64_t &blk_sz_A, int const *virt_blk_len_A, int &load_phase_A, tensor *B, int i_B, CommData *&cg_cdt_B, int64_t &cg_ctr_lda_B, int64_t &cg_ctr_sub_lda_B, bool &cg_move_B, int *blk_len_B, int64_t &blk_sz_B, int const *virt_blk_len_B, int &load_phase_B, tensor *C, int i_C, CommData *&cg_cdt_C, int64_t &cg_ctr_lda_C, int64_t &cg_ctr_sub_lda_C, bool &cg_move_C, int *blk_len_C, int64_t &blk_sz_C, int const *virt_blk_len_C, int &load_phase_C)
sets up a ctr_2d_general (2D SUMMA) level where A is not communicated function will be called with A/...