41 #include "gmm_blas_interface.h"
42 #include "gmm_dense_lu.h"
43 #include "gmm_dense_qr.h"
45 #if defined(GMM_USES_LAPACK) && !defined(GMM_MATLAB_INTERFACE)
47 namespace gmm {
49  /* ********************************************************************** */
50  /* Operations interfaced for T = float, double, std::complex<float> */
51  /* or std::complex<double> : */
52  /* */
53  /* lu_factor(dense_matrix<T>, std::vector<long>) */
54  /* lu_solve(dense_matrix<T>, std::vector<T>, std::vector<T>) */
55  /* lu_solve(dense_matrix<T>, std::vector<long>, std::vector<T>, */
56  /* std::vector<T>) */
57  /* lu_solve_transposed(dense_matrix<T>, std::vector<long>, std::vector<T>,*/
58  /* std::vector<T>) */
59  /* lu_inverse(dense_matrix<T>) */
60  /* lu_inverse(dense_matrix<T>, std::vector<long>, dense_matrix<T>) */
61  /* */
62  /* qr_factor(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
63  /* */
64  /* implicit_qr_algorithm(dense_matrix<T>, std::vector<T>) */
65  /* implicit_qr_algorithm(dense_matrix<T>, std::vector<T>, */
66  /* dense_matrix<T>) */
67  /* implicit_qr_algorithm(dense_matrix<T>, std::vector<std::complex<T> >) */
68  /* implicit_qr_algorithm(dense_matrix<T>, std::vector<std::complex<T> >, */
69  /* dense_matrix<T>) */
70  /* */
71  /* geev_interface_right */
72  /* geev_interface_left */
73  /* */
74  /* schur(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
75  /* */
76  /* svd(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>, std::vector<T>) */
77  /* svd(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>, */
78  /* std::vector<std::complex<T> >) */
79  /* */
80  /* ********************************************************************** */
82  /* ********************************************************************** */
83  /* LAPACK functions used. */
84  /* ********************************************************************** */
86  extern "C" {
87  void sgetrf_(...); void dgetrf_(...); void cgetrf_(...); void zgetrf_(...);
88  void sgetrs_(...); void dgetrs_(...); void cgetrs_(...); void zgetrs_(...);
89  void sgetri_(...); void dgetri_(...); void cgetri_(...); void zgetri_(...);
90  void sgeqrf_(...); void dgeqrf_(...); void cgeqrf_(...); void zgeqrf_(...);
91  void sorgqr_(...); void dorgqr_(...); void cungqr_(...); void zungqr_(...);
92  void sormqr_(...); void dormqr_(...); void cunmqr_(...); void zunmqr_(...);
93  void sgees_ (...); void dgees_ (...); void cgees_ (...); void zgees_ (...);
94  void sgeev_ (...); void dgeev_ (...); void cgeev_ (...); void zgeev_ (...);
95  void sgeesx_(...); void dgeesx_(...); void cgeesx_(...); void zgeesx_(...);
96  void sgesvd_(...); void dgesvd_(...); void cgesvd_(...); void zgesvd_(...);
97  }
99  /* ********************************************************************** */
100  /* LU decomposition. */
101  /* ********************************************************************** */
103 # define getrf_interface(lapack_name, base_type) inline \
104  size_type lu_factor(dense_matrix<base_type> &A, lapack_ipvt &ipvt) { \
105  GMMLAPACK_TRACE("getrf_interface"); \
106  const BLAS_INT m=BLAS_INT(mat_nrows(A)), n=BLAS_INT(mat_ncols(A)), lda(m);\
107  BLAS_INT info(-1); \
108  if (m && n) lapack_name(&m, &n, &A(0,0), &lda, &ipvt[0], &info); \
109  return size_type(abs(info)); \
110  }
112  getrf_interface(sgetrf_, BLAS_S)
113  getrf_interface(dgetrf_, BLAS_D)
114  getrf_interface(cgetrf_, BLAS_C)
115  getrf_interface(zgetrf_, BLAS_Z)
117  /* ********************************************************************* */
118  /* LU solve. */
119  /* ********************************************************************* */
121 # define getrs_interface(f_name, trans, lapack_name, base_type) inline \
122  void f_name(const dense_matrix<base_type> &A, \
123  const lapack_ipvt &ipvt, std::vector<base_type> &x, \
124  const std::vector<base_type> &b) { \
125  GMMLAPACK_TRACE("getrs_interface"); \
126  const BLAS_INT n=BLAS_INT(mat_nrows(A)), nrhs(1); \
127  BLAS_INT info(0); gmm::copy(b, x); trans; \
128  if (n) \
129  lapack_name(&t, &n, &nrhs, &A(0,0), &n, &ipvt[0], &x[0], &n, &info); \
130  }
132 # define getrs_trans_n const char t = 'N'
133 # define getrs_trans_t const char t = 'T'
135  getrs_interface(lu_solve, getrs_trans_n, sgetrs_, BLAS_S)
136  getrs_interface(lu_solve, getrs_trans_n, dgetrs_, BLAS_D)
137  getrs_interface(lu_solve, getrs_trans_n, cgetrs_, BLAS_C)
138  getrs_interface(lu_solve, getrs_trans_n, zgetrs_, BLAS_Z)
139  getrs_interface(lu_solve_transposed, getrs_trans_t, sgetrs_, BLAS_S)
140  getrs_interface(lu_solve_transposed, getrs_trans_t, dgetrs_, BLAS_D)
141  getrs_interface(lu_solve_transposed, getrs_trans_t, cgetrs_, BLAS_C)
142  getrs_interface(lu_solve_transposed, getrs_trans_t, zgetrs_, BLAS_Z)
144  /* ********************************************************************* */
145  /* LU inverse. */
146  /* ********************************************************************* */
148 # define getri_interface(lapack_name, base_type) \
149  inline void lu_inverse(const dense_matrix<base_type> &LU, \
150  const lapack_ipvt &ipvt, \
151  dense_matrix<base_type> &A) { \
152  GMMLAPACK_TRACE("getri_interface"); \
153  const BLAS_INT n=BLAS_INT(mat_nrows(A)); \
154  BLAS_INT info(0), lwork(-1); base_type work1; \
155  if (n) { \
156  gmm::copy(LU, A); \
157  lapack_name(&n, &A(0,0), &n, &ipvt[0], &work1, &lwork, &info); \
158  lwork = int(gmm::real(work1)); \
159  std::vector<base_type> work(lwork); \
160  lapack_name(&n, &A(0,0), &n, &ipvt[0], &work[0], &lwork, &info); \
161  } \
162  }
164  getri_interface(sgetri_, BLAS_S)
165  getri_interface(dgetri_, BLAS_D)
166  getri_interface(cgetri_, BLAS_C)
167  getri_interface(zgetri_, BLAS_Z)
169  /* ********************************************************************** */
170  /* QR factorization. */
171  /* ********************************************************************** */
173 # define geqrf_interface(lapack_name, base_type) \
174  inline void qr_factor(dense_matrix<base_type> &A) { \
175  GMMLAPACK_TRACE("geqrf_interface"); \
176  const BLAS_INT m=BLAS_INT(mat_nrows(A)), n=BLAS_INT(mat_ncols(A)); \
177  BLAS_INT info(0), lwork(-1); base_type work1; \
178  if (m && n) { \
179  std::vector<base_type> tau(n); \
180  lapack_name(&m, &n, &A(0,0), &m, &tau[0], &work1, &lwork, &info); \
181  lwork = BLAS_INT(gmm::real(work1)); \
182  std::vector<base_type> work(lwork); \
183  lapack_name(&m, &n, &A(0,0), &m, &tau[0], &work[0], &lwork, &info); \
184  GMM_ASSERT1(!info, "QR factorization failed"); \
185  } \
186  }
188  geqrf_interface(sgeqrf_, BLAS_S)
189  geqrf_interface(dgeqrf_, BLAS_D)
190  // For complex values, housholder vectors are not the same as in
191  // gmm::lu_factor. Impossible to interface for the moment.
192  // geqrf_interface(cgeqrf_, BLAS_C)
193  // geqrf_interface(zgeqrf_, BLAS_Z)
195 # define geqrf_interface2(lapack_name1, lapack_name2, base_type) inline \
196  void qr_factor(const dense_matrix<base_type> &A, \
197  dense_matrix<base_type> &Q, dense_matrix<base_type> &R) { \
198  GMMLAPACK_TRACE("geqrf_interface2"); \
199  const BLAS_INT m=BLAS_INT(mat_nrows(A)), n=BLAS_INT(mat_ncols(A)); \
200  BLAS_INT info(0), lwork(-1); base_type work1; \
201  if (m && n) { \
202  std::copy(A.begin(), A.end(), Q.begin()); \
203  std::vector<base_type> tau(n); \
204  lapack_name1(&m, &n, &Q(0,0), &m, &tau[0], &work1 , &lwork, &info); \
205  lwork = BLAS_INT(gmm::real(work1)); \
206  std::vector<base_type> work(lwork); \
207  lapack_name1(&m, &n, &Q(0,0), &m, &tau[0], &work[0], &lwork, &info); \
208  GMM_ASSERT1(!info, "QR factorization failed"); \
209  base_type *p = &R(0,0), *q = &Q(0,0); \
210  for (BLAS_INT j = 0; j < n; ++j, q += m-n) \
211  for (BLAS_INT i = 0; i < n; ++i, ++p, ++q) \
212  *p = (j < i) ? base_type(0) : *q; \
213  lapack_name2(&m, &n, &n, &Q(0,0), &m,&tau[0],&work[0],&lwork,&info); \
214  } \
215  else gmm::clear(Q); \
216  }
218  geqrf_interface2(sgeqrf_, sorgqr_, BLAS_S)
219  geqrf_interface2(dgeqrf_, dorgqr_, BLAS_D)
220  geqrf_interface2(cgeqrf_, cungqr_, BLAS_C)
221  geqrf_interface2(zgeqrf_, zungqr_, BLAS_Z)
223  /* ********************************************************************** */
224  /* QR algorithm for eigenvalues search. */
225  /* ********************************************************************** */
227 # define gees_interface(lapack_name, base_type) \
228  template <typename VECT> inline void implicit_qr_algorithm( \
229  const dense_matrix<base_type> &A, VECT &eigval_, \
230  dense_matrix<base_type> &Q, \
231  double tol=gmm::default_tol(base_type()), bool compvect = true) { \
232  GMMLAPACK_TRACE("gees_interface"); \
233  typedef bool (*L_fp)(...); L_fp p = 0; \
234  BLAS_INT n=BLAS_INT(mat_nrows(A)), info(0), lwork(-1), sdim; \
235  base_type work1; \
236  if (!n) return; \
237  dense_matrix<base_type> H(n,n); gmm::copy(A, H); \
238  char jobvs = (compvect ? 'V' : 'N'), sort = 'N'; \
239  std::vector<double> rwork(n), eigv1(n), eigv2(n); \
240  lapack_name(&jobvs, &sort, p, &n, &H(0,0), &n, &sdim, &eigv1[0], \
241  &eigv2[0], &Q(0,0), &n, &work1, &lwork, &rwork[0], &info); \
242  lwork = BLAS_INT(gmm::real(work1)); \
243  std::vector<base_type> work(lwork); \
244  lapack_name(&jobvs, &sort, p, &n, &H(0,0), &n, &sdim, &eigv1[0], \
245  &eigv2[0], &Q(0,0), &n, &work[0], &lwork, &rwork[0],&info);\
246  GMM_ASSERT1(!info, "QR algorithm failed"); \
247  extract_eig(H, eigval_, tol); \
248  }
250 # define gees_interface2(lapack_name, base_type) \
251  template <typename VECT> inline void implicit_qr_algorithm( \
252  const dense_matrix<base_type> &A, VECT &eigval_, \
253  dense_matrix<base_type> &Q, \
254  double tol=gmm::default_tol(base_type()), bool compvect = true) { \
255  GMMLAPACK_TRACE("gees_interface2"); \
256  typedef bool (*L_fp)(...); L_fp p = 0; \
257  BLAS_INT n=BLAS_INT(mat_nrows(A)), info(0), lwork(-1), sdim; \
258  base_type work1; \
259  if (!n) return; \
260  dense_matrix<base_type> H(n,n); gmm::copy(A, H); \
261  char jobvs = (compvect ? 'V' : 'N'), sort = 'N'; \
262  std::vector<double> rwork(n), eigvv(n*2); \
263  lapack_name(&jobvs, &sort, p, &n, &H(0,0), &n, &sdim, &eigvv[0], \
264  &Q(0,0), &n, &work1, &lwork, &rwork[0], &rwork[0], &info); \
265  lwork = BLAS_INT(gmm::real(work1)); \
266  std::vector<base_type> work(lwork); \
267  lapack_name(&jobvs, &sort, p, &n, &H(0,0), &n, &sdim, &eigvv[0], \
268  &Q(0,0), &n, &work[0], &lwork, &rwork[0], &rwork[0],&info);\
269  GMM_ASSERT1(!info, "QR algorithm failed"); \
270  extract_eig(H, eigval_, tol); \
271  }
273  gees_interface(sgees_, BLAS_S)
274  gees_interface(dgees_, BLAS_D)
275  gees_interface2(cgees_, BLAS_C)
276  gees_interface2(zgees_, BLAS_Z)
279 # define jobv_right char jobvl = 'N', jobvr = 'V';
280 # define jobv_left char jobvl = 'V', jobvr = 'N';
282 # define geev_interface(lapack_name, base_type, side) \
283  template <typename VECT> inline void geev_interface_ ## side( \
284  const dense_matrix<base_type> &A, VECT &eigval_, \
285  dense_matrix<base_type> &Q) { \
286  GMMLAPACK_TRACE("geev_interface"); \
287  BLAS_INT n = BLAS_INT(mat_nrows(A)), info(0), lwork(-1); \
288  base_type work1; \
289  if (!n) return; \
290  dense_matrix<base_type> H(n,n); gmm::copy(A, H); \
291  jobv_ ## side \
292  std::vector<base_type> eigvr(n), eigvi(n); \
293  lapack_name(&jobvl, &jobvr, &n, &H(0,0), &n, &eigvr[0], &eigvi[0], \
294  &Q(0,0), &n, &Q(0,0), &n, &work1, &lwork, &info); \
295  lwork = BLAS_INT(gmm::real(work1)); \
296  std::vector<base_type> work(lwork); \
297  lapack_name(&jobvl, &jobvr, &n, &H(0,0), &n, &eigvr[0], &eigvi[0], \
298  &Q(0,0), &n, &Q(0,0), &n, &work[0], &lwork, &info); \
299  GMM_ASSERT1(!info, "QR algorithm failed"); \
300  gmm::copy(eigvr, gmm::real_part(eigval_)); \
301  gmm::copy(eigvi, gmm::imag_part(eigval_)); \
302  }
304 # define geev_interface2(lapack_name, base_type, side) \
305  template <typename VECT> inline void geev_interface_ ## side( \
306  const dense_matrix<base_type> &A, VECT &eigval_, \
307  dense_matrix<base_type> &Q) { \
308  GMMLAPACK_TRACE("geev_interface"); \
309  BLAS_INT n = BLAS_INT(mat_nrows(A)), info(0), lwork(-1); \
310  base_type work1; \
311  if (!n) return; \
312  dense_matrix<base_type> H(n,n); gmm::copy(A, H); \
313  jobv_ ## side \
314  std::vector<base_type::value_type> rwork(2*n); \
315  std::vector<base_type> eigv(n); \
316  lapack_name(&jobvl, &jobvr, &n, &H(0,0), &n, &eigv[0], &Q(0,0), &n, \
317  &Q(0,0), &n, &work1, &lwork, &rwork[0], &info); \
318  lwork = BLAS_INT(gmm::real(work1)); \
319  std::vector<base_type> work(lwork); \
320  lapack_name(&jobvl, &jobvr, &n, &H(0,0), &n, &eigv[0], &Q(0,0), &n, \
321  &Q(0,0), &n, &work[0], &lwork, &rwork[0], &info); \
322  GMM_ASSERT1(!info, "QR algorithm failed"); \
323  gmm::copy(eigv, eigval_); \
324  }
326  geev_interface(sgeev_, BLAS_S, right)
327  geev_interface(dgeev_, BLAS_D, right)
328  geev_interface2(cgeev_, BLAS_C, right)
329  geev_interface2(zgeev_, BLAS_Z, right)
331  geev_interface(sgeev_, BLAS_S, left)
332  geev_interface(dgeev_, BLAS_D, left)
333  geev_interface2(cgeev_, BLAS_C, left)
334  geev_interface2(zgeev_, BLAS_Z, left)
337  /* ********************************************************************** */
338  /* SCHUR algorithm: */
339  /* A = Q*S*(Q^T), with Q orthogonal and S upper quasi-triangula */
340  /* ********************************************************************** */
342 # define geesx_interface(lapack_name, base_type) \
343  inline void schur(dense_matrix<base_type> &A, \
344  dense_matrix<base_type> &S, \
345  dense_matrix<base_type> &Q) { \
346  GMMLAPACK_TRACE("geesx_interface"); \
347  const BLAS_INT m=BLAS_INT(mat_nrows(A)), n=BLAS_INT(mat_ncols(A)); \
348  GMM_ASSERT1(m == n, "Schur decomposition requires square matrix"); \
349  char jobvs = 'V', sort = 'N', sense = 'N'; \
350  bool select = false; \
351  BLAS_INT lwork = 8*n, sdim = 0, liwork = 1; \
352  std::vector<base_type> work(lwork), wr(n), wi(n); \
353  std::vector<BLAS_INT> iwork(liwork); \
354  std::vector<BLAS_INT> bwork(1); \
355  resize(S, n, n); copy(A, S); \
356  resize(Q, n, n); \
357  base_type rconde(0), rcondv(0); \
358  BLAS_INT info(0); \
359  lapack_name(&jobvs, &sort, &select, &sense, &n, &S(0,0), &n, \
360  &sdim, &wr[0], &wi[0], &Q(0,0), &n, &rconde, &rcondv, \
361  &work[0], &lwork, &iwork[0], &liwork, &bwork[0], &info);\
362  GMM_ASSERT1(!info, "SCHUR algorithm failed"); \
363  }
365 # define geesx_interface2(lapack_name, base_type) \
366  inline void schur(dense_matrix<base_type> &A, \
367  dense_matrix<base_type> &S, \
368  dense_matrix<base_type> &Q) { \
369  GMMLAPACK_TRACE("geesx_interface"); \
370  const BLAS_INT m=BLAS_INT(mat_nrows(A)), n=BLAS_INT(mat_ncols(A)); \
371  GMM_ASSERT1(m == n, "Schur decomposition requires square matrix"); \
372  char jobvs = 'V', sort = 'N', sense = 'N'; \
373  bool select = false; \
374  BLAS_INT lwork = 8*n, sdim = 0; \
375  std::vector<base_type::value_type> rwork(lwork); \
376  std::vector<base_type> work(lwork), w(n); \
377  std::vector<BLAS_INT> bwork(1); \
378  resize(S, n, n); copy(A, S); \
379  resize(Q, n, n); \
380  base_type rconde(0), rcondv(0); \
381  BLAS_INT info(0); \
382  lapack_name(&jobvs, &sort, &select, &sense, &n, &S(0,0), &n, \
383  &sdim, &w[0], &Q(0,0), &n, &rconde, &rcondv, \
384  &work[0], &lwork, &rwork[0], &bwork[0], &info); \
385  GMM_ASSERT1(!info, "SCHUR algorithm failed"); \
386  }
388  geesx_interface(sgeesx_, BLAS_S)
389  geesx_interface(dgeesx_, BLAS_D)
390  geesx_interface2(cgeesx_, BLAS_C)
391  geesx_interface2(zgeesx_, BLAS_Z)
393  template <typename MAT>
394  void schur(const MAT &A_, MAT &S, MAT &Q) {
395  MAT A(A_);
396  schur(A, S, Q);
397  }
400  /* ********************************************************************** */
401  /* Interface to SVD. Does not correspond to a Gmm++ functionnality. */
402  /* Author : Sebastian Nowozin <[email protected]> */
403  /* ********************************************************************** */
405 # define gesvd_interface(lapack_name, base_type) \
406  inline void svd(dense_matrix<base_type> &X, \
407  dense_matrix<base_type> &U, \
408  dense_matrix<base_type> &Vtransposed, \
409  std::vector<base_type> &sigma) { \
410  GMMLAPACK_TRACE("gesvd_interface"); \
411  BLAS_INT m = BLAS_INT(mat_nrows(X)), n = BLAS_INT(mat_ncols(X)); \
412  BLAS_INT mn_min = m < n ? m : n; \
413  sigma.resize(mn_min); \
414  std::vector<base_type> work(15 * mn_min); \
415  BLAS_INT lwork = BLAS_INT(work.size()); \
416  resize(U, m, m); \
417  resize(Vtransposed, n, n); \
418  char job = 'A'; \
419  BLAS_INT info(0); \
420  lapack_name(&job, &job, &m, &n, &X(0,0), &m, &sigma[0], &U(0,0), \
421  &m, &Vtransposed(0,0), &n, &work[0], &lwork, &info); \
422  }
424 # define cgesvd_interface(lapack_name, base_type, base_type2) \
425  inline void svd(dense_matrix<base_type> &X, \
426  dense_matrix<base_type> &U, \
427  dense_matrix<base_type> &Vtransposed, \
428  std::vector<base_type2> &sigma) { \
429  GMMLAPACK_TRACE("gesvd_interface"); \
430  BLAS_INT m = BLAS_INT(mat_nrows(X)), n = BLAS_INT(mat_ncols(X)); \
431  BLAS_INT mn_min = m < n ? m : n; \
432  sigma.resize(mn_min); \
433  std::vector<base_type> work(15 * mn_min); \
434  std::vector<base_type2> rwork(5 * mn_min); \
435  BLAS_INT lwork = BLAS_INT(work.size()); \
436  resize(U, m, m); \
437  resize(Vtransposed, n, n); \
438  char job = 'A'; \
439  BLAS_INT info(0); \
440  lapack_name(&job, &job, &m, &n, &X(0,0), &m, &sigma[0], &U(0,0), \
441  &m, &Vtransposed(0,0), &n, &work[0], &lwork, \
442  &rwork[0], &info); \
443  }
445  gesvd_interface(sgesvd_, BLAS_S)
446  gesvd_interface(dgesvd_, BLAS_D)
447  cgesvd_interface(cgesvd_, BLAS_C, BLAS_S)
448  cgesvd_interface(zgesvd_, BLAS_Z, BLAS_D)
450  template <typename MAT, typename VEC>
451  void svd(const MAT &X_, MAT &U, MAT &Vtransposed, VEC &sigma) {
452  MAT X(X_);
453  svd(X, U, Vtransposed, sigma);
454  }
456 }
458 #else
460 namespace gmm
461 {
462 template <typename MAT>
463 void schur(const MAT &, MAT &, MAT &)
464 {
465  GMM_ASSERT1(false, "Use of function schur(A,S,Q) requires GetFEM "
466  "to be built with Lapack");
467 }
469 template <typename BLAS_TYPE>
470 inline void svd(dense_matrix<BLAS_TYPE> &, dense_matrix<BLAS_TYPE> &,
471  dense_matrix<BLAS_TYPE> &, std::vector<BLAS_TYPE> &)
472 {
473  GMM_ASSERT1(false, "Use of function svd(X,U,Vtransposed,sigma) requires GetFEM "
474  "to be built with Lapack");
475 }
477 }// namespace gmm
479 #endif // GMM_USES_LAPACK
