Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
strassen.cxx
Go to the documentation of this file.
1 /*Copyright (c) 2011, Edgar Solomonik, all rights reserved.*/
2 
10 #include <ctf.hpp>
11 using namespace CTF;
12 
13 int strassen(int const n,
14  int const sym,
15  World &dw){
16  int rank, i, num_pes, cnum_pes;
17  int64_t np;
18  double * pairs, err;
19  int64_t * indices;
20 
21  MPI_Comm pcomm, ccomm;
22 
23  pcomm = dw.comm;
24 
25  MPI_Comm_rank(pcomm, &rank);
26  MPI_Comm_size(pcomm, &num_pes);
27 
28  if (num_pes % 7 == 0){
29  cnum_pes = num_pes/7;
30  MPI_Comm_split(pcomm, rank/cnum_pes, rank%cnum_pes, &ccomm);
31  } else {
32  cnum_pes = 1;
33  ccomm = dw.comm;
34  }
35  World cdw(ccomm);
36 
37 #ifndef TEST_SUITE
38  if (rank == 0)
39  printf("n = %d, p = %d\n",
40  n,num_pes);
41 #endif
42 
43  Matrix<> A(n, n, sym, dw);
44  Matrix<> B(n, n, sym, dw);
45  Matrix<> C(n, n, NS, dw);
46  Matrix<> Cs(n, n, NS, dw);
47  Matrix<> M1(n/2, n/2, NS, dw);
48  Matrix<> M2(n/2, n/2, NS, dw);
49  Matrix<> M3(n/2, n/2, NS, dw);
50  Matrix<> M4(n/2, n/2, NS, dw);
51  Matrix<> M5(n/2, n/2, NS, dw);
52  Matrix<> M6(n/2, n/2, NS, dw);
53  Matrix<> M7(n/2, n/2, NS, dw);
54 
55  srand48(13*rank);
56  A.get_local_data(&np, &indices, &pairs);
57  for (i=0; i<np; i++ ) pairs[i] = drand48()-.5;
58  A.write(np, indices, pairs);
59  delete [] pairs;
60  free(indices);
61  B.get_local_data(&np, &indices, &pairs);
62  for (i=0; i<np; i++ ) pairs[i] = drand48()-.5;
63  B.write(np, indices, pairs);
64  delete [] pairs;
65  free(indices);
66  /*C.get_local_data(&np, &indices, &pairs);
67  for (i=0; i<np; i++ ) pairs[i] = 0.0;
68  C.write(np, indices, pairs);
69  delete [] pairs;
70  free(indices);*/
71 
72  C["ij"] = A["ik"]*B["kj"];
73 
74  int off_00[2] = {0, 0};
75  int off_01[2] = {0, n/2};
76  int off_10[2] = {n/2, 0};
77  int off_11[2] = {n/2, n/2};
78  int end_11[2] = {n/2, n/2};
79  int end_21[2] = {n, n/2};
80  int end_12[2] = {n/2, n};
81  int end_22[2] = {n, n};
82 
83  int snhalf[2] = {n/2, n/2};
84  int sym_ns[2] = {NS, NS};
85 
86  if (ccomm != dw.comm){
87  /*int off_ij[2] = {ri * n/2, rj * n/2};
88  int end_ij[2] = {ri * n/2 + n/2, rj * n/2 + n/2};
89  int off_ik[2] = {ri * n/2, rk * n/2};
90  int end_ik[2] = {ri * n/2 + n/2, rk * n/2 + n/2};
91  int off_kj[2] = {rk * n/2, rj * n/2};
92  int end_kj[2] = {rk * n/2 + n/2, rj * n/2 + n/2};*/
93  Matrix<> cA(n/2, n/2, NS, cdw);
94  Matrix<> cB(n/2, n/2, NS, cdw);
95  Matrix<> cC(n/2, n/2, NS, cdw);
96 
97  Tensor<> dummy(0, 0, NULL, NULL, cdw);
98 
99  switch (rank/cnum_pes){
100  case 0: //M1
101  cA.slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
102  cA.slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
103  cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
104  cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
105  cC["ij"] = cA["ik"]*cB["kj"];
106  Cs.slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
107  Cs.slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
108  break;
109  case 1: //M6
110  cA.slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
111  cA["ik"] = -1.0*cA["ik"];
112  cA.slice(off_00, end_11, 1.0, A, off_10, end_21, 1.0);
113  cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
114  cB.slice(off_00, end_11, 1.0, B, off_01, end_12, 1.0);
115  cC["ij"] = cA["ik"]*cB["kj"];
116  Cs.slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
117  Cs.slice(off_00, off_00, 1.0, dummy, NULL, NULL, 1.0);
118  break;
119  case 2: //M7
120  cA.slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
121  cA["ik"] = -1.0*cA["ik"];
122  cA.slice(off_00, end_11, 1.0, A, off_01, end_12, 1.0);
123  cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
124  cB.slice(off_00, end_11, 1.0, B, off_10, end_21, 1.0);
125  cC["ij"] = cA["ik"]*cB["kj"];
126  Cs.slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
127  Cs.slice(off_00, off_00, 1.0, dummy, NULL, NULL, 1.0);
128  break;
129  case 3: //M2
130  cA.slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
131  cA.slice(off_00, end_11, 1.0, A, off_10, end_21, 1.0);
132  cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
133  dummy.slice(NULL, NULL, 1.0, B, off_00, off_00, 1.0);
134  cC["ij"] = cA["ik"]*cB["kj"];
135  Cs.slice(off_10, end_21, 1.0, cC, off_00, end_11, 1.0);
136  cC["ij"] = -1.0*cC["ij"];
137  Cs.slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
138  break;
139  case 4: //M5
140  cA.slice(off_00, end_11, 1.0, A, off_01, end_12, 1.0);
141  cA.slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
142  cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
143  dummy.slice(NULL, NULL, 1.0, B, off_00, off_00, 1.0);
144  cC["ij"] = -1.0*cA["ik"]*cB["kj"];
145  Cs.slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
146  cC["ij"] = -1.0*cC["ij"];
147  Cs.slice(off_01, end_12, 1.0, cC, off_00, end_11, 1.0);
148  break;
149  case 5: //M3
150  cA.slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
151  dummy.slice(NULL, NULL, 1.0, A, off_00, off_00, 1.0);
152  cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
153  cB["kj"] = -1.0*cB["kj"];
154  cB.slice(off_00, end_11, 1.0, B, off_01, end_12, 1.0);
155  cC["ij"] = cA["ik"]*cB["kj"];
156  Cs.slice(off_01, end_12, 1.0, cC, off_00, end_11, 1.0);
157  Cs.slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
158  break;
159  case 6: //M4
160  cA.slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
161  dummy.slice(NULL, NULL, 1.0, A, off_00, off_00, 1.0);
162  cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
163  cB["kj"] = -1.0*cB["kj"];
164  cB.slice(off_00, end_11, 1.0, B, off_10, end_21, 1.0);
165  cC["ij"] = cA["ik"]*cB["kj"];
166  Cs.slice(off_10, end_21, 1.0, cC, off_00, end_11, 1.0);
167  Cs.slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
168  break;
169  }
170  } else {
171 
172  Tensor<> A21 = A.slice(off_10, end_21);
173  Tensor<> A11 = A.slice(off_00, end_11);
174  Tensor<> A12(2,snhalf,sym_ns,dw);
175  if (sym == SY || sym == SH){
176  A12["ij"] = A21["ji"];
177  }
178  if (sym == AS){
179  A12["ij"] = -1.0*A21["ji"];
180  }
181  if (sym == NS){
182  A12 = A.slice(off_01, end_12);
183  }
184  Tensor<> A22 = A.slice(off_11, end_22);
185 
186  Tensor<> B11 = B.slice(off_00, end_11);
187  Tensor<> B21 = B.slice(off_10, end_21);
188 
189  Tensor<> B12(2,snhalf,sym_ns,dw);
190  if (sym == SY || sym == SH){
191  B12["ij"] = B21["ji"];
192  }
193  if (sym == AS){
194  B12["ij"] = -1.0*B21["ji"];
195  }
196  if (sym == NS){
197  B12 = B.slice(off_01, end_12);
198  }
199  Tensor<> B22 = B.slice(off_11, end_22);
200 
201  M1["ij"] = (A11["ik"]+A22["ik"])*(B22["kj"]+B11["kj"]);
202  M6["ij"] = (A21["ik"]-A11["ik"])*(B11["kj"]+B12["kj"]);
203  M7["ij"] = (A12["ik"]-A22["ik"])*(B22["kj"]+B21["kj"]);
204  M2["ij"] = (A21["ik"]+A22["ik"])*B11["kj"];
205  M5["ij"] = (A11["ik"]+A12["ik"])*B22["kj"];
206  M3["ij"] = A11["ik"]*(B12["kj"]-B22["kj"]);
207  M4["ij"] = A22["ik"]*(B21["kj"]-B11["kj"]);
208 
209  /*printf("[0] %lf\n", M1.norm2());
210  printf("[1] %lf\n", M6.norm2());
211  printf("[2] %lf\n", M7.norm2());
212  printf("[3] %lf\n", M2.norm2());
213  printf("[4] %lf\n", M5.norm2());
214  printf("[5] %lf\n", M3.norm2());
215  printf("[6] %lf\n", M4.norm2());*/
216 
217  Cs.slice(off_00, end_11, 0.0, M1, off_00, end_11, 1.0);
218  Cs.slice(off_00, end_11, 1.0, M4, off_00, end_11, 1.0);
219  Cs.slice(off_00, end_11, 1.0, M5, off_00, end_11, -1.0);
220  Cs.slice(off_00, end_11, 1.0, M7, off_00, end_11, 1.0);
221  Cs.slice(off_01, end_12, 0.0, M3, off_00, end_11, 1.0);
222  Cs.slice(off_01, end_12, 1.0, M5, off_00, end_11, 1.0);
223  Cs.slice(off_10, end_21, 0.0, M2, off_00, end_11, 1.0);
224  Cs.slice(off_10, end_21, 1.0, M4, off_00, end_11, 1.0);
225  Cs.slice(off_11, end_22, 0.0, M1, off_00, end_11, 1.0);
226  Cs.slice(off_11, end_22, 1.0, M2, off_00, end_11, -1.0);
227  Cs.slice(off_11, end_22, 1.0, M3, off_00, end_11, 1.0);
228  Cs.slice(off_11, end_22, 1.0, M6, off_00, end_11, 1.0);
229  }
230 
231  err = ((1./n)/n)*(C["ij"]-Cs["ij"])*(C["ij"]-Cs["ij"]);
232 
233  if (rank == 0){
234  //printf("{ Strassen's error norm = %E\n",err);
235  if (err<1.E-10)
236  printf("{ Strassen's algorithm via slicing } passed\n");
237  else
238  printf("{ Strassen's algorithm via slicing } FAILED, error norm = %E\n",err);
239  }
240  return err<1.E-10;
241 }
242 
243 
244 #ifndef TEST_SUITE
245 char* getCmdOption(char ** begin,
246  char ** end,
247  const std::string & option){
248  char ** itr = std::find(begin, end, option);
249  if (itr != end && ++itr != end){
250  return *itr;
251  }
252  return 0;
253 }
254 
255 int main(int argc, char ** argv){
256  int rank, np, n;
257  int const in_num = argc;
258  char ** input_str = argv;
259 
260  MPI_Init(&argc, &argv);
261  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
262  MPI_Comm_size(MPI_COMM_WORLD, &np);
263 
264  if (getCmdOption(input_str, input_str+in_num, "-n")){
265  n = atoi(getCmdOption(input_str, input_str+in_num, "-n"));
266  if (n < 0) n = 256;
267  } else n = 256;
268 
269  assert(n%2 == 0);
270 
271  {
272  World dw(MPI_COMM_WORLD, argc, argv);
273  int pass;
274  if (rank == 0){
275  printf("Non-symmetric: NS = NS*NS strassen:\n");
276  }
277  pass = strassen(n, NS, dw);
278  assert(pass);
279  if (rank == 0){
280  printf("(Anti-)Skew-symmetric: NS = AS*AS strassen:\n");
281  }
282  pass = strassen(n, AS, dw);
283  assert(pass);
284  if (rank == 0){
285  printf("Symmetric: NS = SY*SY strassen:\n");
286  }
287  pass = strassen(n, SY, dw);
288  assert(pass);
289  if (rank == 0){
290  printf("Symmetric-hollow: NS = SH*SH strassen:\n");
291  }
292  pass = strassen(n, SH, dw);
293  assert(pass);
294  }
295 
296  MPI_Finalize();
297  return 0;
298 }
299 
305 #endif
Matrix class which encapsulates a 2D tensor.
Definition: matrix.h:18
def rank(self)
Definition: core.pyx:312
Tensor< dtype > slice(int const *offsets, int const *ends) const
cuts out a slice (block) of this tensor A[offsets,ends) result will always be fully nonsymmetric ...
Definition: tensor.cxx:643
Definition: common.h:37
char * getCmdOption(char **begin, char **end, const std::string &option)
Definition: strassen.cxx:245
an instance of the CTF library (world) on a MPI communicator
Definition: world.h:19
string
Definition: core.pyx:456
int main(int argc, char **argv)
Definition: strassen.cxx:255
void get_local_data(int64_t *npair, int64_t **global_idx, dtype **data, bool nonzeros_only=false, bool unpack_sym=false) const
Gives the global indices and values associated with the local data.
Definition: tensor.cxx:159
Definition: apsp.cxx:17
an instance of a tensor within a CTF world
Definition: tensor.h:74
int strassen(int const n, int const sym, World &dw)
Definition: strassen.cxx:13
Definition: common.h:37
Definition: common.h:37
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...
Definition: tensor.cxx:264
Definition: common.h:37
MPI_Comm comm
set of processors making up this world
Definition: world.h:22
def np(self)
Definition: core.pyx:315