Cyclops Tensor Framework parallel arithmetic on multidimensional arrays
sym_indices.cxx
Go to the documentation of this file.
1
3 #include "../interface/common.h"
4 #include "sym_indices.h"
5
6 using namespace CTF;
7
9 {
10  int sort;
11  int idx;
12  int pos_A;
13  int pos_B;
14  int pos_C;
15
16  index_locator_(int sort, int idx, int pos_A, int pos_B, int pos_C)
17  : sort(sort), idx(idx), pos_A(pos_A), pos_B(pos_B), pos_C(pos_C) {}
18
19  static bool sortA(const index_locator_& a, const index_locator_& b)
20  {
21  return a.pos_A < b.pos_A;
22  }
23
24  static bool sortB(const index_locator_& a, const index_locator_& b)
25  {
26  return a.pos_B < b.pos_B;
27  }
28
29  static bool sortC(const index_locator_& a, const index_locator_& b)
30  {
31  return a.pos_C < b.pos_C;
32  }
33
34  bool operator==(int idx)
35  {
36  return this->idx == idx;
37  }
38 };
39 template <typename T>
40 int align_symmetric_indices(int order_A, T& idx_A, const int* sym_A,
41  int order_B, T& idx_B, const int* sym_B)
42 {
43  int fact = 1;
44
45  std::vector<index_locator_> indices;
46
47  for (int i = 0;i < order_A;i++)
48  {
49  int i_in_B; for (i_in_B = 0;i_in_B < order_B && idx_A[i] != idx_B[i_in_B];i_in_B++);
50  if (i_in_B == order_B) continue;
51
52  indices.push_back(index_locator_(0, idx_A[i], i, i_in_B, 0));
53  }
54
55  while (!indices.empty())
56  {
57  std::vector<index_locator_> group;
58  group.push_back(indices[0]);
59  group.back().sort = 0;
60  indices.erase(indices.begin());
61
62  int s = 1;
63  for (std::vector<index_locator_>::iterator it = indices.begin();;)
64  {
65  if (it == indices.end()) break;
66
67  if ((group[0].pos_A == -1 && it->pos_A != -1) ||
68  (group[0].pos_A != -1 && it->pos_A == -1) ||
69  (group[0].pos_B == -1 && it->pos_B != -1) ||
70  (group[0].pos_B != -1 && it->pos_B == -1))
71  {
72  ++it;
73  continue;
74  }
75
76  bool sym_in_A = false;
77  for (int k = group[0].pos_A-1;k >= 0 && sym_A[k] != NS;k--)
78  {
79  if (idx_A[k] == it->idx)
80  {
81  sym_in_A = true;
82  break;
83  }
84  }
85  for (int k = group[0].pos_A+1;k < order_A && sym_A[k-1] != NS;k++)
86  {
87  if (idx_A[k] == it->idx)
88  {
89  sym_in_A = true;
90  break;
91  }
92  }
93  if (!sym_in_A)
94  {
95  ++it;
96  continue;
97  }
98
99  bool sym_in_B = false;
100  for (int k = group[0].pos_B-1;k >= 0 && sym_B[k] != NS;k--)
101  {
102  if (idx_B[k] == it->idx)
103  {
104  sym_in_B = true;
105  break;
106  }
107  }
108  for (int k = group[0].pos_B+1;k < order_B && sym_B[k-1] != NS;k++)
109  {
110  if (idx_B[k] == it->idx)
111  {
112  sym_in_B = true;
113  break;
114  }
115  }
116  if (!sym_in_B)
117  {
118  ++it;
119  continue;
120  }
121
122  group.push_back(*it);
123  group.back().sort = s++;
124  it = indices.erase(it);
125  }
126
127  if (group.size() <= 1) continue;
128
129  std::vector<int> order_A, order_B;
130
131  for (int i = 0;i < (int)group.size();i++)
132  order_A.push_back(group[i].sort);
133
134  std::sort(group.begin(), group.end(), index_locator_::sortB);
135  for (int i = 0;i < (int)group.size();i++)
136  {
137  order_B.push_back(group[i].sort);
138  idx_B[group[group[i].sort].pos_B] = group[i].idx;
139  }
140  if (sym_B[group[0].pos_B] == AS)
141  fact *= relativeSign(order_A, order_B);
142  }
143
144  //if (fact != 1)
145  //{
146  // std::cout << "I got a -1 !!!!!" << std::endl;
147  // for (int i = 0;i < order_A;i++) std::cout << idx_A[i] << ' ';
148  // std::cout << std::endl;
149  // for (int i = 0;i < order_B;i++) std::cout << idx_B[i] << ' ';
150  // std::cout << std::endl;
151  //}
152
153  return fact;
154 }
155
156 template <typename T>
157 int align_symmetric_indices(int order_A, T& idx_A, const int* sym_A,
158  int order_B, T& idx_B, const int* sym_B,
159  int order_C, T& idx_C, const int* sym_C)
160 {
161  int fact = 1;
162
163  std::vector<index_locator_> indices;
164
165  for (int i = 0;i < order_A;i++)
166  {
167  int i_in_B; for (i_in_B = 0;i_in_B < order_B && idx_A[i] != idx_B[i_in_B];i_in_B++);
168  if (i_in_B == order_B) i_in_B = -1;
169
170  int i_in_C; for (i_in_C = 0;i_in_C < order_C && idx_A[i] != idx_C[i_in_C];i_in_C++);
171  if (i_in_C == order_C) i_in_C = -1;
172
173  if (i_in_B == -1 && i_in_C == -1) continue;
174
175  indices.push_back(index_locator_(0, idx_A[i], i, i_in_B, i_in_C));
176  }
177
178  for (int i = 0;i < order_B;i++)
179  {
180  int i_in_A; for (i_in_A = 0;i_in_A < order_A && idx_B[i] != idx_A[i_in_A];i_in_A++);
181  if (i_in_A == order_A) i_in_A = -1;
182
183  int i_in_C; for (i_in_C = 0;i_in_C < order_C && idx_B[i] != idx_C[i_in_C];i_in_C++);
184  if (i_in_C == order_C) i_in_C = -1;
185
186  if (i_in_A != -1 || i_in_C == -1) continue;
187
188  indices.push_back(index_locator_(0, idx_B[i], i_in_A, i, i_in_C));
189  }
190
191  while (!indices.empty())
192  {
193  std::vector<index_locator_> group;
194  group.push_back(indices[0]);
195  group.back().sort = 0;
196  indices.erase(indices.begin());
197
198  int s = 1;
199  for (std::vector<index_locator_>::iterator it = indices.begin();;)
200  {
201  if (it == indices.end()) break;
202
203  if ((group[0].pos_A == -1 && it->pos_A != -1) ||
204  (group[0].pos_A != -1 && it->pos_A == -1) ||
205  (group[0].pos_B == -1 && it->pos_B != -1) ||
206  (group[0].pos_B != -1 && it->pos_B == -1) ||
207  (group[0].pos_C == -1 && it->pos_C != -1) ||
208  (group[0].pos_C != -1 && it->pos_C == -1))
209  {
210  ++it;
211  continue;
212  }
213
214  if (group[0].pos_A != -1)
215  {
216  bool sym_in_A = false;
217  for (int k = group[0].pos_A-1;k >= 0 && sym_A[k] != NS;k--)
218  {
219  if (idx_A[k] == it->idx)
220  {
221  sym_in_A = true;
222  break;
223  }
224  }
225  for (int k = group[0].pos_A+1;k < order_A && sym_A[k-1] != NS;k++)
226  {
227  if (idx_A[k] == it->idx)
228  {
229  sym_in_A = true;
230  break;
231  }
232  }
233  if (!sym_in_A)
234  {
235  ++it;
236  continue;
237  }
238  }
239
240  if (group[0].pos_B != -1)
241  {
242  bool sym_in_B = false;
243  for (int k = group[0].pos_B-1;k >= 0 && sym_B[k] != NS;k--)
244  {
245  if (idx_B[k] == it->idx)
246  {
247  sym_in_B = true;
248  break;
249  }
250  }
251  for (int k = group[0].pos_B+1;k < order_B && sym_B[k-1] != NS;k++)
252  {
253  if (idx_B[k] == it->idx)
254  {
255  sym_in_B = true;
256  break;
257  }
258  }
259  if (!sym_in_B)
260  {
261  ++it;
262  continue;
263  }
264  }
265
266  if (group[0].pos_C != -1)
267  {
268  bool sym_in_C = false;
269  for (int k = group[0].pos_C-1;k >= 0 && sym_C[k] != NS;k--)
270  {
271  if (idx_C[k] == it->idx)
272  {
273  sym_in_C = true;
274  break;
275  }
276  }
277  for (int k = group[0].pos_C+1;k < order_C && sym_C[k-1] != NS;k++)
278  {
279  if (idx_C[k] == it->idx)
280  {
281  sym_in_C = true;
282  break;
283  }
284  }
285  if (!sym_in_C)
286  {
287  ++it;
288  continue;
289  }
290  }
291
292  group.push_back(*it);
293  group.back().sort = s++;
294  it = indices.erase(it);
295  }
296
297  if (group.size() <= 1) continue;
298
299  std::vector<int> order_A, order_B, order_C;
300
301  if (group[0].pos_A != -1)
302  {
303  for (int i = 0;i < (int)group.size();i++)
304  order_A.push_back(group[i].sort);
305
306  if (group[0].pos_B != -1)
307  {
308  std::sort(group.begin(), group.end(), index_locator_::sortB);
309  for (int i = 0;i < (int)group.size();i++)
310  {
311  order_B.push_back(group[i].sort);
312  idx_B[group[group[i].sort].pos_B] = group[i].idx;
313  }
314  if (sym_B[group[0].pos_B] == AS)
315  fact *= relativeSign(order_A, order_B);
316  }
317
318  if (group[0].pos_C != -1)
319  {
320  std::sort(group.begin(), group.end(), index_locator_::sortC);
321  for (int i = 0;i < (int)group.size();i++)
322  {
323  order_C.push_back(group[i].sort);
324  idx_C[group[group[i].sort].pos_C] = group[i].idx;
325  }
326  if (sym_C[group[0].pos_C] == AS)
327  fact *= relativeSign(order_A, order_C);
328  }
329  }
330  else
331  {
332  for (int i = 0;i < (int)group.size();i++)
333  order_B.push_back(group[i].sort);
334
335  std::sort(group.begin(), group.end(), index_locator_::sortC);
336  for (int i = 0;i < (int)group.size();i++)
337  {
338  order_C.push_back(group[i].sort);
339  idx_C[group[group[i].sort].pos_C] = group[i].idx;
340  }
341  if (sym_C[group[0].pos_C] == AS)
342  fact *= relativeSign(order_B, order_C);
343  }
344  }
345
346  //if (fact != 1)
347  //{
348  // std::cout << "I got a -1 !!!!!" << std::endl;
349  // for (int i = 0;i < order_A;i++) std::cout << idx_A[i] << ' ';
350  // std::cout << std::endl;
351  // for (int i = 0;i < order_B;i++) std::cout << idx_B[i] << ' ';
352  // std::cout << std::endl;
353  // for (int i = 0;i < order_C;i++) std::cout << idx_C[i] << ' ';
354  // std::cout << std::endl;
355  //}
356
357  return fact;
358 }
359
360 template <typename T>
361 int overcounting_factor(int order_A, const T& idx_A, const int* sym_A,
362  int order_B, const T& idx_B, const int* sym_B,
363  int order_C, const T& idx_C, const int* sym_C)
364 {
365  int fact = 1;
366
367  for (int i = 0;i < order_A;i++)
368  {
369  int j;
370  for (j = 0;j < order_B && idx_A[i] != idx_B[j];j++);
371  if (j == order_B) continue;
372
373  int k;
374  for (k = 0;k < order_C && idx_A[i] != idx_C[k];k++);
375  if (k != order_C) continue;
376
377  int ninarow = 1;
378  while (i < order_A &&
379  j < order_B &&
380  sym_A[i] != NS &&
381  sym_B[j] != NS &&
382  idx_A[i] == idx_B[j])
383  {
384  ninarow++;
385  i++;
386  j++;
387  }
388  if (i < order_A &&
389  j < order_B &&
390  idx_A[i] != idx_B[j]) ninarow--;
391
392  if (ninarow >= 2){
393  //if (sym_A[i-ninarow+1]!=SY)
394  for (;ninarow > 1;ninarow--) fact *= ninarow;
395  }
396  }
397
398  return fact;
399 }
400
401 template <typename T>
402 int overcounting_factor(int order_A, const T& idx_A, const int* sym_A,
403  int order_B, const T& idx_B, const int* sym_B)
404 {
405  int fact;
406  int ninarow;
407  fact = 1.0;
408
409  for (int i = 0;i < order_A;i++)
410  {
411  int j;
412  ninarow = 0;
413  for (j = 0;j < order_B && idx_A[i] != idx_B[j];j++);
414  if (j>=order_B){
415  ninarow = 1;
416  while (sym_A[i] != NS)
417  {
418  i++;
419  for (j = 0;j < order_B && idx_A[i] != idx_B[j];j++);
420  if (j>=order_B) ninarow++;
421  }
422  }
423  if (ninarow >= 2){
424  if (sym_A[i-ninarow+1]==AS) return 0.0;
425  if (sym_A[i-ninarow+1]==SY) {
426  /*printf("CTF error: sum over SY index pair currently not functional, ABORTING\n");
427  assert(0);*/
428  }
429  if (sym_A[i-ninarow+1]!=SY)
430  for (;ninarow > 1;ninarow--) fact *= ninarow;
431  }
432  }
433  return fact;
434 }
435
436
437 template int align_symmetric_indices<int*>(int order_A, int*& idx_A, const int* sym_A,
438  int order_B, int*& idx_B, const int* sym_B);
439
440 template int align_symmetric_indices<int*>(int order_A, int*& idx_A, const int* sym_A,
441  int order_B, int*& idx_B, const int* sym_B,
442  int order_C, int*& idx_C, const int* sym_C);
443
444 template int overcounting_factor<int*>(int order_A, int * const & idx_A, const int* sym_A,
445  int order_B, int * const & idx_B, const int* sym_B,
446  int order_C, int * const & idx_C, const int* sym_C);
447
448 template int overcounting_factor<int*>(int order_A, int * const & idx_A, const int* sym_A,
449  int order_B, int * const & idx_B, const int* sym_B);
450
451
452 template int align_symmetric_indices<std::string>(int order_A, std::string& idx_A, const int* sym_A,
453  int order_B, std::string& idx_B, const int* sym_B);
454
455 template int align_symmetric_indices<std::string>(int order_A, std::string& idx_A, const int* sym_A,
456  int order_B, std::string& idx_B, const int* sym_B,
457  int order_C, std::string& idx_C, const int* sym_C);
458
459 template int overcounting_factor<std::string>(int order_A, std::string const & idx_A, const int* sym_A,
460  int order_B, std::string const & idx_B, const int* sym_B,
461  int order_C, std::string const & idx_C, const int* sym_C);
462
463 template int overcounting_factor<std::string>(int order_A, std::string const & idx_A, const int* sym_A,
464  int order_B, std::string const & idx_B, const int* sym_B);
465
466
467
static bool sortB(const index_locator_ &a, const index_locator_ &b)
Definition: sym_indices.cxx:24
Definition: common.h:37
string
Definition: core.pyx:456
static bool sortA(const index_locator_ &a, const index_locator_ &b)
Definition: sym_indices.cxx:19
int overcounting_factor(int order_A, const T &idx_A, const int *sym_A, int order_B, const T &idx_B, const int *sym_B, int order_C, const T &idx_C, const int *sym_C)
int64_t fact(int64_t n)
Definition: util.cxx:277
template int align_symmetric_indices< int * >(int order_A, int *&idx_A, const int *sym_A, int order_B, int *&idx_B, const int *sym_B)
template int overcounting_factor< int * >(int order_A, int *const &idx_A, const int *sym_A, int order_B, int *const &idx_B, const int *sym_B, int order_C, int *const &idx_C, const int *sym_C)
int relativeSign(RAIterator s1b, RAIterator s1e, RAIterator s2b, RAIterator s2e)
Definition: sym_indices.h:9
static bool sortC(const index_locator_ &a, const index_locator_ &b)
Definition: sym_indices.cxx:29
Definition: apsp.cxx:17
bool operator==(int idx)
Definition: sym_indices.cxx:34
int align_symmetric_indices(int order_A, T &idx_A, const int *sym_A, int order_B, T &idx_B, const int *sym_B)
Definition: sym_indices.cxx:40
Definition: common.h:37
index_locator_(int sort, int idx, int pos_A, int pos_B, int pos_C)
Definition: sym_indices.cxx:16
Definition: common.h:37