18 int rank, num_pes, cnum_pes, ri, rj, rk, ni, nj, nk, div;
19 MPI_Comm pcomm, ccomm;
22 MPI_Comm_rank(pcomm, &rank);
23 MPI_Comm_size(pcomm, &num_pes);
25 if (num_pes == 1 || m == 1 || n == 1 || k==1){
26 C[
"ij"] += 1.0*A[
"ik"]*B[
"kj"];
28 for (div=2; num_pes%div!=0; div++){}
30 cnum_pes = num_pes / div;
32 MPI_Comm_split(pcomm, rank/cnum_pes, rank%cnum_pes, &ccomm);
41 if (m >= n && m >= k){
45 }
else if (n >= m && n >= k){
49 }
else if (k >= m && k >= n){
55 int off_ij[2] = {ri * m/ni, rj * n/nj};
56 int end_ij[2] = {ri * m/ni + m/ni, rj * n/nj + n/nj};
57 int off_ik[2] = {ri * m/ni, rk * k/nk};
58 int end_ik[2] = {ri * m/ni + m/ni, rk * k/nk + k/nk};
59 int off_kj[2] = {rk * k/nk, rj * n/nj};
60 int end_kj[2] = {rk * k/nk + k/nk, rj * n/nj + n/nj};
68 int off_00[2] = {0, 0};
69 int end_11[2] = {m/ni, n/nj};
70 C.
slice(off_ij, end_ij, 1.0, cC, off_00, end_11, 1.0);
71 MPI_Comm_free(&ccomm);
89 MPI_Comm pcomm = dw.
comm;
90 MPI_Comm_rank(pcomm, &rank);
91 MPI_Comm_size(pcomm, &num_pes);
95 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
96 A.
write(np, indices, pairs);
100 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
101 B.
write(np, indices, pairs);
105 C_ans[
"ij"] += 1.0*A[
"ik"]*B[
"kj"];
113 C_ans[
"ij"] -= C[
"ij"];
119 printf(
"{ GEMM with parallel slicing } passed\n");
121 printf(
"{ GEMM with parallel slicing } FAILED, error norm = %E\n",err);
131 char ** itr = std::find(begin, end, option);
132 if (itr != end && ++itr != end){
138 int main(
int argc,
char ** argv){
140 int const in_num = argc;
141 char ** input_str = argv;
143 MPI_Init(&argc, &argv);
144 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
145 MPI_Comm_size(MPI_COMM_WORLD, &np);
148 n = atoi(
getCmdOption(input_str, input_str+in_num,
"-n"));
152 m = atoi(
getCmdOption(input_str, input_str+in_num,
"-m"));
156 k = atoi(
getCmdOption(input_str, input_str+in_num,
"-k"));
161 World dw(MPI_COMM_WORLD, argc, argv);
164 printf(
"Non-symmetric: NS = NS*NS test_recursive_matmul:\n");
Matrix class which encapsulates a 2D tensor.
int main(int argc, char **argv)
Tensor< dtype > slice(int const *offsets, int const *ends) const
cuts out a slice (block) of this tensor A[offsets,ends) result will always be fully nonsymmetric ...
an instance of the CTF library (world) on a MPI communicator
dtype norm2()
computes the frobenius norm of the tensor (needs sqrt()!)
CTF::World * wrld
distributed processor context on which tensor is defined
void get_local_data(int64_t *npair, int64_t **global_idx, dtype **data, bool nonzeros_only=false, bool unpack_sym=false) const
Gives the global indices and values associated with the local data.
void recursive_matmul(int n, int m, int k, Tensor<> &A, Tensor<> &B, Tensor<> &C)
char * getCmdOption(char **begin, char **end, const std::string &option)
an instance of a tensor within a CTF world
int test_recursive_matmul(int n, int m, int k, World &dw)
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...
MPI_Comm comm
set of processors making up this world