17 double * pairs, * pairs_AB, * pairs_BC;
18 int64_t * indices, * indices_AB, * indices_BC;
20 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
21 MPI_Comm_size(MPI_COMM_WORLD, &num_pes);
24 int shapeN4[] = {sym,
NS,sym,NS};
25 int sizeN4[] = {n,n,n,n};
35 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
36 A.
write(np, indices, pairs);
40 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
41 B.
write(np, indices, pairs);
45 for (i=0; i<
np; i++ ) pairs[i] = drand48()-.5;
46 C.
write(np, indices, pairs);
53 double t = MPI_Wtime();
54 for (i=0; i<niter; i++){
55 C[
"ijkl"] += (.3*i)*A[
"ijmn"]*B[
"mnkl"];
57 time = MPI_Wtime()- t;
59 double nd = (double)n;
61 if (sym ==
SY || sym ==
AS){
64 printf(
"%lf seconds/GEMM %lf GF\n",
65 time/niter,niter*c*nd*nd*nd*nd*nd*nd/time);
66 printf(
"Verifying associativity\n");
73 D[
"ijkl"] = A[
"ijmn"]*B[
"mnkl"];
74 D[
"ijkl"] = D[
"ijmn"]*C[
"mnkl"];
75 C[
"ijkl"] = B[
"ijmn"]*C[
"mnkl"];
76 C[
"ijkl"] = A[
"ijmn"]*C[
"mnkl"];
79 C.get_local_data(&np, &indices_BC, &pairs_BC);
83 if (fabs((
double)pairs_BC[i]-(
double)pairs_AB[i])>=1.E-6) pass = 0;
90 MPI_Reduce(MPI_IN_PLACE, &pass, 1, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);
92 printf(
"{ (A[\"ijmn\"]*B[\"mnpq\"])*C[\"pqkl\"] = A[\"ijmn\"]*(B[\"mnpq\"]*C[\"pqkl\"]) } passed\n");
94 printf(
"{ (A[\"ijmn\"]*B[\"mnpq\"])*C[\"pqkl\"] = A[\"ijmn\"]*(B[\"mnpq\"]*C[\"pqkl\"]) } failed!\n");
96 MPI_Reduce(&pass, MPI_IN_PLACE, 1, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);
105 char ** itr = std::find(begin, end, option);
106 if (itr != end && ++itr != end){
113 int main(
int argc,
char ** argv){
114 int rank,
np, niter, n, pass;
115 int const in_num = argc;
116 char ** input_str = argv;
118 MPI_Init(&argc, &argv);
119 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
120 MPI_Comm_size(MPI_COMM_WORLD, &np);
123 n = atoi(
getCmdOption(input_str, input_str+in_num,
"-n"));
127 if (
getCmdOption(input_str, input_str+in_num,
"-niter")){
128 niter = atoi(
getCmdOption(input_str, input_str+in_num,
"-niter"));
129 if (niter < 0) niter = 3;
135 World dw(argc, argv);
138 printf(
"Computing C_ijkl = A_ijmn*B_klmn\n");
139 printf(
"Non-symmetric: NS = NS*NS gemm:\n");
144 printf(
"Symmetric: SY = SY*SY gemm:\n");
149 printf(
"(Anti-)Skew-symmetric: AS = AS*AS gemm:\n");
154 printf(
"Symmetric-hollow: SH = SH*SH gemm:\n");
char * getCmdOption(char **begin, char **end, const std::string &option)
an instance of the CTF library (world) on a MPI communicator
int gemm_4D(int const n, int const sym, int const niter, World &dw)
void align(CTF_int::tensor const &A)
aligns data mapping with tensor A
int main(int argc, char **argv)
void get_local_data(int64_t *npair, int64_t **global_idx, dtype **data, bool nonzeros_only=false, bool unpack_sym=false) const
Gives the global indices and values associated with the local data.
an instance of a tensor within a CTF world
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...