5 #include "../shared/blas_symbs.h" 6 #include "../shared/lapack_symbs.h" 19 operator const int*()
const 28 template<
typename dtype>
34 template<
typename dtype>
42 template<
typename dtype>
65 template<
typename dtype>
73 world_, sr_, name_, profile_) {
79 template<
typename dtype>
88 world_, sr_, name_, profile_) {
94 template<
typename dtype>
102 world_, sr_, name_, profile_) {
109 template<
typename dtype>
118 world_, sr_, name_, profile_) {
125 template<
typename dtype>
137 world_, idx, prl, blk, name_, profile_, sr_) {
143 template<
typename dtype>
150 for (
int j=0; j<
ncol; j++){
151 this->
sr->
print((
char*)&(data[j*nrow+
i]));
152 if (j!=ncol-1) printf(
" ");
160 template<
typename dtype>
173 nmyr = mb*(nrow/mb/pr);
174 if ((nrow/mb)%pr > (rank+pr-rsrc)%pr){
177 if (((nrow/mb)%pr) == (rank+pr-rsrc)%pr){
180 nmyc = nb*(ncol/nb/pc);
181 if ((ncol/nb)%pc > (rank/pr+pc-csrc)%pc){
184 if (((ncol/nb)%pc) == (rank/pr+pc-csrc)%pc){
189 int cblk = (rank/pr+pc-csrc)%pc;
190 for (int64_t
i=0;
i<nmyc;
i++){
191 int rblk = (rank+pr-rsrc)%pr;
192 for (int64_t j=0; j<nmyr; j++){
193 pairs[
i*nmyr+j].
k = (cblk*nb+(
i%nb))*nrow+rblk*mb+(j%mb);
196 if ((j+1)%mb == 0) rblk += pr;
198 if ((
i+1)%nb == 0) cblk += pc;
202 template<
typename dtype>
210 dtype const * data_){
211 if (mb==1 && nb==1 &&
nrow%pr==0 &&
ncol%pc==0 && rsrc==0 && csrc==0){
216 for (int64_t
i=0;
i<
ncol/pc;
i++){
222 (*this)[
"ab"] = M[
"ab"];
227 get_my_kv_pair(this->
wrld->
rank,
nrow,
ncol, mb, nb, pr, pc, rsrc, csrc, nmyr, nmyc, pairs);
231 for (int64_t
i=0;
i<nmyr*nmyc;
i++){
232 pairs[
i].
d = data_[
i];
235 for (int64_t
i=0;
i<nmyc;
i++){
236 for (int64_t j=0; j<nmyr; j++){
237 pairs[
i*nmyr+j].
d = data_[
i*lda+j];
241 this->
write(nmyr*nmyc, pairs);
246 template<
typename dtype>
256 if (!this->
is_sparse && (mb==1 && nb==1 &&
nrow%pr==0 &&
ncol%pc==0 && rsrc==0 && csrc==0)){
261 for (int64_t
i=0;
i<
ncol/pc;
i++){
266 int plens[] = {pr, pc};
269 M[
"ab"] = (*this)[
"ab"];
270 M.
read_mat(mb, nb, pr, pc, rsrc, csrc, lda, data_);
275 get_my_kv_pair(this->
wrld->
rank,
nrow,
ncol, mb, nb, pr, pc, rsrc, csrc, nmyr, nmyc, pairs);
277 this->
read(nmyr*nmyc, pairs);
279 for (int64_t
i=0;
i<nmyr*nmyc;
i++){
280 data_[
i] = pairs[
i].
d;
284 for (int64_t
i=0;
i<nmyc;
i++){
285 for (int64_t j=0; j<nmyr; j++){
286 data_[
i*lda+j] = pairs[
i*nmyr+j].
d;
295 template<
typename dtype>
319 desc = (
int*)malloc(
sizeof(
int)*9);
331 template<
typename dtype>
335 int pr, pc, ipr, ipc;
340 read_mat(desc[4],desc[5],pr,pc,desc[6],desc[7],desc[8],data_);
343 template<
typename dtype>
359 wrld_, sr_, name_, profile_) {
363 write_mat(mb,nb,pr,pc,rsrc,csrc,lda,data);
368 static inline Idx_Partition get_map_from_desc(
int const * desc){
371 int pr, pc, ipr, ipc;
376 template<
typename dtype>
384 wrld_,
"ij", get_map_from_desc(desc),
Idx_Partition(), name_, profile_, sr_) {
389 int pr, pc, ipr, ipc;
395 write_mat(desc[4],desc[5],pr,pc,desc[6],desc[7],desc[8],data_);
398 template <
typename dtype>
413 inline int get_int_fromreal<std::complex<float>>(std::complex<float> r){
414 return (
int)r.real();
417 inline int get_int_fromreal<std::complex<double>>(std::complex<double> r){
418 return (
int)r.real();
424 template<
typename dtype>
442 CTF_SCALAPACK::pgeqrf<dtype>(m,n,A,1,1,desca,tau,(
dtype*)&dlwork,-1,&info);
443 int lwork = get_int_fromreal<dtype>(dlwork);
445 CTF_SCALAPACK::pgeqrf<dtype>(m,n,A,1,1,desca,tau,work,lwork,&info);
456 R = R.
slice(0,((int64_t)m)*(n-1)+n-1);
461 CTF_SCALAPACK::porgqr<dtype>(m,n,n,dQ,1,1,desca,tau,(
dtype*)&dlwork,-1,&info);
462 lwork = get_int_fromreal<dtype>(dlwork);
463 work = (
dtype*)malloc(((int64_t)lwork)*
sizeof(
dtype));
464 CTF_SCALAPACK::porgqr<dtype>(m,n,n,dQ,1,1,desca,tau,work,lwork,&info);
470 int syns[] = {
SY,
NS};
472 int nsns[] = {
NS, NS};
480 template<
typename dtype>
487 int k = std::min(m,n);
491 int * descu = (
int*)malloc(9*
sizeof(
int));
492 int * descvt = (
int*)malloc(9*
sizeof(
int));
501 int64_t mpr = m/pr + (m % pr != 0);
502 int64_t kpr = k/pr + (k % pr != 0);
503 int64_t kpc = k/pc + (k % pc != 0);
504 int64_t npc = n/pc + (n % pc != 0);
519 CTF_SCALAPACK::pgesvd<dtype>(
'V',
'V', m, n, NULL, 1, 1, desca, NULL, NULL, 1, 1, descu, vt, 1, 1, descvt, &dlwork, -1, &info);
521 lwork = get_int_fromreal<dtype>(dlwork);
524 CTF_SCALAPACK::pgesvd<dtype>(
'V',
'V', m, n, A, 1, 1, desca, s, u, 1, 1, descu, vt, 1, 1, descvt, work, lwork, &info);
534 int phase = S.
edge_map[0].calc_phase();
535 if ((
int)(this->
wrld->
rank) < phase){
536 for (
int i = S.
edge_map[0].calc_phys_rank(S.
topo);
i < k;
i += phase) {
537 s_data[
i/phase] = s[
i];
540 if (rank > 0 && rank < k) {
541 S = S.
slice(0, rank-1);
542 U = U.slice(0, rank*((int64_t)m)-1);
543 VT = VT.
slice(0, k*((int64_t)n)-(k-rank+1));
void write_mat(int mb, int nb, int pr, int pc, int rsrc, int csrc, int lda, dtype const *data)
writes a nonsymmetric matrix from a block-cyclic initial distribution this is `cheap' if mb=nb=1...
int * sym
symmetries among tensor dimensions
void qr(Matrix< dtype > &Q, Matrix< dtype > &R)
int calc_phase() const
compute the phase of a mapping
dtype d
tensor value associated with index
void cblacs_get(int contxt, int what, int *val)
Matrix class which encapsulates a 2D tensor.
Typ_Idx_Tensor< dtype > i(char const *idx_map)
Tensor< dtype > slice(int const *offsets, int const *ends) const
cuts out a slice (block) of this tensor A[offsets,ends) result will always be fully nonsymmetric ...
int * pad_edge_len
padded tensor edge lengths
int64_t size
current size of local tensor data chunk (mapping-dependent)
Vector class which encapsulates a 1D tensor.
void * alloc(int64_t len)
alloc abstraction
void read_all(int64_t *npair, dtype **data, bool unpack=false)
collects the entire tensor data on each process (not memory scalable)
an instance of the CTF library (world) on a MPI communicator
void read_mat(int mb, int nb, int pr, int pc, int rsrc, int csrc, int lda, dtype *data)
reads a nonsymmetric matrix into a block-cyclic initial distribution this is `cheap' if mb=nb=1...
bool is_sparse
whether only the non-zero elements of the tensor are stored
int order
number of tensor dimensions
int64_t k
key, global index [i1,i2,...] specified as i1+len[0]*i2+...
index-value pair used for tensor data input
CTF::World * wrld
distributed processor context on which tensor is defined
int rank
rank of local processor
int * lens
unpadded tensor edge lengths
void svd(Matrix< dtype > &U, Vector< dtype > &S, Matrix< dtype > &VT, int rank=0)
int get_int_fromreal< double >(double r)
std::set< grid_wrapper > scalapack_grids
index for ScaLAPACK processor grids
void get_desc(int &ictxt, int *&desc)
get a ScaLAPACK descriptor for this Matrix, will always be in pure cyclic layout
algstrct * sr
algstrct on which tensor elements and operations are defined
mapping * edge_map
mappings of each tensor dimension onto topology dimensions
void cblacs_gridinit(int *contxt, char *row, int nprow, int npcol)
int get_int_fromreal(dtype r)
virtual void print(char const *a, FILE *fp=stdout) const
prints the value
int get_int_fromreal< float >(float r)
void get_my_kv_pair(int rank, int nrow, int ncol, int mb, int nb, int pr, int pc, int rsrc, int csrc, int64_t &nmyr, int64_t &nmyc, Pair< dtype > *&pairs)
int cdealloc(void *ptr)
free abstraction
dtype * get_raw_data(int64_t *size) const
gives the raw current local data with padding included
algstrct (algebraic structure) defines the elementwise operations computed in each tensor contraction...
char * data
tensor data, either the data or the key-value pairs should exist at any given time ...
an instance of a tensor within a CTF world
Matrix()
default constructor for a matrix
void read(int64_t npair, Pair< dtype > *pairs)
Gives the values associated with any set of indices.
topology * topo
topology to which the tensor is mapped
void cdescinit(int *desc, int m, int n, int mb, int nb, int irsrc, int icsrc, int ictxt, int LLD, int *info)
void cblacs_gridinfo(int contxt, int *nprow, int *npcol, int *myprow, int *mypcol)
void write(int64_t npair, int64_t const *global_idx, dtype const *data)
writes in values associated with any set of indices The sparse data is defined in coordinate format...
int np
number of processors
MPI_Comm comm
set of processors making up this world