Cyclops Tensor Framework
parallel arithmetic on multidimensional arrays
lapack_symbs.cxx
Go to the documentation of this file.
1 #include <stdlib.h>
2 #include <complex>
3 #include <assert.h>
4 #include "lapack_symbs.h"
5 
6 #if FTN_UNDERSCORE
7 #define DGELSD dgelsd_
8 #define DGEQRF dgeqrf_
9 #define DORMQR dormqr_
10 #define PDGESVD pdgesvd_
11 #define PSGESVD psgesvd_
12 #define PCGESVD pcgesvd_
13 #define PZGESVD pzgesvd_
14 #define PSGEQRF psgeqrf_
15 #define PDGEQRF pdgeqrf_
16 #define PCGEQRF pcgeqrf_
17 #define PZGEQRF pzgeqrf_
18 #define PSORGQR psorgqr_
19 #define PDORGQR pdorgqr_
20 #define PCUNGQR pcungqr_
21 #define PZUNGQR pzungqr_
22 #define DESCINIT descinit_
23 #define BLACS_GRIDINFO blacs_gridinfo_
24 #define BLACS_GRIDINIT blacs_gridinit_
25 #else
26 #define DGELSD dgelsd
27 #define DGEQRF dgeqrf
28 #define DORMQR dormqr
29 #define PDGESVD pdgesvd
30 #define PSGESVD psgesvd
31 #define PCGESVD pcgesvd
32 #define PZGESVD pzgesvd
33 #define PSGEQRF psgeqrf
34 #define PDGEQRF pdgeqrf
35 #define PCGEQRF pcgeqrf
36 #define PZGEQRF pzgeqrf
37 #define PSORGQR psorgqr
38 #define PDORGQR pdorgqr
39 #define PCUNGQR pcungqr
40 #define PZUNGQR pzungqr
41 #define DESCINIT descinit
42 #define BLACS_GRIDINFO blacs_gridinfo
43 #define BLACS_GRIDINIT blacs_gridinit
44 #endif
45 
46 namespace CTF_LAPACK{
47 #ifdef USE_LAPACK
48  extern "C"
49  void DGELSD(int * m, int * n, int * k, double const * A, int * lda_A, double * B, int * lda_B, double * S, double * cond, int * rank, double * work, int * lwork, int * iwork, int * info);
50 
51  extern "C"
52  void DGEQRF(int const * M, int const * N, double * A, int const * LDA, double * TAU2, double * WORK, int const * LWORK, int * INFO);
53 
54  extern "C"
55  void DORMQR(char const * SIDE, char const * TRANS, int const * M, int const * N, int const * K, double const * A, int const * LDA, double const * TAU2, double * C, int const * LDC, double * WORK, int const * LWORK, int * INFO);
56 #endif
57 
58  void cdgelsd(int m, int n, int k, double const * A, int lda_A, double * B, int lda_B, double * S, double cond, int * rank, double * work, int lwork, int * iwork, int * info){
59 #ifdef USE_LAPACK
60  DGELSD(&m, &n, &k, A, &lda_A, B, &lda_B, S, &cond, rank, work, &lwork, iwork, info);
61 #else
62  assert(0);
63 #endif
64  }
65 
66  void cdgeqrf(int M, int N, double * A, int LDA, double * TAU2, double * WORK, int LWORK, int * INFO){
67 #ifdef USE_LAPACK
68  DGEQRF(&M, &N, A, &LDA, TAU2, WORK, &LWORK, INFO);
69 #else
70  assert(0);
71 #endif
72  }
73 
74  void cdormqr(char SIDE, char TRANS, int M, int N, int K, double const * A, int LDA, double const * TAU2, double * C, int LDC, double * WORK, int LWORK, int * INFO){
75 #ifdef USE_LAPACK
76  DORMQR(&SIDE, &TRANS, &M, &N, &K, A, &LDA, TAU2, C, &LDC, WORK, &LWORK, INFO);
77 #else
78  assert(0);
79 #endif
80  }
81 }
82 
83 namespace CTF_SCALAPACK{
84 #ifdef USE_SCALAPACK
85  extern "C"
86  void BLACS_GRIDINFO(int * icontxt, int * nprow, int * npcol, int * iprow, int * ipcol);
87 
88  extern "C"
89  void BLACS_GRIDINIT(int * icontxt, char * order, int * nprow, int * npcol);
90 
91  extern "C"
92  void PDGESVD( char *,
93  char *,
94  int *,
95  int *,
96  double *,
97  int *,
98  int *,
99  int *,
100  double *,
101  double *,
102  int *,
103  int *,
104  int *,
105  double *,
106  int *,
107  int *,
108  int *,
109  double *,
110  int *,
111  int *);
112 
113  extern "C"
114  void PSGESVD( char *,
115  char *,
116  int *,
117  int *,
118  float *,
119  int *,
120  int *,
121  int *,
122  float *,
123  float *,
124  int *,
125  int *,
126  int *,
127  float *,
128  int *,
129  int *,
130  int *,
131  float *,
132  int *,
133  int *);
134 
135  extern "C"
136  void PCGESVD( char *,
137  char *,
138  int *,
139  int *,
140  std::complex<float> *,
141  int *,
142  int *,
143  int *,
144  float *,
145  std::complex<float> *,
146  int *,
147  int *,
148  int *,
149  std::complex<float> *,
150  int *,
151  int *,
152  int *,
153  float *,
154  int *,
155  float *,
156  int *);
157 
158  extern "C"
159  void PZGESVD( char *,
160  char *,
161  int *,
162  int *,
163  std::complex<double> *,
164  int *,
165  int *,
166  int *,
167  double *,
168  std::complex<double> *,
169  int *,
170  int *,
171  int *,
172  std::complex<double> *,
173  int *,
174  int *,
175  int *,
176  double *,
177  int *,
178  double *,
179  int *);
180 
181 
182  extern "C"
183  void PSGEQRF(int *,
184  int *,
185  float *,
186  int *,
187  int *,
188  int const *,
189  float *,
190  float *,
191  int *,
192  int *);
193 
194  extern "C"
195  void PDGEQRF(int *,
196  int *,
197  double *,
198  int *,
199  int *,
200  int const *,
201  double *,
202  double *,
203  int *,
204  int *);
205 
206 
207  extern "C"
208  void PCGEQRF(int *,
209  int *,
210  std::complex<float> *,
211  int *,
212  int *,
213  int const *,
214  std::complex<float> *,
215  std::complex<float> *,
216  int *,
217  int *);
218 
219  extern "C"
220  void PZGEQRF(int *,
221  int *,
222  std::complex<double> *,
223  int *,
224  int *,
225  int const *,
226  std::complex<double> *,
227  std::complex<double> *,
228  int *,
229  int *);
230 
231 
232  extern "C"
233  void PSORGQR(int *,
234  int *,
235  int *,
236  float *,
237  int *,
238  int *,
239  int const *,
240  float *,
241  float *,
242  int *,
243  int *);
244 
245 
246 
247  extern "C"
248  void PDORGQR(int *,
249  int *,
250  int *,
251  double *,
252  int *,
253  int *,
254  int const *,
255  double *,
256  double *,
257  int *,
258  int *);
259 
260 
261  extern "C"
262  void PCUNGQR(int *,
263  int *,
264  int *,
265  std::complex<float> *,
266  int *,
267  int *,
268  int const *,
269  std::complex<float> *,
270  std::complex<float> *,
271  int *,
272  int *);
273 
274 
275 
276  extern "C"
277  void PZUNGQR(int *,
278  int *,
279  int *,
280  std::complex<double> *,
281  int *,
282  int *,
283  int const *,
284  std::complex<double> *,
285  std::complex<double> *,
286  int *,
287  int *);
288 
289 
290  extern "C"
291  void DESCINIT(int *, int *,
292 
293  int *, int *,
294 
295  int *, int *,
296 
297  int *, int *,
298 
299  int *, int *);
300 
301  extern "C"
302  void Cblacs_pinfo(int*, int*);
303  extern "C"
304  void Cblacs_get(int, int, int*);
305  extern "C"
306  void Cblacs_gridinit(int*, char*, int, int);
307  extern "C"
308  void Cblacs_gridinfo(int, int*, int*, int*, int*);
309  extern "C"
310  void Cblacs_gridmap(int*, int*, int, int, int);
311  extern "C"
312  void Cblacs_barrier(int , char*);
313  extern "C"
314  void Cblacs_gridexit(int);
315 #endif
316 
317 
318  template <>
319  void pgesvd<float>(char JOBU,
320  char JOBVT,
321  int M,
322  int N,
323  float * A,
324  int IA,
325  int JA,
326  int * DESCA,
327  float * S,
328  float * U,
329  int IU,
330  int JU,
331  int * DESCU,
332  float * VT,
333  int IVT,
334  int JVT,
335  int * DESCVT,
336  float * WORK,
337  int LWORK,
338  int * info) {
339 #ifdef USE_SCALAPACK
340  PSGESVD(&JOBU, &JOBVT, &M, &N, A, &IA, &JA, DESCA, S, U, &IU, &JU, DESCU, VT, &IVT, &JVT, DESCVT, WORK, &LWORK, info);
341 #else
342  assert(0);
343 #endif
344  }
345 
346  template <>
347  void pgesvd<double>(char JOBU,
348  char JOBVT,
349  int M,
350  int N,
351  double * A,
352  int IA,
353  int JA,
354  int * DESCA,
355  double * S,
356  double * U,
357  int IU,
358  int JU,
359  int * DESCU,
360  double * VT,
361  int IVT,
362  int JVT,
363  int * DESCVT,
364  double * WORK,
365  int LWORK,
366  int * info) {
367 #ifdef USE_SCALAPACK
368  PDGESVD(&JOBU, &JOBVT, &M, &N, A, &IA, &JA, DESCA, S, U, &IU, &JU, DESCU, VT, &IVT, &JVT, DESCVT, WORK, &LWORK, info);
369 #else
370  assert(0);
371 #endif
372  }
373 
374 
375  template <>
376  void pgesvd< std::complex<float> >(char JOBU,
377  char JOBVT,
378  int M,
379  int N,
380  std::complex<float> * A,
381  int IA,
382  int JA,
383  int * DESCA,
384  std::complex<float> * cS,
385  std::complex<float> * U,
386  int IU,
387  int JU,
388  int * DESCU,
389  std::complex<float> * VT,
390  int IVT,
391  int JVT,
392  int * DESCVT,
393  std::complex<float> * WORK,
394  int LWORK,
395  int * info) {
396 #ifdef USE_SCALAPACK
397  float * S = (float*)cS;
398  float * rwork;
399  rwork = new float[4*std::max(M,N)+1];
400  PCGESVD(&JOBU, &JOBVT, &M, &N, A, &IA, &JA, DESCA, S, U, &IU, &JU, DESCU, VT, &IVT, &JVT, DESCVT, (float*)WORK, &LWORK, rwork, info);
401  delete [] rwork;
402  if (LWORK != -1){
403  for (int i=std::min(M,N)-1; i>=0; i--){
404  cS[i].real(S[i]);
405  cS[i].imag(0.0);
406  }
407  }
408 #else
409  assert(0);
410 #endif
411  }
412 
413 
414  template <>
415  void pgesvd< std::complex<double> >(char JOBU,
416  char JOBVT,
417  int M,
418  int N,
419  std::complex<double> * A,
420  int IA,
421  int JA,
422  int * DESCA,
423  std::complex<double> * cS,
424  std::complex<double> * U,
425  int IU,
426  int JU,
427  int * DESCU,
428  std::complex<double> * VT,
429  int IVT,
430  int JVT,
431  int * DESCVT,
432  std::complex<double> * WORK,
433  int LWORK,
434  int * info){
435 #ifdef USE_SCALAPACK
436  double * S = (double*)cS;
437  double * rwork;
438  rwork = new double[4*std::max(M,N)+1];
439  PZGESVD(&JOBU, &JOBVT, &M, &N, A, &IA, &JA, DESCA, S, U, &IU, &JU, DESCU, VT, &IVT, &JVT, DESCVT, (double*)WORK, &LWORK, rwork, info);
440  delete [] rwork;
441  if (LWORK != -1){
442  for (int i=std::min(M,N)-1; i>=0; i--){
443  cS[i].real(S[i]);
444  cS[i].imag(0.0);
445  }
446  }
447 #else
448  assert(0);
449 #endif
450  }
451 
452  template <>
453  void pgeqrf<float>(int M,
454  int N,
455  float * A,
456  int IA,
457  int JA,
458  int const * DESCA,
459  float * TAU2,
460  float * WORK,
461  int LWORK,
462  int * INFO){
463 #ifdef USE_SCALAPACK
464  PSGEQRF(&M,&N,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
465 #else
466  assert(0);
467 #endif
468  }
469 
470  template <>
471  void pgeqrf<double>(int M,
472  int N,
473  double * A,
474  int IA,
475  int JA,
476  int const * DESCA,
477  double * TAU2,
478  double * WORK,
479  int LWORK,
480  int * INFO){
481 #ifdef USE_SCALAPACK
482  PDGEQRF(&M,&N,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
483 #else
484  assert(0);
485 #endif
486  }
487 
488  template <>
489  void pgeqrf< std::complex<float> >(int M,
490  int N,
491  std::complex<float> * A,
492  int IA,
493  int JA,
494  int const * DESCA,
495  std::complex<float> * TAU2,
496  std::complex<float> * WORK,
497  int LWORK,
498  int * INFO){
499 #ifdef USE_SCALAPACK
500  PCGEQRF(&M,&N,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
501 #else
502  assert(0);
503 #endif
504  }
505 
506 
507  template <>
508  void pgeqrf< std::complex<double> >(int M,
509  int N,
510  std::complex<double> * A,
511  int IA,
512  int JA,
513  int const * DESCA,
514  std::complex<double> * TAU2,
515  std::complex<double> * WORK,
516  int LWORK,
517  int * INFO){
518 #ifdef USE_SCALAPACK
519  PZGEQRF(&M,&N,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
520 #else
521  assert(0);
522 #endif
523  }
524 
525 
526  template <>
527  void porgqr<float>(int M,
528  int N,
529  int K,
530  float * A,
531  int IA,
532  int JA,
533  int const * DESCA,
534  float * TAU2,
535  float * WORK,
536  int LWORK,
537  int * INFO){
538 #ifdef USE_SCALAPACK
539  PSORGQR(&M,&N,&K,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
540 #else
541  assert(0);
542 #endif
543  }
544 
545  template <>
546  void porgqr<double>(int M,
547  int N,
548  int K,
549  double * A,
550  int IA,
551  int JA,
552  int const * DESCA,
553  double * TAU2,
554  double * WORK,
555  int LWORK,
556  int * INFO){
557 #ifdef USE_SCALAPACK
558  PDORGQR(&M,&N,&K,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
559 #else
560  assert(0);
561 #endif
562  }
563 
564  template <>
565  void porgqr< std::complex<float> >(int M,
566  int N,
567  int K,
568  std::complex<float> * A,
569  int IA,
570  int JA,
571  int const * DESCA,
572  std::complex<float> * TAU2,
573  std::complex<float> * WORK,
574  int LWORK,
575  int * INFO){
576 #ifdef USE_SCALAPACK
577  PCUNGQR(&M,&N,&K,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
578 #else
579  assert(0);
580 #endif
581  }
582 
583 
584  template <>
585  void porgqr< std::complex<double> >(int M,
586  int N,
587  int K,
588  std::complex<double> * A,
589  int IA,
590  int JA,
591  int const * DESCA,
592  std::complex<double> * TAU2,
593  std::complex<double> * WORK,
594  int LWORK,
595  int * INFO){
596 #ifdef USE_SCALAPACK
597  PZUNGQR(&M,&N,&K,A,&IA,&JA,DESCA,TAU2,WORK,&LWORK,INFO);
598 #else
599  assert(0);
600 #endif
601  }
602 
603 
604  void cdescinit(int * desc,
605  int m,
606  int n,
607  int mb,
608  int nb,
609  int irsrc,
610  int icsrc,
611  int ictxt,
612  int LLD,
613  int * info){
614 #ifdef USE_SCALAPACK
615  DESCINIT(desc,&m,&n,&mb,&nb,&irsrc,&icsrc,&ictxt, &LLD, info);
616 #else
617  assert(0);
618 #endif
619  }
620 
621  void cblacs_pinfo(int * mypnum, int * nprocs){
622 #ifdef USE_SCALAPACK
623  Cblacs_pinfo(mypnum, nprocs);
624 #else
625  assert(0);
626 #endif
627  }
628 
629  void cblacs_get(int contxt, int what, int * val){
630 #ifdef USE_SCALAPACK
631  Cblacs_get(contxt, what, val);
632 #else
633  assert(0);
634 #endif
635  }
636 
637  void cblacs_gridinit(int * contxt, char * row, int nprow, int npcol){
638 #ifdef USE_SCALAPACK
639  Cblacs_gridinit(contxt, row, nprow, npcol);
640 #else
641  assert(0);
642 #endif
643  }
644 
645  void cblacs_gridinfo(int contxt, int * nprow, int * npcol, int * myprow, int * mypcol){
646 #ifdef USE_SCALAPACK
647  Cblacs_gridinfo(contxt, nprow, npcol, myprow, mypcol);
648 #else
649  assert(0);
650 #endif
651  }
652 
653  void cblacs_gridmap(int * contxt, int * usermap, int ldup, int nprow0, int npcol0){
654 #ifdef USE_SCALAPACK
655  Cblacs_gridmap(contxt, usermap, ldup, nprow0, npcol0);
656 #else
657  assert(0);
658 #endif
659  }
660 
661  void cblacs_barrier(int contxt, char * scope){
662 #ifdef USE_SCALAPACK
663  Cblacs_barrier(contxt, scope);
664 #else
665  assert(0);
666 #endif
667  }
668 
669  void cblacs_gridexit(int contxt){
670 #ifdef USE_SCALAPACK
671  Cblacs_gridexit(contxt);
672 #else
673  assert(0);
674 #endif
675  }
676 
677 }
678 
#define PDGESVD
#define DORMQR
void pgeqrf< double >(int M, int N, double *A, int IA, int JA, int const *DESCA, double *TAU2, double *WORK, int LWORK, int *INFO)
void Cblacs_gridinit(int *, char *, int, int)
#define DGEQRF
#define BLACS_GRIDINFO
void cblacs_get(int contxt, int what, int *val)
void pgesvd< float >(char JOBU, char JOBVT, int M, int N, float *A, int IA, int JA, int *DESCA, float *S, float *U, int IU, int JU, int *DESCU, float *VT, int IVT, int JVT, int *DESCVT, float *WORK, int LWORK, int *info)
#define PCUNGQR
void cblacs_pinfo(int *mypnum, int *nprocs)
#define PCGEQRF
#define BLACS_GRIDINIT
def rank(self)
Definition: core.pyx:312
void Cblacs_get(int, int, int *)
int icontxt
Definition: hosvd.cxx:7
#define PDORGQR
#define PCGESVD
#define PSGEQRF
#define PSGESVD
#define DGELSD
void Cblacs_gridinfo(int, int *, int *, int *, int *)
void Cblacs_gridexit(int)
void porgqr< double >(int M, int N, int K, double *A, int IA, int JA, int const *DESCA, double *TAU2, double *WORK, int LWORK, int *INFO)
#define DESCINIT
#define PZUNGQR
void cblacs_gridinit(int *contxt, char *row, int nprow, int npcol)
void cdormqr(char SIDE, char TRANS, int M, int N, int K, double const *A, int LDA, double const *TAU2, double *C, int LDC, double *WORK, int LWORK, int *INFO)
void cdgeqrf(int M, int N, double *A, int LDA, double *TAU2, double *WORK, int LWORK, int *INFO)
void cblacs_gridexit(int contxt)
void Cblacs_gridmap(int *, int *, int, int, int)
void pgeqrf< float >(int M, int N, float *A, int IA, int JA, int const *DESCA, float *TAU2, float *WORK, int LWORK, int *INFO)
void porgqr< float >(int M, int N, int K, float *A, int IA, int JA, int const *DESCA, float *TAU2, float *WORK, int LWORK, int *INFO)
void Cblacs_barrier(int, char *)
#define PZGESVD
void cdgelsd(int m, int n, int k, double const *A, int lda_A, double *B, int lda_B, double *S, double cond, int *rank, double *work, int lwork, int *iwork, int *info)
#define PDGEQRF
void pgesvd< double >(char JOBU, char JOBVT, int M, int N, double *A, int IA, int JA, int *DESCA, double *S, double *U, int IU, int JU, int *DESCU, double *VT, int IVT, int JVT, int *DESCVT, double *WORK, int LWORK, int *info)
void cdescinit(int *desc, int m, int n, int mb, int nb, int irsrc, int icsrc, int ictxt, int LLD, int *info)
#define PZGEQRF
void cblacs_gridmap(int *contxt, int *usermap, int ldup, int nprow0, int npcol0)
void cblacs_gridinfo(int contxt, int *nprow, int *npcol, int *myprow, int *mypcol)
void cblacs_barrier(int contxt, char *scope)
void Cblacs_pinfo(int *, int *)
#define PSORGQR