sisi4s
Loading...
Searching...
No Matches
LapackMatrix.hpp
Go to the documentation of this file.
1#ifndef LAPACK_MATRIX_DEFINED
2#define LAPACK_MATRIX_DEFINED
3
4#include <math/Complex.hpp>
5#include <vector>
6#include <sstream>
7#include <util/Tensor.hpp>
8
9namespace sisi4s {
10template <typename F = real>
12public:
17 : rows(A.rows)
18 , columns(A.columns)
19 , values(A.values) {}
20
24 LapackMatrix(const int rows_, const int columns_)
25 : rows(rows_)
26 , columns(columns_)
27 , values(rows_ * columns_) {}
28
32 LapackMatrix(const int rows_,
33 const int columns_,
34 const std::vector<F> &values_)
35 : rows(rows_)
36 , columns(columns_)
37 , values(values_) {}
38
43 LapackMatrix(Tensor<F> &ctfA, const int rows_ = 0, const int columns_ = 0)
44 : rows(rows_ > 0 ? rows_ : ctfA.lens[0])
45 , columns(columns_ > 0 ? columns_ : ctfA.lens[1]) {
46 values.resize(rows * columns);
47 ctfA.read_all(values.data(), true);
48 }
49
51 rows = A.rows;
52 columns = A.columns;
53 values = A.values;
54 return *this;
55 }
56
57 const F &operator()(const int i, const int j) const {
58 return values[i + j * rows];
59 }
60
61 F &operator()(const int i, const int j) { return values[i + j * rows]; }
62
63 int getRows() const { return rows; }
64 int getColumns() const { return columns; }
65
69 const F *getValues() const { return values.data(); }
70
74 F *getValues() { return values.data(); }
75
79 void write(Tensor<F> &ctfA) const {
80 if (ctfA.lens[0] != rows || ctfA.lens[1] != columns) {
81 std::stringstream stream;
82 stream << "Tensor is not of correct shape to receive (" << rows << "x"
83 << columns << ") matrix: ";
84 std::string join("");
85 for (int d(0); d < ctfA.order; ++d) {
86 stream << join << ctfA.lens[d];
87 join = "x";
88 }
89 throw new EXCEPTION(stream.str());
90 }
91 int64_t size(values.size());
92 int64_t localSize(ctfA.wrld->rank == 0 ? size : 0);
93 std::vector<int64_t> indices(localSize);
94 for (int64_t i(0); i < localSize; ++i) { indices[i] = i; }
95 ctfA.write(localSize, indices.data(), values.data());
96 }
97
98protected:
103
107 std::vector<F> values;
108};
109
110// TODO: use blas (D|Z)GEMM for matrix multiplication
111// TODO: support move semantics
112template <typename F = real>
114 if (A.getColumns() != B.getRows()) {
115 std::stringstream stream;
116 stream << "Matrix shapes not compatible for multiplication: ("
117 << A.getRows() << "x" << A.getColumns() << ") . (" << B.getRows()
118 << "x" << B.getColumns() << ")";
119 throw new EXCEPTION(stream.str());
120 }
122 for (int i(0); i < A.getRows(); ++i) {
123 for (int j(0); j < B.getColumns(); ++j) {
124 C(i, j) = 0;
125 for (int k(0); k < A.getColumns(); ++k) { C(i, j) += A(i, k) * B(k, j); }
126 }
127 }
128 return C;
129}
130} // namespace sisi4s
131
132#endif
#define EXCEPTION(message)
Definition Exception.hpp:8
Definition LapackMatrix.hpp:11
F & operator()(const int i, const int j)
Definition LapackMatrix.hpp:61
int columns
Definition LapackMatrix.hpp:102
void write(Tensor< F > &ctfA) const
Writes the data of this Lapack matrix to the CTF tensor.
Definition LapackMatrix.hpp:79
LapackMatrix(const LapackMatrix< F > &A)
Copies the content of the Lapack matrix.
Definition LapackMatrix.hpp:16
const F * getValues() const
Returns the pointer to the column major data.
Definition LapackMatrix.hpp:69
LapackMatrix(Tensor< F > &ctfA, const int rows_=0, const int columns_=0)
Constructs an LapackMatrix from a CTF tensor on all ranks if rows and columns is given,...
Definition LapackMatrix.hpp:43
int getColumns() const
Definition LapackMatrix.hpp:64
F * getValues()
Returns the pointer to the mutable column major data.
Definition LapackMatrix.hpp:74
LapackMatrix(const int rows_, const int columns_)
Construct a zero nxm Lapack matrix.
Definition LapackMatrix.hpp:24
LapackMatrix< F > & operator=(const LapackMatrix< F > &A)
Definition LapackMatrix.hpp:50
const F & operator()(const int i, const int j) const
Definition LapackMatrix.hpp:57
std::vector< F > values
The column major data.
Definition LapackMatrix.hpp:107
LapackMatrix(const int rows_, const int columns_, const std::vector< F > &values_)
Construct an nxm Lapack matrix from a given vector of values.
Definition LapackMatrix.hpp:32
int rows
Number of rows and columns.
Definition LapackMatrix.hpp:102
int getRows() const
Definition LapackMatrix.hpp:63
Definition Algorithm.hpp:10
CTF::Tensor< F > Tensor
Definition Tensor.hpp:9
std::string operator*(const std::string &s, const sisi4s::Permutation< N > &pi)
Definition CcsdPerturbativeTriples.cxx:24