5 #include "../sparse_formats/csr.h" 13 template <
typename dtype>
18 template <
typename dtype>
25 for (
int i=0; i<n; i++){
26 Y[incY*i] += alpha*X[incX*i];
32 (int,float,
float const *,int,
float *,int);
36 (int,double,
double const *,int,
double *,int);
39 void default_axpy< std::complex<float> >
40 (int,std::complex<float>,std::complex<float>
const *,int,std::complex<float> *,int);
43 void default_axpy< std::complex<double> >
44 (int,std::complex<double>,std::complex<double>
const *,int,std::complex<double> *,int);
46 template <
typename dtype>
51 for (
int i=0; i<n; i++){
63 void default_scal< std::complex<float> >
64 (
int n, std::complex<float> alpha, std::complex<float> * X,
int incX);
67 void default_scal< std::complex<double> >
68 (
int n, std::complex<double> alpha, std::complex<double> * X,
int incX);
70 template<
typename dtype>
82 int istride_A, lstride_A, jstride_B, lstride_B;
84 if (tA ==
'N' || tA ==
'n'){
91 if (tB ==
'N' || tB ==
'n'){
102 C[j*m+i] += alpha*A[istride_A*i+lstride_A*l]*B[lstride_B*l+jstride_B*j];
109 template<
typename dtype>
115 #pragma omp parallel for 117 for (
int i=0; i<ngrp; i++){
118 data_ptrs[i] = ((
dtype*)data)+i*grp_sz;
123 template <
typename dtype>
137 template <
typename dtype>
161 CTF_int::gemm<float>(tA,tB,m,n,k,alpha,A,B,beta,C);
176 CTF_int::gemm<double>(tA,tB,m,n,k,alpha,A,B,beta,C);
180 inline void default_gemm< std::complex<float> >
186 std::complex<float> alpha,
187 std::complex<float>
const * A,
188 std::complex<float>
const * B,
189 std::complex<float> beta,
190 std::complex<float> * C){
191 CTF_int::gemm< std::complex<float> >(tA,tB,m,n,k,alpha,A,B,beta,C);
195 inline void default_gemm< std::complex<double> >
201 std::complex<double> alpha,
202 std::complex<double>
const * A,
203 std::complex<double>
const * B,
204 std::complex<double> beta,
205 std::complex<double> * C){
206 CTF_int::gemm< std::complex<double> >(tA,tB,m,n,k,alpha,A,B,beta,C);
209 template<
typename dtype>
222 if (m == 1 && n == 1 && k == 1){
223 for (
int i=0; i<l; i++){
224 C[i] = C[i]*beta + alpha*A[i]*B[i];
227 for (
int i=0; i<l; i++){
228 default_gemm<dtype>(taA, taB, m, n, k, alpha, A+i*m*k, B+i*k*n, beta, C+i*m*n);
246 CTF_int::gemm_batch<float>(taA, taB, l, m, n, k, alpha, A, B, beta, C);
262 CTF_int::gemm_batch<double>(taA, taB, l, m, n, k, alpha, A, B, beta, C);
266 inline void default_gemm_batch<std::complex<float>>
273 std::complex<float> alpha,
274 std::complex<float>
const* A,
275 std::complex<float>
const* B,
276 std::complex<float> beta,
277 std::complex<float> * C){
278 CTF_int::gemm_batch< std::complex<float> >(taA, taB, l, m, n, k, alpha, A, B, beta, C);
282 inline void default_gemm_batch<std::complex<double>>
289 std::complex<double> alpha,
290 std::complex<double>
const* A,
291 std::complex<double>
const* B,
292 std::complex<double> beta,
293 std::complex<double> * C){
294 CTF_int::gemm_batch< std::complex<double> >(taA, taB, l, m, n, k, alpha, A, B, beta, C);
297 template <
typename dtype>
311 for (
int j=0; j<n; j++){
312 for (
int i=0; i<m; i++){
316 for (
int i=0; i<nnz_A; i++){
317 int row_A = rows_A[i]-1;
318 int col_A = cols_A[i]-1;
319 for (
int col_C=0; col_C<n; col_C++){
320 C[col_C*m+row_A] += alpha*A[i]*B[col_C*k+col_A];
328 (int,int,int,float,
float const *,
int const *,
int const *,int,
float const *,float,
float *);
332 (int,int,int,double,
double const *,
int const *,
int const *,int,
double const *,double,
double *);
335 void default_coomm< std::complex<float> >
336 (int,int,int,std::complex<float>,std::complex<float>
const *,
int const *,
int const *,int,std::complex<float>
const *,std::complex<float>,std::complex<float> *);
339 void default_coomm< std::complex<double> >
340 (int,int,int,std::complex<double>,std::complex<double>
const *,
int const *,
int const *,int,std::complex<double>
const *,std::complex<double>,std::complex<double> *);
358 template <
typename dtype=
double,
bool is_ord=CTF_
int::get_default_is_ord<dtype>()>
363 void (*fscal)(int,
dtype,dtype*,int);
364 void (*faxpy)(int,
dtype,dtype
const*,int,dtype*,int);
366 void (*fgemm)(char,char,int,int,int,
dtype,dtype
const*,dtype
const*,
dtype,dtype*);
367 void (*fcoomm)(int,int,int,
dtype,dtype
const*,
int const*,
int const*,int,dtype
const*,
dtype,dtype*);
368 void (*fgemm_batch)(char,char,int,int,int,int,
dtype,dtype
const*,dtype
const*,
dtype,dtype*);
374 this->tmulid = other.
tmulid;
375 this->fscal = other.
fscal;
376 this->faxpy = other.
faxpy;
377 this->fmul = other.
fmul;
378 this->fgemm = other.
fgemm;
379 this->fcoomm = other.
fcoomm;
380 this->is_def = other.
is_def;
401 dtype (*fadd_)(dtype a, dtype
b),
404 dtype (*fmul_)(dtype a, dtype b),
405 void (*gemm_)(
char,
char,
int,
int,
int,dtype,dtype
const*,dtype
const*,dtype,dtype*)=NULL,
406 void (*axpy_)(
int,dtype,dtype
const*,
int,dtype*,
int)=NULL,
407 void (*scal_)(
int,dtype,dtype*,
int)=NULL,
408 void (*coomm_)(
int,
int,
int,dtype,dtype
const*,
int const*,
int const*,
int,dtype
const*,dtype,dtype*)=NULL,
409 void (*fgemm_batch_)(
char,
char,
int,
int,
int,
int,dtype,dtype
const*,dtype
const*,dtype,dtype*)=NULL)
410 :
Monoid<dtype, is_ord>(addid_, fadd_, addmop_) {
416 fgemm_batch = fgemm_batch_;
419 this->has_coo_ker = (coomm_ != NULL);
428 fmul = &CTF_int::default_mul<dtype>;
429 fgemm = &CTF_int::default_gemm<dtype>;
430 faxpy = &CTF_int::default_axpy<dtype>;
431 fscal = &CTF_int::default_scal<dtype>;
432 fcoomm = &CTF_int::default_coomm<dtype>;
433 fgemm_batch = &CTF_int::default_gemm_batch<dtype>;
440 ((dtype*)c)[0] = fmul(((dtype*)a)[0],((dtype*)b)[0]);
446 if (a == NULL && b == NULL){
449 }
else if (a == NULL) {
451 memcpy(c,b,this->el_size);
452 }
else if (b == NULL) {
454 memcpy(c,b,this->el_size);
457 ((dtype*)c)[0] = fmul(((dtype*)a)[0],((dtype*)b)[0]);
462 return (
char const *)&tmulid;
472 if (fscal != NULL) fscal(n, ((dtype
const *)alpha)[0], (dtype *)X, incX);
474 dtype
const a = ((dtype*)alpha)[0];
475 dtype * dX = (dtype*) X;
476 for (int64_t i=0; i<n; i++){
477 dX[i] = fmul(a,dX[i]);
489 if (faxpy != NULL) faxpy(n, ((dtype
const *)alpha)[0], (dtype
const *)X, incX, (dtype *)Y, incY);
493 dtype a = ((dtype*)alpha)[0];
494 dtype
const * dX = (dtype*) X;
495 dtype * dY = (dtype*) Y;
496 for (int64_t i=0; i<n; i++){
497 dY[i] = this->fadd(fmul(a,dX[i]), dY[i]);
514 fgemm(tA, tB, m, n, k, ((dtype
const *)alpha)[0], (dtype
const *)A, (dtype
const *)B, ((dtype
const *)beta)[0], (dtype *)C);
517 dtype
const * dA = (dtype
const *) A;
518 dtype
const * dB = (dtype
const *) B;
519 dtype * dC = (dtype*) C;
520 if (!this->isequal(beta, this->mulid())){
521 scal(m*n, beta, C, 1);
523 int lda_Cj, lda_Ci, lda_Al, lda_Ai, lda_Bj, lda_Bl;
542 if (!this->isequal(alpha, this->mulid())){
543 dtype a = ((dtype*)alpha)[0];
544 for (int64_t j=0; j<n; j++){
545 for (int64_t i=0; i<m; i++){
546 for (int64_t l=0; l<k; l++){
548 dC[j*lda_Cj+i*lda_Ci] = this->fadd(fmul(a,fmul(dA[l*lda_Al+i*lda_Ai],dB[j*lda_Bj+l*lda_Bl])), dC[j*lda_Cj+i*lda_Ci]);
553 for (int64_t j=0; j<n; j++){
554 for (int64_t i=0; i<m; i++){
555 for (int64_t l=0; l<k; l++){
557 dC[j*lda_Cj+i*lda_Ci] = this->fadd(fmul(dA[l*lda_Al+i*lda_Ai],dB[j*lda_Bj+l*lda_Bl]), dC[j*lda_Cj+i*lda_Ci]);
577 if (fgemm_batch != NULL) {
578 fgemm_batch(tA, tB, l, m, n, k, ((dtype
const *)alpha)[0], ((dtype
const *)A), ((dtype
const *)B), ((dtype
const *)beta)[0], ((dtype *)C));
580 for (
int i=0; i<l; i++){
581 gemm(tA, tB, m, n, k, alpha, A+m*k*i*
sizeof(dtype), B+k*n*i*
sizeof(dtype), beta, C+m*n*i*
sizeof(dtype));
596 printf(
"CTF ERROR: offload gemm not present for this semiring\n");
605 void coomm(
int m,
int n,
int k,
char const * alpha,
char const * A,
int const * rows_A,
int const * cols_A, int64_t nnz_A,
char const * B,
char const * beta,
char * C,
CTF_int::bivar_function const * func)
const {
606 if (func == NULL && alpha != NULL && fcoomm != NULL){
607 fcoomm(m, n, k, ((dtype
const *)alpha)[0], (dtype
const *)A, rows_A, cols_A, nnz_A, (dtype
const *)B, ((dtype
const *)beta)[0], (dtype *)C);
610 if (func == NULL && alpha != NULL && this->isequal(beta,mulid())){
612 dtype
const * dA = (dtype
const*)A;
613 dtype
const * dB = (dtype
const*)B;
614 dtype * dC = (dtype*)C;
615 dtype a = ((dtype*)alpha)[0];
616 if (!this->isequal(beta, this->mulid())){
617 scal(m*n, beta, C, 1);
619 for (int64_t i=0; i<nnz_A; i++){
620 int row_A = rows_A[i]-1;
621 int col_A = cols_A[i]-1;
622 for (
int col_C=0; col_C<n; col_C++){
623 dC[col_C*m+row_A] = this->fadd(fmul(a,fmul(dA[i],dB[col_C*k+col_A])), dC[col_C*m+row_A]);
627 }
else { assert(0); }
644 #pragma omp parallel for 646 for (
int row_A=0; row_A<m; row_A++){
648 #pragma omp parallel for 650 for (
int col_B=0; col_B<n; col_B++){
651 C[col_B*m+row_A] = this->fmul(beta,C[col_B*m+row_A]);
652 if (IA[row_A] < IA[row_A+1]){
653 int i_A1 = IA[row_A]-1;
654 int col_A1 = JA[i_A1]-1;
655 dtype tmp = this->fmul(A[i_A1],B[col_B*k+col_A1]);
656 for (
int i_A=IA[row_A]; i_A<IA[row_A+1]-1; i_A++){
657 int col_A = JA[i_A]-1;
658 tmp = this->fadd(tmp, this->fmul(A[i_A],B[col_B*k+col_A]));
660 C[col_B*m+row_A] = this->fadd(C[col_B*m+row_A], this->fmul(alpha,tmp));
681 assert(!this->has_coo_ker);
682 assert(func == NULL);
683 this->default_csrmm(m,n,k,((dtype*)alpha)[0],(dtype*)A,JA,IA,nnz_A,(dtype*)B,((dtype*)beta)[0],(dtype*)C);
686 void default_csrmultd
702 if (!this->isequal((
char const*)&beta, this->mulid())){
703 this->scal(m*n, (
char const *)&beta, (
char*)C, 1);
706 #pragma omp parallel for 708 for (
int row_A=0; row_A<m; row_A++){
709 for (
int i_A=IA[row_A]-1; i_A<IA[row_A+1]-1; i_A++){
710 int row_B = JA[i_A]-1;
711 for (
int i_B=IB[row_B]-1; i_B<IB[row_B+1]-1; i_B++){
712 int col_B = JB[i_B]-1;
713 if (!this->isequal((
char const*)&alpha, this->mulid()))
714 this->fadd(C[col_B*m+row_A], this->fmul(alpha,this->fmul(A[i_A],B[i_B])));
716 this->fadd(C[col_B*m+row_A], this->fmul(A[i_A],B[i_B]));
736 char *& C_CSR)
const {
738 memset(IC, 0,
sizeof(
int)*(m+1));
746 #pragma omp for schedule(dynamic) // TO DO test other strategies 748 for (
int i=0; i<m; i++){
749 memset(has_col, 0,
sizeof(
int)*(n+1));
751 for (
int j=0; j<IA[i+1]-IA[i]; j++){
752 int row_B = JA[IA[i]+j-1]-1;
753 for (
int kk=0; kk<IB[row_B+1]-IB[row_B]; kk++){
754 int idx_B = IB[row_B]+kk-1;
755 if (has_col[JB[idx_B]] == 0){
757 has_col[JB[idx_B]] = 1;
768 for(
int i=0;i < m+1; i++){
773 dtype * vC = (dtype*)C.
vals();
774 this->
set((
char *)vC, this->addid(), IC[m]+1);
776 memcpy(C.
IA(), IC,
sizeof(int)*(m+1));
789 for (
int i=0; i<m; i++){
790 std::fill(acc_data, acc_data+n, this->taddid);
791 memset(dcol, 0,
sizeof(
int)*(n));
793 for (
int j=0; j<IA[i+1]-IA[i]; j++){
794 int row_b = JA[IA[i]+j-1]-1;
795 int idx_a = IA[i]+j-1;
796 for (
int ii = 0; ii < IB[row_b+1]-IB[row_b]; ii++){
797 int col_b = IB[row_b]+ii-1;
798 int col_c = JB[col_b]-1;
799 dtype val = fmul(A[idx_a], B[col_b]);
800 if (dcol[col_c] == 0){
801 dcol[col_c] = JB[col_b];
804 acc_data[col_c]= this->fadd(acc_data[col_c], val);
807 for(
int jj = 0; jj < n; jj++){
809 JC[IC[i]+ins-1] = dcol[jj];
810 vC[IC[i]+ins-1] = acc_data[jj];
821 if (!this->isequal((
char const *)&alpha, this->mulid())){
822 this->scal(C.
nnz(), (
char const *)&alpha, C.
vals(), 1);
824 if (C_CSR == NULL || C_in.
nnz() == 0 || this->isequal((
char const *)&beta, this->addid())){
827 if (!this->isequal((
char const *)&beta, this->mulid())){
828 this->scal(C_in.
nnz(), (
char const *)&beta, C_in.
vals(), 1);
830 char * ans = this->csr_add(C_CSR, C.
all_data);
911 void default_csrmultcsr
925 char *& C_CSR)
const {
926 this->gen_csrmultcsr(m,n,k,alpha,A,JA,IA,nnz_A,B,JB,IB,nnz_B,beta,C_CSR);
945 this->default_csrmultd(m,n,k,((dtype
const*)alpha)[0],(dtype
const*)A,JA,IA,nnz_A,(dtype
const*)B,JB,IB,nnz_B,((dtype
const*)beta)[0],(dtype*)C);
963 char *& C_CSR)
const {
966 this->default_csrmultcsr(m,n,k,((dtype
const*)alpha)[0],(dtype
const*)A,JA,IA,nnz_A,(dtype
const*)B,JB,IB,nnz_B,((dtype
const*)beta)[0],C_CSR);
968 this->gen_csrmultcsr(m,n,k,((dtype
const*)alpha)[0],(dtype
const*)A,JA,IA,nnz_A,(dtype
const*)B,JB,IB,nnz_B,((dtype
const*)beta)[0],C_CSR);
979 void CTF::Semiring<float,1>::default_csrmm(
int,
int,
int,
float,
float const *,
int const *,
int const *,
int,
float const *,
float,
float *)
const;
981 void CTF::Semiring<double,1>::default_csrmm(
int,
int,
int,
double,
double const *,
int const *,
int const *,
int,
double const *,
double,
double *)
const;
983 void CTF::Semiring<std::complex<float>,0>::default_csrmm(
int,
int,
int,std::complex<float>,std::complex<float>
const *,
int const *,
int const *,
int,std::complex<float>
const *,std::complex<float>,std::complex<float> *)
const;
985 void CTF::Semiring<std::complex<double>,0>::default_csrmm(
int,
int,
int,std::complex<double>,std::complex<double>
const *,
int const *,
int const *,
int,std::complex<double>
const *,std::complex<double>,std::complex<double> *)
const;
989 void CTF::Semiring<float,1>::default_csrmultd(
int,
int,
int,
float,
float const *,
int const *,
int const *,
int,
float const *,
int const *,
int const *,
int,
float,
float *)
const;
991 void CTF::Semiring<double,1>::default_csrmultd(
int,
int,
int,
double,
double const *,
int const *,
int const *,
int,
double const *,
int const *,
int const *,
int,
double,
double *)
const;
993 void CTF::Semiring<std::complex<float>,0>::default_csrmultd(
int,
int,
int,std::complex<float>,std::complex<float>
const *,
int const *,
int const *,
int,std::complex<float>
const *,
int const *,
int const *,
int,std::complex<float>,std::complex<float> *)
const;
995 void CTF::Semiring<std::complex<double>,0>::default_csrmultd(
int,
int,
int,std::complex<double>,std::complex<double>
const *,
int const *,
int const *,
int,std::complex<double>
const *,
int const *,
int const *,
int,std::complex<double>,std::complex<double> *)
const;
998 void CTF::Semiring<float,1>::default_csrmultcsr(
int,
int,
int,
float,
float const *,
int const *,
int const *,
int,
float const *,
int const *,
int const *,
int,
float,
char *&)
const;
1000 void CTF::Semiring<double,1>::default_csrmultcsr(
int,
int,
int,
double,
double const *,
int const *,
int const *,
int,
double const *,
int const *,
int const *,
int,
double,
char *&)
const;
1002 void CTF::Semiring<std::complex<float>,0>::default_csrmultcsr(
int,
int,
int,std::complex<float>,std::complex<float>
const *,
int const *,
int const *,
int,std::complex<float>
const *,
int const *,
int const *,
int,std::complex<float>,
char *&)
const;
1004 void CTF::Semiring<std::complex<double>,0>::default_csrmultcsr(
int,
int,
int,std::complex<double>,std::complex<double>
const *,
int const *,
int const *,
int,std::complex<double>
const *,
int const *,
int const *,
int,std::complex<double>,
char *&)
const;
1012 bool CTF::Semiring<std::complex<float>,0>::is_offloadable()
const;
1014 bool CTF::Semiring<std::complex<double>,0>::is_offloadable()
const;
1017 void CTF::Semiring<double,1>::offload_gemm(
char,
char,
int,
int,
int,
char const *,
char const *,
char const *,
char const *,
char *)
const;
1019 void CTF::Semiring<double,1>::offload_gemm(
char,
char,
int,
int,
int,
char const *,
char const *,
char const *,
char const *,
char *)
const;
1021 void CTF::Semiring<std::complex<float>,0>
::offload_gemm(
char,
char,
int,
int,
int,
char const *,
char const *,
char const *,
char const *,
char *)
const;
1023 void CTF::Semiring<std::complex<double>,0>
::offload_gemm(
char,
char,
int,
int,
int,
char const *,
char const *,
char const *,
char const *,
char *)
const;
void(* fscal)(int, dtype, dtype *, int)
void scal(int n, char const *alpha, char *X, int incX) const
X["i"]=alpha*X["i"];.
void mul(char const *a, char const *b, char *c) const
c = a*b
void default_scal(int n, dtype alpha, dtype *X, int incX)
int * IA() const
retrieves prefix sum of number of nonzeros for each row (of size nrow()+1) out of all_data ...
void default_coomm< double >(int m, int n, int k, double alpha, double const *A, int const *rows_A, int const *cols_A, int nnz_A, double const *B, double beta, double *C)
void coomm(int m, int n, int k, char const *alpha, char const *A, int const *rows_A, int const *cols_A, int64_t nnz_A, char const *B, char const *beta, char *C, CTF_int::bivar_function const *func) const
sparse version of gemm using coordinate format for A
dtype ** get_grp_ptrs(int64_t grp_sz, int64_t ngrp, dtype const *data)
dtype(* fmul)(dtype a, dtype b)
Semiring is a Monoid with an addition multiplicaton function addition must have an identity and be as...
void * alloc(int64_t len)
alloc abstraction
void(* faxpy)(int, dtype, dtype const *, int, dtype *, int)
void default_csrmm(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, dtype beta, dtype *C) const
void csrmm(int m, int n, int k, char const *alpha, char const *A, int const *JA, int const *IA, int64_t nnz_A, char const *B, char const *beta, char *C, CTF_int::bivar_function const *func) const
sparse version of gemm using CSR format for A
void safemul(char const *a, char const *b, char *&c) const
c = a*b, with NULL treated as mulid
void gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
untyped internal class for triply-typed bivariate function
void default_scal< double >(int n, double alpha, double *X, int incX)
void default_coomm(int m, int n, int k, dtype alpha, dtype const *A, int const *rows_A, int const *cols_A, int nnz_A, dtype const *B, dtype beta, dtype *C)
void default_gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_coomm< float >(int m, int n, int k, float alpha, float const *A, int const *rows_A, int const *cols_A, int nnz_A, float const *B, float beta, float *C)
void default_axpy(int n, dtype alpha, dtype const *X, int incX, dtype *Y, int incY)
virtual CTF_int::algstrct * clone() const
''copy constructor''
void gemm_batch(char tA, char tB, int l, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
beta*C["ijl"]=alpha*A^tA["ikl"]*B^tB["kjl"];
void(* fgemm)(char, char, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)
void offload_gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
void default_gemm_batch< double >(char taA, char taB, int l, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
int * JA() const
retrieves column indices of each value in vals stored in sorted form by row
void default_csrmultcsr(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype beta, char *&C_CSR) const
void axpy(int n, char const *alpha, char const *X, int incX, char *Y, int incY) const
Y["i"]+=alpha*X["i"];.
void gemm(char tA, char tB, int m, int n, int k, char const *alpha, char const *A, char const *B, char const *beta, char *C) const
beta*C["ij"]=alpha*A^tA["ik"]*B^tB["kj"];
int64_t nnz() const
retrieves number of nonzeros out of all_data
char const * mulid() const
identity element for multiplication i.e. 1
abstraction for a serialized sparse matrix stored in column-sparse-row (CSR) layout ...
void default_csrmultd(int m, int n, int k, dtype alpha, dtype const *A, int const *JA, int const *IA, int nnz_A, dtype const *B, int const *JB, int const *IB, int nnz_B, dtype beta, dtype *C) const
void gemm_batch(char taA, char taB, int l, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_gemm< double >(char tA, char tB, int m, int n, int k, double alpha, double const *A, double const *B, double beta, double *C)
Semiring()
constructor for algstrct equipped with + only
bool is_offloadable() const
Semiring(dtype addid_, dtype(*fadd_)(dtype a, dtype b), MPI_Op addmop_, dtype mulid_, dtype(*fmul_)(dtype a, dtype b), void(*gemm_)(char, char, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)=NULL, void(*axpy_)(int, dtype, dtype const *, int, dtype *, int)=NULL, void(*scal_)(int, dtype, dtype *, int)=NULL, void(*coomm_)(int, int, int, dtype, dtype const *, int const *, int const *, int, dtype const *, dtype, dtype *)=NULL, void(*fgemm_batch_)(char, char, int, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)=NULL)
constructor for algstrct equipped with * and +
void default_gemm< float >(char tA, char tB, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
char * all_data
serialized buffer containing all info, index, and values related to matrix
char * vals() const
retrieves array of values out of all_data
int cdealloc(void *ptr)
free abstraction
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
void default_scal< float >(int n, float alpha, float *X, int incX)
A Monoid is a Set equipped with a binary addition operator '+' or a custom function addition must hav...
void default_axpy< float >(int n, float alpha, float const *X, int incX, float *Y, int incY)
void(* fgemm_batch)(char, char, int, int, int, int, dtype, dtype const *, dtype const *, dtype, dtype *)
Semiring(Semiring const &other)
void offload_gemm(char tA, char tB, int m, int n, int k, dtype alpha, offload_tsr &A, int lda_A, offload_tsr &B, int lda_B, dtype beta, offload_tsr &C, int lda_C)
void default_gemm(char tA, char tB, int m, int n, int k, dtype alpha, dtype const *A, dtype const *B, dtype beta, dtype *C)
void default_axpy< double >(int n, double alpha, double const *X, int incX, double *Y, int incY)
void(* fcoomm)(int, int, int, dtype, dtype const *, int const *, int const *, int, dtype const *, dtype, dtype *)
void default_gemm_batch< float >(char taA, char taB, int l, int m, int n, int k, float alpha, float const *A, float const *B, float beta, float *C)
dtype default_mul(dtype a, dtype b)