Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
contraction.h
Go to the documentation of this file.
1 #ifndef __INT_CONTRACTION_H__
2 #define __INT_CONTRACTION_H__
3 
4 #include <assert.h>
5 #include "ctr_tsr.h"
6 
7 namespace CTF_int {
8  class tensor;
9  class topology;
10  class distribution;
11  class mapping;
12 
16  class contraction {
17  public:
19  tensor * A;
21  tensor * B;
23  tensor * C;
24 
26  char const * alpha;
28  char const * beta;
29 
31  int * idx_A;
33  int * idx_B;
35  int * idx_C;
37  bool is_custom;
40 
42  contraction(){ idx_A = NULL; idx_B = NULL; idx_C=NULL; is_custom=0; alpha=NULL; beta=NULL; };
43 
45  ~contraction();
46 
48  contraction(contraction const & other);
49 
63  /*contraction(tensor * A,
64  int const * idx_A,
65  tensor * B,
66  int const * idx_B,
67  char const * alpha,
68  tensor * C,
69  int const * idx_C,
70  char const * beta);*/
71 
87  contraction(tensor * A,
88  int const * idx_A,
89  tensor * B,
90  int const * idx_B,
91  char const * alpha,
92  tensor * C,
93  int const * idx_C,
94  char const * beta,
95  bivar_function const * func=NULL);
96  contraction(tensor * A,
97  char const * idx_A,
98  tensor * B,
99  char const * idx_B,
100  char const * alpha,
101  tensor * C,
102  char const * idx_C,
103  char const * beta,
104  bivar_function const * func=NULL);
105 
106 
108  void execute();
109 
111  double estimate_time();
112 
117  int is_equal(contraction const & os);
118 
119  private:
123  bool is_sparse();
124 
132  void get_fold_indices(int * num_fold,
133  int ** fold_idx);
134 
139  int can_fold();
140 
141 
152  void get_fold_ctr(contraction *& fold_ctr,
153  int & all_fdim_A,
154  int & all_fdim_B,
155  int & all_fdim_C,
156  int *& all_flen_A,
157  int *& all_flen_B,
158  int *& all_flen_C);
159 
160 
174  void select_ctr_perm(contraction const * fold_ctr,
175  int all_fdim_A,
176  int all_fdim_B,
177  int all_fdim_C,
178  int const * all_flen_A,
179  int const * all_flen_B,
180  int const * all_flen_C,
181  int & bperm_order,
182  double & btime,
183  iparam & iprm);
184 
190  iparam map_fold(bool do_transp=true);
191 
196  double est_time_fold();
197 
198 
204  int unfold_broken_sym(contraction ** new_contraction);
205 
211  bool check_consistency();
212 
217  int check_mapping();
218 
230  int map_to_topology(topology * topo,
231  int order);
232 /* int * idx_arr,
233  int * idx_ctr,
234  int * idx_extra,
235  int * idx_no_ctr,
236  int * idx_weigh);*/
237 
238 
243  int get_num_map_variants(topology const * topo);
244 
245  int get_num_map_variants(topology const * topo,
246  int & nmax_ctr_2d,
247  int & nAB,
248  int & nAC,
249  int & nBC);
250 
251  bool switch_topo_perm();
252 
262  bool exh_map_to_topo(topology const * topo,
263  int variant);
267  int try_topo_morph();
268 
269  void get_best_sel_map(distribution const * dA, distribution const * dB, distribution const * dC, topology * old_topo_A, topology * old_topo_B, topology * old_topo_C, mapping const * old_map_A, mapping const * old_map_B, mapping const * old_map_C, int & idx, double & time);
270 
271  void get_best_exh_map(distribution const * dA, distribution const * dB, distribution const * dC, topology * old_topo_A, topology * old_topo_B, topology * old_topo_C, mapping const * old_map_A, mapping const * old_map_B, mapping const * old_map_C, int & idx, double & time, double init_best_time);
272 
279  int map(ctr ** ctrf, bool do_remap=1);
280 
290  ctr * construct_ctr(int is_inner=0,
291  iparam const * inner_params=NULL,
292  int * nvirt_all=NULL,
293  int is_used=1);
294 
295  ctr * construct_dense_ctr(int is_inner,
296  iparam const * inner_params,
297  int * nvirt_all,
298  int is_used,
299  int const * phys_mapped);
300 
301  ctr * construct_sparse_ctr(int is_inner,
302  iparam const * inner_params,
303  int * nvirt_all,
304  int is_used,
305  int const * phys_mapped);
306 
307 
308 
313  int contract();
314 
319  int sym_contract();
320 
326  int home_contract();
327 
331  void prescale_operands();
332 
336  bool need_prescale_operands();
337 
339  void print();
340  };
341 
342 
366  int ctr_2d_gen_build(int is_used,
367  CommData global_comm,
368  int i,
369  int * virt_dim,
370  int & cg_edge_len,
371  int & total_iter,
372  tensor * A,
373  int i_A,
374  CommData *& cg_cdt_A,
375  int64_t & cg_ctr_lda_A,
376  int64_t & cg_ctr_sub_lda_A,
377  bool & cg_move_A,
378  int * blk_len_A,
379  int64_t & blk_sz_A,
380  int const * virt_blk_len_A,
381  int & load_phase_A,
382  tensor * B,
383  int i_B,
384  CommData *& cg_cdt_B,
385  int64_t & cg_ctr_lda_B,
386  int64_t & cg_ctr_sub_lda_B,
387  bool & cg_move_B,
388  int * blk_len_B,
389  int64_t & blk_sz_B,
390  int const * virt_blk_len_B,
391  int & load_phase_B,
392  tensor * C,
393  int i_C,
394  CommData *& cg_cdt_C,
395  int64_t & cg_ctr_lda_C,
396  int64_t & cg_ctr_sub_lda_C,
397  bool & cg_move_C,
398  int * blk_len_C,
399  int64_t & blk_sz_C,
400  int const * virt_blk_len_C,
401  int & load_phase_C);
402 }
403 
404 #endif
~contraction()
destructor
Definition: contraction.cxx:29
void execute()
run contraction
Definition: contraction.cxx:99
tensor * A
left operand
Definition: contraction.h:19
tensor * B
right operand
Definition: contraction.h:21
untyped internal class for triply-typed bivariate function
Definition: ctr_comm.h:16
bivar_function const * func
function to execute on elements
Definition: contraction.h:39
char const * beta
scaling of existing C
Definition: contraction.h:28
bool is_custom
whether there is a elementwise custom function
Definition: contraction.h:37
class for execution distributed contraction of tensors
Definition: contraction.h:16
contraction()
lazy constructor
Definition: contraction.h:42
int * idx_B
indices of right operand
Definition: contraction.h:33
tensor * C
output
Definition: contraction.h:23
int is_equal(contraction const &os)
returns 1 if contractions have same tensors and index map
double estimate_time()
predicts execution time in seconds using performance models
int * idx_C
indices of output
Definition: contraction.h:35
internal distributed tensor class
char const * alpha
scaling of A*B
Definition: contraction.h:26
int ctr_2d_gen_build(int is_used, CommData global_comm, int i, int *virt_dim, int &cg_edge_len, int &total_iter, tensor *A, int i_A, CommData *&cg_cdt_A, int64_t &cg_ctr_lda_A, int64_t &cg_ctr_sub_lda_A, bool &cg_move_A, int *blk_len_A, int64_t &blk_sz_A, int const *virt_blk_len_A, int &load_phase_A, tensor *B, int i_B, CommData *&cg_cdt_B, int64_t &cg_ctr_lda_B, int64_t &cg_ctr_sub_lda_B, bool &cg_move_B, int *blk_len_B, int64_t &blk_sz_B, int const *virt_blk_len_B, int &load_phase_B, tensor *C, int i_C, CommData *&cg_cdt_C, int64_t &cg_ctr_lda_C, int64_t &cg_ctr_sub_lda_C, bool &cg_move_C, int *blk_len_C, int64_t &blk_sz_C, int const *virt_blk_len_C, int &load_phase_C)
sets up a ctr_2d_general (2D SUMMA) level where A is not communicated function will be called with A/...
int * idx_A
indices of left operand
Definition: contraction.h:31