37 #if defined(GMM_USES_BLAS) || defined(GMM_USES_LAPACK)
39 #ifndef GMM_BLAS_INTERFACE_H
40 #define GMM_BLAS_INTERFACE_H
50 #define GMMLAPACK_TRACE(f)
53 #if defined(WeirdNEC) || defined(GMM_USE_BLAS64_INTERFACE)
149 # define BLAS_S float
150 # define BLAS_D double
151 # define BLAS_C std::complex<float>
152 # define BLAS_Z std::complex<double>
153 typedef struct{
float r,i;} FORTRAN_BLAS_C;
154 typedef struct{
double r,i;} FORTRAN_BLAS_Z;
157 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
158 # define BLAS_CPLX_FUNC_CALL(blasname, ftype, res, ...) \
159 blasname(&res, __VA_ARGS__)
161 # define BLAS_CPLX_FUNC_CALL(blasname, ftype, res, ...) \
162 ftype _res=blasname(__VA_ARGS__); res=decltype(res){_res.r,_res.i};
169 void daxpy_(
const BLAS_INT *n,
const double *alpha,
const double *x,
170 const BLAS_INT *incx,
double *y,
const BLAS_INT *incy);
171 void saxpy_(...);
void caxpy_(...);
void zaxpy_(...);
172 void dgemm_(
const char *tA,
const char *tB,
const BLAS_INT *m,
173 const BLAS_INT *n,
const BLAS_INT *k,
const BLAS_D *alpha,
174 const BLAS_D *A,
const BLAS_INT *ldA,
const BLAS_D *B,
175 const BLAS_INT *ldB,
const BLAS_D *beta, BLAS_D *C,
176 const BLAS_INT *ldC);
177 void sgemm_(...);
void cgemm_(...);
void zgemm_(...);
178 void sgemv_(...);
void dgemv_(...);
void cgemv_(...);
void zgemv_(...);
179 void strsv_(...);
void dtrsv_(...);
void ctrsv_(...);
void ztrsv_(...);
180 BLAS_S sdot_ (...); BLAS_D ddot_ (...);
181 #if defined(GMM_BLAS_RETURN_COMPLEX_AS_ARGUMENT)
182 void cdotu_(...);
void zdotu_(...);
void cdotc_(...);
void zdotc_(...);
184 FORTRAN_BLAS_C cdotu_(...); FORTRAN_BLAS_Z zdotu_(...);
186 FORTRAN_BLAS_C cdotc_(...); FORTRAN_BLAS_Z zdotc_(...);
188 BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
189 BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
190 void sger_(...);
void dger_(...);
void cgerc_(...);
void zgerc_(...);
198 # define nrm2_interface(blas_name, base_type) \
199 inline number_traits<base_type>::magnitude_type \
200 vect_norm2(const std::vector<base_type> &x) { \
201 GMMLAPACK_TRACE("nrm2_interface"); \
202 const BLAS_INT n=BLAS_INT(vect_size(x)), inc(1); \
203 return blas_name(&n, &x[0], &inc); \
206 nrm2_interface(snrm2_, BLAS_S)
207 nrm2_interface(dnrm2_, BLAS_D)
208 nrm2_interface(scnrm2_, BLAS_C)
209 nrm2_interface(dznrm2_, BLAS_Z)
215 # define dot_interface(funcname, msg, blas_name, base_type) \
216 inline base_type funcname(const std::vector<base_type> &x, \
217 const std::vector<base_type> &y) { \
218 GMMLAPACK_TRACE(msg); \
219 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
220 return blas_name(&n, &x[0], &inc, &y[0], &inc); \
222 inline base_type funcname \
223 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
224 const std::vector<base_type> &y) { \
225 GMMLAPACK_TRACE(msg); \
226 const std::vector<base_type> &x = *(linalg_origin(x_)); \
228 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
229 return a * blas_name(&n, &x[0], &inc, &y[0], &inc); \
231 inline base_type funcname \
232 (const std::vector<base_type> &x, \
233 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
234 GMMLAPACK_TRACE(msg); \
235 const std::vector<base_type> &y = *(linalg_origin(y_)); \
237 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
238 return b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
240 inline base_type funcname \
241 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
242 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
243 GMMLAPACK_TRACE(msg); \
244 const std::vector<base_type> &x = *(linalg_origin(x_)); \
245 const std::vector<base_type> &y = *(linalg_origin(y_)); \
246 base_type a(x_.r), b(y_.r); \
247 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); \
248 return a*b * blas_name(&n, &x[0], &inc, &y[0], &inc); \
251 dot_interface(vect_sp,
"dot_interface", sdot_, BLAS_S)
252 dot_interface(vect_sp,
"dot_interface", ddot_, BLAS_D)
253 dot_interface(vect_hp,
"dotc_interface", sdot_, BLAS_S)
254 dot_interface(vect_hp,
"dotc_interface", ddot_, BLAS_D)
262 # define dot_interface_cplx(funcname, msg, blas_name, base_type, ftype, b) \
263 inline base_type funcname(const std::vector<base_type> &x, \
264 const std::vector<base_type> &y) { \
265 GMMLAPACK_TRACE(msg); \
266 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
267 BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
268 &n, &y[0], &inc, &x[0], &inc) \
271 inline base_type funcname \
272 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
273 const std::vector<base_type> &y) { \
274 GMMLAPACK_TRACE(msg); \
275 const std::vector<base_type> &x = *(linalg_origin(x_)); \
276 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
277 BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
278 &n, &y[0], &inc, &x[0], &inc) \
281 inline base_type funcname \
282 (const std::vector<base_type> &x, \
283 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
284 GMMLAPACK_TRACE(msg); \
285 const std::vector<base_type> &y = *(linalg_origin(y_)); \
286 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
287 BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
288 &n, &y[0], &inc, &x[0], &inc) \
291 inline base_type funcname \
292 (const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_, \
293 const scaled_vector_const_ref<std::vector<base_type>,base_type> &y_) {\
294 GMMLAPACK_TRACE(msg); \
295 const std::vector<base_type> &x = *(linalg_origin(x_)); \
296 const std::vector<base_type> &y = *(linalg_origin(y_)); \
297 const BLAS_INT n=BLAS_INT(vect_size(y)), inc(1); base_type res; \
298 BLAS_CPLX_FUNC_CALL(blas_name, ftype, res, \
299 &n, &y[0], &inc, &x[0], &inc) \
300 return (x_.r)*(b)*res; \
303 dot_interface_cplx(vect_sp,
"dot_interface", cdotu_,
304 BLAS_C, FORTRAN_BLAS_C, y_.r)
305 dot_interface_cplx(vect_sp, "dot_interface", zdotu_,
306 BLAS_Z, FORTRAN_BLAS_Z, y_.r)
307 dot_interface_cplx(vect_hp, "dotc_interface", cdotc_,
308 BLAS_C, FORTRAN_BLAS_C, gmm::conj(y_.r))
309 dot_interface_cplx(vect_hp, "dotc_interface", zdotc_,
310 BLAS_Z, FORTRAN_BLAS_Z, gmm::conj(y_.r))
316 template<
size_type N, class V1, class V2>
317 inline
void add_fixed(const V1 &x, V2 &y)
319 for(
size_type i = 0; i != N; ++i) y[i] += x[i];
322 template<
class V1,
class V2>
323 inline void add_for_short_vectors(
const V1 &x, V2 &y,
size_type n)
327 case 1: add_fixed<1>(x, y);
break;
328 case 2: add_fixed<2>(x, y);
break;
329 case 3: add_fixed<3>(x, y);
break;
330 case 4: add_fixed<4>(x, y);
break;
331 case 5: add_fixed<5>(x, y);
break;
332 case 6: add_fixed<6>(x, y);
break;
333 case 7: add_fixed<7>(x, y);
break;
334 case 8: add_fixed<8>(x, y);
break;
335 case 9: add_fixed<9>(x, y);
break;
336 case 10: add_fixed<10>(x, y);
break;
337 case 11: add_fixed<11>(x, y);
break;
338 case 12: add_fixed<12>(x, y);
break;
339 case 13: add_fixed<13>(x, y);
break;
340 case 14: add_fixed<14>(x, y);
break;
341 case 15: add_fixed<15>(x, y);
break;
342 case 16: add_fixed<16>(x, y);
break;
343 case 17: add_fixed<17>(x, y);
break;
344 case 18: add_fixed<18>(x, y);
break;
345 case 19: add_fixed<19>(x, y);
break;
346 case 20: add_fixed<20>(x, y);
break;
347 case 21: add_fixed<21>(x, y);
break;
348 case 22: add_fixed<22>(x, y);
break;
349 case 23: add_fixed<23>(x, y);
break;
350 case 24: add_fixed<24>(x, y);
break;
352 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
357 template<
size_type N,
class V1,
class V2,
class T>
358 inline void add_fixed(
const V1 &x, V2 &y,
const T &a)
360 for(
size_type i = 0; i != N; ++i) y[i] += a*x[i];
363 template<
class V1,
class V2,
class T>
364 inline void add_for_short_vectors(
const V1 &x, V2 &y,
const T &a,
size_type n)
368 case 1: add_fixed<1>(x, y, a);
break;
369 case 2: add_fixed<2>(x, y, a);
break;
370 case 3: add_fixed<3>(x, y, a);
break;
371 case 4: add_fixed<4>(x, y, a);
break;
372 case 5: add_fixed<5>(x, y, a);
break;
373 case 6: add_fixed<6>(x, y, a);
break;
374 case 7: add_fixed<7>(x, y, a);
break;
375 case 8: add_fixed<8>(x, y, a);
break;
376 case 9: add_fixed<9>(x, y, a);
break;
377 case 10: add_fixed<10>(x, y, a);
break;
378 case 11: add_fixed<11>(x, y, a);
break;
379 case 12: add_fixed<12>(x, y, a);
break;
380 case 13: add_fixed<13>(x, y, a);
break;
381 case 14: add_fixed<14>(x, y, a);
break;
382 case 15: add_fixed<15>(x, y, a);
break;
383 case 16: add_fixed<16>(x, y, a);
break;
384 case 17: add_fixed<17>(x, y, a);
break;
385 case 18: add_fixed<18>(x, y, a);
break;
386 case 19: add_fixed<19>(x, y, a);
break;
387 case 20: add_fixed<20>(x, y, a);
break;
388 case 21: add_fixed<21>(x, y, a);
break;
389 case 22: add_fixed<22>(x, y, a);
break;
390 case 23: add_fixed<23>(x, y, a);
break;
391 case 24: add_fixed<24>(x, y, a);
break;
393 GMM_ASSERT2(
false,
"add_for_short_vectors used with unsupported size");
399 # define axpy_interface(blas_name, base_type) \
400 inline void add(const std::vector<base_type> &x, \
401 std::vector<base_type> &y) { \
402 GMMLAPACK_TRACE("axpy_interface"); \
403 const size_type nn=vect_size(y); \
404 if (nn == 0) return; \
405 else if (nn < 25) add_for_short_vectors(x, y, nn); \
406 else { const BLAS_INT n=BLAS_INT(nn), inc(1); const base_type a(1); \
407 blas_name(&n, &a, &x[0], &inc, &y[0], &inc); } \
409 inline void add(const scaled_vector_const_ref<std::vector<base_type>, \
411 std::vector<base_type> &y) { \
412 GMMLAPACK_TRACE("axpy_interface"); \
413 const size_type nn=vect_size(y); const base_type a(x_.r); \
414 const std::vector<base_type>& x = *(linalg_origin(x_)); \
415 if (nn == 0) return; \
416 else if (nn < 25) add_for_short_vectors(x, y, a, nn); \
417 else { const BLAS_INT n=BLAS_INT(nn), inc(1); \
418 blas_name(&n, &a, &x[0], &inc, &y[0], &inc); } \
421 axpy_interface(saxpy_, BLAS_S)
422 axpy_interface(daxpy_, BLAS_D)
423 axpy_interface(caxpy_, BLAS_C)
424 axpy_interface(zaxpy_, BLAS_Z)
432 # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
434 inline void mult_add_spec(param1(base_type), param2(base_type), \
435 std::vector<base_type> &z, orien) { \
436 GMMLAPACK_TRACE("gemv_interface"); \
437 trans1(base_type); trans2(base_type); const base_type beta(1); \
438 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
439 n=BLAS_INT(mat_ncols(A)), inc(1); \
440 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
441 &beta, &z[0], &inc); \
442 else gmm::clear(z); \
444 inline void mult_spec(param1(base_type), param2(base_type), \
445 std::vector<base_type> &z, orien) { \
446 GMMLAPACK_TRACE("gemv_interface2"); \
447 trans1(base_type); trans2(base_type); const base_type beta(0); \
448 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
449 n=BLAS_INT(mat_ncols(A)), inc(1); \
450 if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, \
451 &x[0], &inc, &beta, &z[0], &inc); \
452 else gmm::clear(z); \
456 # define gem_p1_n(base_type) const dense_matrix<base_type> &A
457 # define gem_trans1_n(base_type) const char t = 'N'
458 # define gem_p1_t(base_type) \
459 const transposed_col_ref<dense_matrix<base_type> *> &A_
460 # define gem_trans1_t(base_type) const dense_matrix<base_type> &A = \
461 *(linalg_origin(A_)); \
463 # define gem_p1_tc(base_type) \
464 const transposed_col_ref<const dense_matrix<base_type> *> &A_
465 # define gem_p1_c(base_type) \
466 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_
467 # define gem_trans1_c(base_type) const dense_matrix<base_type> &A = \
468 *(linalg_origin(A_)); \
472 # define gemv_p2_n(base_type) const std::vector<base_type> &x
473 # define gemv_trans2_n(base_type) base_type alpha(1)
474 # define gemv_p2_s(base_type) \
475 const scaled_vector_const_ref<std::vector<base_type>,base_type> &x_
476 # define gemv_trans2_s(base_type) const std::vector<base_type> &x = \
477 (*(linalg_origin(x_))); \
478 base_type alpha(x_.r)
482 gemv_interface(gem_p1_n, gem_trans1_n,
483 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, col_major)
484 gemv_interface(gem_p1_n, gem_trans1_n,
485 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, col_major)
486 gemv_interface(gem_p1_n, gem_trans1_n,
487 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, col_major)
488 gemv_interface(gem_p1_n, gem_trans1_n,
489 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, col_major)
493 gemv_interface(gem_p1_t, gem_trans1_t,
494 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
495 gemv_interface(gem_p1_t, gem_trans1_t,
496 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
497 gemv_interface(gem_p1_t, gem_trans1_t,
498 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
499 gemv_interface(gem_p1_t, gem_trans1_t,
500 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
504 gemv_interface(gem_p1_tc, gem_trans1_t,
505 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
506 gemv_interface(gem_p1_tc, gem_trans1_t,
507 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
508 gemv_interface(gem_p1_tc, gem_trans1_t,
509 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
510 gemv_interface(gem_p1_tc, gem_trans1_t,
511 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
515 gemv_interface(gem_p1_c, gem_trans1_c,
516 gemv_p2_n, gemv_trans2_n, sgemv_, BLAS_S, row_major)
517 gemv_interface(gem_p1_c, gem_trans1_c,
518 gemv_p2_n, gemv_trans2_n, dgemv_, BLAS_D, row_major)
519 gemv_interface(gem_p1_c, gem_trans1_c,
520 gemv_p2_n, gemv_trans2_n, cgemv_, BLAS_C, row_major)
521 gemv_interface(gem_p1_c, gem_trans1_c,
522 gemv_p2_n, gemv_trans2_n, zgemv_, BLAS_Z, row_major)
526 gemv_interface(gem_p1_n, gem_trans1_n,
527 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, col_major)
528 gemv_interface(gem_p1_n, gem_trans1_n,
529 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, col_major)
530 gemv_interface(gem_p1_n, gem_trans1_n,
531 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, col_major)
532 gemv_interface(gem_p1_n, gem_trans1_n,
533 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, col_major)
537 gemv_interface(gem_p1_t, gem_trans1_t,
538 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
539 gemv_interface(gem_p1_t, gem_trans1_t,
540 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
541 gemv_interface(gem_p1_t, gem_trans1_t,
542 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
543 gemv_interface(gem_p1_t, gem_trans1_t,
544 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
548 gemv_interface(gem_p1_tc, gem_trans1_t,
549 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
550 gemv_interface(gem_p1_tc, gem_trans1_t,
551 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
552 gemv_interface(gem_p1_tc, gem_trans1_t,
553 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
554 gemv_interface(gem_p1_tc, gem_trans1_t,
555 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
559 gemv_interface(gem_p1_c, gem_trans1_c,
560 gemv_p2_s, gemv_trans2_s, sgemv_, BLAS_S, row_major)
561 gemv_interface(gem_p1_c, gem_trans1_c,
562 gemv_p2_s, gemv_trans2_s, dgemv_, BLAS_D, row_major)
563 gemv_interface(gem_p1_c, gem_trans1_c,
564 gemv_p2_s, gemv_trans2_s, cgemv_, BLAS_C, row_major)
565 gemv_interface(gem_p1_c, gem_trans1_c,
566 gemv_p2_s, gemv_trans2_s, zgemv_, BLAS_Z, row_major)
573 # define ger_interface(blas_name, base_type) \
574 inline void rank_one_update(dense_matrix<base_type> &A, \
575 const std::vector<base_type> &V, \
576 const std::vector<base_type> &W) { \
577 GMMLAPACK_TRACE("ger_interface"); \
578 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
579 n=BLAS_INT(mat_ncols(A)), inc(1); \
580 const base_type alpha(1); \
582 blas_name(&m, &n, &alpha, &V[0], &inc, &W[0], &inc, &A(0,0), &lda); \
585 rank_one_update(dense_matrix<base_type> &A, \
586 const scaled_vector_const_ref<std::vector<base_type>, \
588 const std::vector<base_type> &W) { \
589 GMMLAPACK_TRACE("ger_interface"); \
590 const std::vector<base_type> &x = (*(linalg_origin(x_))); \
591 const base_type alpha(x_.r); \
592 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
593 n=BLAS_INT(mat_ncols(A)), inc(1); \
595 blas_name(&m, &n, &alpha, &x[0], &inc, &W[0], &inc, &A(0,0), &lda); \
598 rank_one_update(dense_matrix<base_type> &A, \
599 const std::vector<base_type> &V, \
600 const scaled_vector_const_ref<std::vector<base_type>, \
602 GMMLAPACK_TRACE("ger_interface"); \
603 const std::vector<base_type> &x = (*(linalg_origin(x_))); \
604 const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
605 n=BLAS_INT(mat_ncols(A)), inc(1); \
606 const base_type alpha0(x_.r), alpha=gmm::conj(alpha0); \
608 blas_name(&m, &n, &alpha, &V[0], &inc, &x[0], &inc, &A(0,0), &lda); \
611 ger_interface(sger_, BLAS_S)
612 ger_interface(dger_, BLAS_D)
613 ger_interface(cgerc_, BLAS_C)
614 ger_interface(zgerc_, BLAS_Z)
621 # define gemm_interface_nn(blas_name, base_type) \
622 inline void mult_spec(const dense_matrix<base_type> &A, \
623 const dense_matrix<base_type> &B, \
624 dense_matrix<base_type> &C, c_mult) { \
625 GMMLAPACK_TRACE("gemm_interface_nn"); \
626 const char t='N'; const BLAS_INT m=BLAS_INT(mat_nrows(A)), lda(m), \
627 k=BLAS_INT(mat_ncols(A)), ldb(k), \
628 n=BLAS_INT(mat_ncols(B)), ldc(m); \
629 const base_type alpha(1), beta(0); \
631 blas_name(&t, &t, &m, &n, &k, &alpha, \
632 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
633 else gmm::clear(C); \
636 gemm_interface_nn(sgemm_, BLAS_S)
637 gemm_interface_nn(dgemm_, BLAS_D)
638 gemm_interface_nn(cgemm_, BLAS_C)
639 gemm_interface_nn(zgemm_, BLAS_Z)
646 # define gemm_interface_tn_nt(blas_name, base_type, mat_type) \
647 inline void mult_spec( \
648 const transposed_col_ref<mat_type<base_type> *> &A_, \
649 const dense_matrix<base_type> &B, \
650 dense_matrix<base_type> &C, rcmult) { \
651 GMMLAPACK_TRACE("gemm_interface_tn"); \
652 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
653 const char t = 'T', u = 'N'; \
654 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
655 n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
656 const base_type alpha(1), beta(0); \
657 if (m && k && n) blas_name(&t, &u, &m, &n, &k, &alpha, &A(0,0), &lda, \
658 &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
659 else gmm::clear(C); \
662 mult_spec(const dense_matrix<base_type> &A, \
663 const transposed_col_ref<mat_type<base_type> *> &B_, \
664 dense_matrix<base_type> &C, r_mult) { \
665 GMMLAPACK_TRACE("gemm_interface_nt"); \
666 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
667 const char t = 'N', u = 'T'; \
668 const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
669 n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
670 const base_type alpha(1), beta(0); \
671 if (m && k && n) blas_name(&t, &u, &m, &n, &k, &alpha, &A(0,0), &lda, \
672 &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
673 else gmm::clear(C); \
676 gemm_interface_tn_nt(sgemm_, BLAS_S, dense_matrix)
677 gemm_interface_tn_nt(dgemm_, BLAS_D, dense_matrix)
678 gemm_interface_tn_nt(cgemm_, BLAS_C, dense_matrix)
679 gemm_interface_tn_nt(zgemm_, BLAS_Z, dense_matrix)
680 gemm_interface_tn_nt(sgemm_, BLAS_S,
const dense_matrix)
681 gemm_interface_tn_nt(dgemm_, BLAS_D,
const dense_matrix)
682 gemm_interface_tn_nt(cgemm_, BLAS_C,
const dense_matrix)
683 gemm_interface_tn_nt(zgemm_, BLAS_Z,
const dense_matrix)
690 # define gemm_interface_tt(blas_name, base_type, matA_type, matB_type) \
692 mult_spec(const transposed_col_ref<matA_type<base_type> *> &A_, \
693 const transposed_col_ref<matB_type<base_type> *> &B_, \
694 dense_matrix<base_type> &C, r_mult) { \
695 GMMLAPACK_TRACE("gemm_interface_tt"); \
696 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
697 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
698 const char t = 'T', u = 'T'; \
699 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
700 n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
701 base_type alpha(1), beta(0); \
703 blas_name(&t, &u, &m, &n, &k, &alpha, \
704 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
705 else gmm::clear(C); \
708 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
709 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
710 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
711 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
712 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix, dense_matrix)
713 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix, dense_matrix)
714 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix, dense_matrix)
715 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix, dense_matrix)
716 gemm_interface_tt(sgemm_, BLAS_S, dense_matrix,
const dense_matrix)
717 gemm_interface_tt(dgemm_, BLAS_D, dense_matrix,
const dense_matrix)
718 gemm_interface_tt(cgemm_, BLAS_C, dense_matrix,
const dense_matrix)
719 gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix,
const dense_matrix)
720 gemm_interface_tt(sgemm_, BLAS_S,
const dense_matrix,
const dense_matrix)
721 gemm_interface_tt(dgemm_, BLAS_D,
const dense_matrix,
const dense_matrix)
722 gemm_interface_tt(cgemm_, BLAS_C,
const dense_matrix,
const dense_matrix)
723 gemm_interface_tt(zgemm_, BLAS_Z,
const dense_matrix,
const dense_matrix)
732 # define gemm_interface_cn_nc_cc(blas_name, base_type) \
733 inline void mult_spec( \
734 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
735 const dense_matrix<base_type> &B, \
736 dense_matrix<base_type> &C, rcmult) { \
737 GMMLAPACK_TRACE("gemm_interface_cn"); \
738 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
739 const char t = 'C', u = 'N'; \
740 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
741 n=BLAS_INT(mat_ncols(B)), lda(k), ldb(k), ldc(m); \
742 const base_type alpha(1), beta(0); \
744 blas_name(&t, &u, &m, &n, &k, &alpha, \
745 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
746 else gmm::clear(C); \
748 inline void mult_spec( \
749 const dense_matrix<base_type> &A, \
750 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
751 dense_matrix<base_type> &C, c_mult, row_major) { \
752 GMMLAPACK_TRACE("gemm_interface_nc"); \
753 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
754 const char t = 'N', u = 'C'; \
755 const BLAS_INT m=BLAS_INT(mat_nrows(A)), k=BLAS_INT(mat_ncols(A)), \
756 n=BLAS_INT(mat_nrows(B)), lda(m), ldb(n), ldc(m); \
757 const base_type alpha(1), beta(0); \
759 blas_name(&t, &u, &m, &n, &k, &alpha, \
760 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
761 else gmm::clear(C); \
763 inline void mult_spec( \
764 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &A_, \
765 const conjugated_col_matrix_const_ref<dense_matrix<base_type> > &B_, \
766 dense_matrix<base_type> &C, r_mult) { \
767 GMMLAPACK_TRACE("gemm_interface_cc"); \
768 const dense_matrix<base_type> &A = *(linalg_origin(A_)); \
769 const dense_matrix<base_type> &B = *(linalg_origin(B_)); \
770 const char t = 'C', u = 'C'; \
771 const BLAS_INT m=BLAS_INT(mat_ncols(A)), k=BLAS_INT(mat_nrows(A)), \
772 n=BLAS_INT(mat_nrows(B)), lda(k), ldb(n), ldc(m); \
773 const base_type alpha(1), beta(0); \
775 blas_name(&t, &u, &m, &n, &k, &alpha, \
776 &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
777 else gmm::clear(C); \
780 gemm_interface_cn_nc_cc(sgemm_, BLAS_S)
781 gemm_interface_cn_nc_cc(dgemm_, BLAS_D)
782 gemm_interface_cn_nc_cc(cgemm_, BLAS_C)
783 gemm_interface_cn_nc_cc(zgemm_, BLAS_Z)
790 # define trsv_interface(LorU1, LorU2, param1, trans1, blas_name, base_type)\
792 lower_tri_solve(param1(base_type), std::vector<base_type> &x, \
793 size_type k, bool is_unit) { \
794 GMMLAPACK_TRACE("trsv_interface"); \
795 const char l = LorU1; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
796 const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
797 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
800 upper_tri_solve(param1(base_type), std::vector<base_type> &x, \
801 size_type k, bool is_unit) { \
802 GMMLAPACK_TRACE("trsv_interface"); \
803 const char l = LorU2; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
804 const BLAS_INT lda=BLAS_INT(mat_nrows(A)), inc(1), n=BLAS_INT(k); \
805 if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
810 trsv_interface(
'L',
'U', gem_p1_n, gem_trans1_n, strsv_, BLAS_S)
811 trsv_interface(
'L',
'U', gem_p1_n, gem_trans1_n, dtrsv_, BLAS_D)
812 trsv_interface(
'L',
'U', gem_p1_n, gem_trans1_n, ctrsv_, BLAS_C)
813 trsv_interface(
'L',
'U', gem_p1_n, gem_trans1_n, ztrsv_, BLAS_Z)
817 trsv_interface(
'U',
'L', gem_p1_t, gem_trans1_t, strsv_, BLAS_S)
818 trsv_interface(
'U',
'L', gem_p1_t, gem_trans1_t, dtrsv_, BLAS_D)
819 trsv_interface(
'U',
'L', gem_p1_t, gem_trans1_t, ctrsv_, BLAS_C)
820 trsv_interface(
'U',
'L', gem_p1_t, gem_trans1_t, ztrsv_, BLAS_Z)
824 trsv_interface(
'U',
'L', gem_p1_tc, gem_trans1_t, strsv_, BLAS_S)
825 trsv_interface(
'U',
'L', gem_p1_tc, gem_trans1_t, dtrsv_, BLAS_D)
826 trsv_interface(
'U',
'L', gem_p1_tc, gem_trans1_t, ctrsv_, BLAS_C)
827 trsv_interface(
'U',
'L', gem_p1_tc, gem_trans1_t, ztrsv_, BLAS_Z)
831 trsv_interface(
'U',
'L', gem_p1_c, gem_trans1_c, strsv_, BLAS_S)
832 trsv_interface(
'U',
'L', gem_p1_c, gem_trans1_c, dtrsv_, BLAS_D)
833 trsv_interface(
'U',
'L', gem_p1_c, gem_trans1_c, ctrsv_, BLAS_C)
834 trsv_interface(
'U',
'L', gem_p1_c, gem_trans1_c, ztrsv_, BLAS_Z)
Basic linear algebra functions.
gmm interface for STL vectors.
Declaration of some matrix types (gmm::dense_matrix, gmm::row_matrix, gmm::col_matrix,...
size_t size_type
used as the common size type in the library