16 int rank, i, num_pes, cnum_pes;
21 MPI_Comm pcomm, ccomm;
25 MPI_Comm_rank(pcomm, &rank);
26 MPI_Comm_size(pcomm, &num_pes);
28 if (num_pes % 7 == 0){
30 MPI_Comm_split(pcomm, rank/cnum_pes, rank%cnum_pes, &ccomm);
39 printf(
"n = %d, p = %d\n",
57 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
58 A.
write(np, indices, pairs);
62 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
63 B.
write(np, indices, pairs);
72 C[
"ij"] = A[
"ik"]*B[
"kj"];
74 int off_00[2] = {0, 0};
75 int off_01[2] = {0, n/2};
76 int off_10[2] = {n/2, 0};
77 int off_11[2] = {n/2, n/2};
78 int end_11[2] = {n/2, n/2};
79 int end_21[2] = {n, n/2};
80 int end_12[2] = {n/2, n};
81 int end_22[2] = {n, n};
83 int snhalf[2] = {n/2, n/2};
84 int sym_ns[2] = {
NS,
NS};
86 if (ccomm != dw.
comm){
97 Tensor<> dummy(0, 0, NULL, NULL, cdw);
99 switch (rank/cnum_pes){
101 cA.
slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
102 cA.
slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
103 cB.
slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
104 cB.
slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
105 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
106 Cs.
slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
107 Cs.
slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
110 cA.
slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
111 cA[
"ik"] = -1.0*cA[
"ik"];
112 cA.
slice(off_00, end_11, 1.0, A, off_10, end_21, 1.0);
113 cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
114 cB.slice(off_00, end_11, 1.0, B, off_01, end_12, 1.0);
115 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
116 Cs.
slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
117 Cs.
slice(off_00, off_00, 1.0, dummy, NULL, NULL, 1.0);
120 cA.
slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
121 cA[
"ik"] = -1.0*cA[
"ik"];
122 cA.
slice(off_00, end_11, 1.0, A, off_01, end_12, 1.0);
123 cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
124 cB.slice(off_00, end_11, 1.0, B, off_10, end_21, 1.0);
125 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
126 Cs.
slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
127 Cs.
slice(off_00, off_00, 1.0, dummy, NULL, NULL, 1.0);
130 cA.
slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
131 cA.
slice(off_00, end_11, 1.0, A, off_10, end_21, 1.0);
132 cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
133 dummy.
slice(NULL, NULL, 1.0, B, off_00, off_00, 1.0);
134 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
135 Cs.
slice(off_10, end_21, 1.0, cC, off_00, end_11, 1.0);
136 cC[
"ij"] = -1.0*cC[
"ij"];
137 Cs.
slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
140 cA.
slice(off_00, end_11, 1.0, A, off_01, end_12, 1.0);
141 cA.
slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
142 cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
143 dummy.
slice(NULL, NULL, 1.0, B, off_00, off_00, 1.0);
144 cC[
"ij"] = -1.0*cA[
"ik"]*cB[
"kj"];
145 Cs.
slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
146 cC[
"ij"] = -1.0*cC[
"ij"];
147 Cs.
slice(off_01, end_12, 1.0, cC, off_00, end_11, 1.0);
150 cA.
slice(off_00, end_11, 1.0, A, off_00, end_11, 1.0);
151 dummy.
slice(NULL, NULL, 1.0, A, off_00, off_00, 1.0);
152 cB.slice(off_00, end_11, 1.0, B, off_11, end_22, 1.0);
153 cB[
"kj"] = -1.0*cB[
"kj"];
154 cB.slice(off_00, end_11, 1.0, B, off_01, end_12, 1.0);
155 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
156 Cs.
slice(off_01, end_12, 1.0, cC, off_00, end_11, 1.0);
157 Cs.
slice(off_11, end_22, 1.0, cC, off_00, end_11, 1.0);
160 cA.
slice(off_00, end_11, 1.0, A, off_11, end_22, 1.0);
161 dummy.
slice(NULL, NULL, 1.0, A, off_00, off_00, 1.0);
162 cB.slice(off_00, end_11, 1.0, B, off_00, end_11, 1.0);
163 cB[
"kj"] = -1.0*cB[
"kj"];
164 cB.slice(off_00, end_11, 1.0, B, off_10, end_21, 1.0);
165 cC[
"ij"] = cA[
"ik"]*cB[
"kj"];
166 Cs.
slice(off_10, end_21, 1.0, cC, off_00, end_11, 1.0);
167 Cs.
slice(off_00, end_11, 1.0, cC, off_00, end_11, 1.0);
175 if (sym ==
SY || sym ==
SH){
176 A12[
"ij"] = A21[
"ji"];
179 A12[
"ij"] = -1.0*A21[
"ji"];
182 A12 = A.
slice(off_01, end_12);
190 if (sym ==
SY || sym ==
SH){
191 B12[
"ij"] = B21[
"ji"];
194 B12[
"ij"] = -1.0*B21[
"ji"];
197 B12 = B.
slice(off_01, end_12);
201 M1[
"ij"] = (A11[
"ik"]+A22[
"ik"])*(B22[
"kj"]+B11[
"kj"]);
202 M6[
"ij"] = (A21[
"ik"]-A11[
"ik"])*(B11[
"kj"]+B12[
"kj"]);
203 M7[
"ij"] = (A12[
"ik"]-A22[
"ik"])*(B22[
"kj"]+B21[
"kj"]);
204 M2[
"ij"] = (A21[
"ik"]+A22[
"ik"])*B11[
"kj"];
205 M5[
"ij"] = (A11[
"ik"]+A12[
"ik"])*B22[
"kj"];
206 M3[
"ij"] = A11[
"ik"]*(B12[
"kj"]-B22[
"kj"]);
207 M4[
"ij"] = A22[
"ik"]*(B21[
"kj"]-B11[
"kj"]);
217 Cs.
slice(off_00, end_11, 0.0, M1, off_00, end_11, 1.0);
218 Cs.
slice(off_00, end_11, 1.0, M4, off_00, end_11, 1.0);
219 Cs.
slice(off_00, end_11, 1.0, M5, off_00, end_11, -1.0);
220 Cs.
slice(off_00, end_11, 1.0, M7, off_00, end_11, 1.0);
221 Cs.
slice(off_01, end_12, 0.0, M3, off_00, end_11, 1.0);
222 Cs.
slice(off_01, end_12, 1.0, M5, off_00, end_11, 1.0);
223 Cs.
slice(off_10, end_21, 0.0, M2, off_00, end_11, 1.0);
224 Cs.
slice(off_10, end_21, 1.0, M4, off_00, end_11, 1.0);
225 Cs.
slice(off_11, end_22, 0.0, M1, off_00, end_11, 1.0);
226 Cs.
slice(off_11, end_22, 1.0, M2, off_00, end_11, -1.0);
227 Cs.
slice(off_11, end_22, 1.0, M3, off_00, end_11, 1.0);
228 Cs.
slice(off_11, end_22, 1.0, M6, off_00, end_11, 1.0);
231 err = ((1./n)/n)*(C[
"ij"]-Cs[
"ij"])*(C[
"ij"]-Cs[
"ij"]);
236 printf(
"{ Strassen's algorithm via slicing } passed\n");
238 printf(
"{ Strassen's algorithm via slicing } FAILED, error norm = %E\n",err);
248 char ** itr = std::find(begin, end, option);
249 if (itr != end && ++itr != end){
255 int main(
int argc,
char ** argv){
257 int const in_num = argc;
258 char ** input_str = argv;
260 MPI_Init(&argc, &argv);
261 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
262 MPI_Comm_size(MPI_COMM_WORLD, &np);
265 n = atoi(
getCmdOption(input_str, input_str+in_num,
"-n"));
272 World dw(MPI_COMM_WORLD, argc, argv);
275 printf(
"Non-symmetric: NS = NS*NS strassen:\n");
280 printf(
"(Anti-)Skew-symmetric: NS = AS*AS strassen:\n");
285 printf(
"Symmetric: NS = SY*SY strassen:\n");
290 printf(
"Symmetric-hollow: NS = SH*SH strassen:\n");
Matrix class which encapsulates a 2D tensor.
Tensor< dtype > slice(int const *offsets, int const *ends) const
cuts out a slice (block) of this tensor A[offsets,ends) result will always be fully nonsymmetric ...
char * getCmdOption(char **begin, char **end, const std::string &option)
an instance of the CTF library (world) on a MPI communicator
int main(int argc, char **argv)
void get_local_data(int64_t *npair, int64_t **global_idx, dtype **data, bool nonzeros_only=false, bool unpack_sym=false) const
Gives the global indices and values associated with the local data.
an instance of a tensor within a CTF world
int strassen(int const n, int const sym, World &dw)
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...
MPI_Comm comm
set of processors making up this world