sisi4s
Loading...
Searching...
No Matches
CtfMachineTensor.hpp
Go to the documentation of this file.
1#ifndef CTF_MACHINE_TENSOR_DEFINED
2#define CTF_MACHINE_TENSOR_DEFINED
3
4#include <tcc/MachineTensor.hpp>
5#include <Sisi4s.hpp>
6#include <util/Tensor.hpp>
7#include <string>
8#include <memory>
9
10namespace sisi4s {
11template <typename F = double>
12class CtfMachineTensorFactory;
13
14template <typename F = double>
15class CtfMachineTensor : public tcc::MachineTensor<F> {
16protected:
18
19public:
20 // required by templates to infer corresponding Factory type
23
24 // constructors called by factory
25 CtfMachineTensor(const std::vector<int> &lens,
26 const std::string &name,
27 CTF::World *world,
28 const ProtectedToken &)
29 : tensor(static_cast<int>(lens.size()),
30 lens.data(),
31 std::vector<int>(0, lens.size()).data(),
32 *world,
33 name.c_str()) {}
34
35 // copy constructor from CTF tensor, for compatibility
37 : tensor(T) {}
38
39 static std::shared_ptr<CtfMachineTensor<F>> create(const Tensor &T) {
40 return std::make_shared<CtfMachineTensor<F>>(T, ProtectedToken());
41 }
42
43 virtual ~CtfMachineTensor() {}
44
45 // this[bIndices] = alpha * A[aIndices] + beta*this[bIndices]
46 virtual void move(F alpha,
47 const std::shared_ptr<tcc::MachineTensor<F>> &A,
48 const std::string &aIndices,
49 F beta,
50 const std::string &bIndices) {
51 std::shared_ptr<CtfMachineTensor<F>> ctfA(
52 std::dynamic_pointer_cast<CtfMachineTensor<F>>(A));
53 if (!ctfA) {
54 throw new EXCEPTION("Passed machine tensor of wrong implementation.");
55 }
56 LOG(2, "TCC") << "move " << getName() << "[" << bIndices
57 << "] <<= " << alpha << " * " << ctfA->getName() << "["
58 << aIndices << "] + " << beta << " * " << getName() << "["
59 << bIndices << "]" << std::endl;
60 tensor.sum(alpha, ctfA->tensor, aIndices.c_str(), beta, bIndices.c_str());
61 }
62
63 // this[bIndices] = alpha * f(A[aIndices]) + beta*this[bIndices]
64 void move(F alpha,
65 const std::shared_ptr<tcc::MachineTensor<F>> &A,
66 const std::string &aIndices,
67 F beta,
68 const std::string &bIndices,
69 const std::function<F(const F)> &f) {
70 std::shared_ptr<CtfMachineTensor<F>> ctfA(
71 std::dynamic_pointer_cast<CtfMachineTensor<F>>(A));
72 if (!ctfA) {
73 throw new EXCEPTION("Passed machine tensor of wrong implementation.");
74 }
75 LOG(2, "TCC") << "move " << getName() << "[" << bIndices
76 << "] <<= " << alpha << " * " << ctfA->getName() << "["
77 << aIndices << "] + " << beta << " * " << getName() << "["
78 << bIndices << "]" << std::endl;
79 tensor.sum(alpha,
80 ctfA->tensor,
81 aIndices.c_str(),
82 beta,
83 bIndices.c_str(),
84 CTF::Univar_Function<F>(f));
85 }
86
87 // this[cIndices] = alpha * A[aIndices] * B[bIndices] + beta*this[cIndices]
88 void contract(F alpha,
89 const std::shared_ptr<tcc::MachineTensor<F>> &A,
90 const std::string &aIndices,
91 const std::shared_ptr<tcc::MachineTensor<F>> &B,
92 const std::string &bIndices,
93 F beta,
94 const std::string &cIndices) {
95 std::shared_ptr<CtfMachineTensor<F>> ctfA(
96 std::dynamic_pointer_cast<CtfMachineTensor<F>>(A));
97 std::shared_ptr<CtfMachineTensor<F>> ctfB(
98 std::dynamic_pointer_cast<CtfMachineTensor<F>>(B));
99 if (!ctfA || !ctfB) {
100 throw new EXCEPTION("Passed machine tensor of wrong implementation.");
101 }
102 LOG(2, "TCC") << "contract " << getName() << "[" << cIndices << "] <<= g("
103 << alpha << " * " << ctfA->getName() << "[" << aIndices
104 << "], " << ctfB->getName() << "[" << bIndices << "]) + "
105 << beta << " * " << getName() << "[" << cIndices << "]"
106 << std::endl;
107 tensor.contract(alpha,
108 ctfA->tensor,
109 aIndices.c_str(),
110 ctfB->tensor,
111 bIndices.c_str(),
112 beta,
113 cIndices.c_str());
114 }
115
116 // this[cIndices] = alpha * g(A[aIndices],B[bIndices]) + beta*this[cIndices]
117 void contract(F alpha,
118 const std::shared_ptr<tcc::MachineTensor<F>> &A,
119 const std::string &aIndices,
120 const std::shared_ptr<tcc::MachineTensor<F>> &B,
121 const std::string &bIndices,
122 F beta,
123 const std::string &cIndices,
124 const std::function<F(const F, const F)> &g) {
125 std::shared_ptr<CtfMachineTensor<F>> ctfA(
126 std::dynamic_pointer_cast<CtfMachineTensor<F>>(A));
127 std::shared_ptr<CtfMachineTensor<F>> ctfB(
128 std::dynamic_pointer_cast<CtfMachineTensor<F>>(B));
129 if (!ctfA || !ctfB) {
130 throw new EXCEPTION("Passed machine tensor of wrong implementation.");
131 }
132 LOG(2, "TCC") << "contract " << getName() << "[" << cIndices << "] <<= g("
133 << alpha << " * " << ctfA->getName() << "[" << aIndices
134 << "], " << ctfB->getName() << "[" << bIndices << "]) + "
135 << beta << " * " << getName() << "[" << cIndices << "]"
136 << std::endl;
137 tensor.contract(alpha,
138 ctfA->tensor,
139 aIndices.c_str(),
140 ctfB->tensor,
141 bIndices.c_str(),
142 beta,
143 cIndices.c_str(),
144 CTF::Bivar_Function<F>(g));
145 }
146
147 // TODO: interfaces to be defined: slice, permute, transform
148
149 virtual std::vector<int> getLens() const {
150 return std::vector<int>(tensor.lens, tensor.lens + tensor.order);
151 }
152
153 virtual std::string getName() const { return std::string(tensor.get_name()); }
154
159
160 friend class CtfMachineTensorFactory<F>;
161};
162
163template <typename F>
164class CtfMachineTensorFactory : public tcc::MachineTensorFactory<F> {
165protected:
167
168public:
169 CtfMachineTensorFactory(CTF::World *world_, const ProtectedToken &)
170 : world(world_) {}
171
173
174 virtual std::shared_ptr<tcc::MachineTensor<F>>
175 createTensor(const std::vector<int> &lens, const std::string &name) {
176 return std::shared_ptr<typename tcc::MachineTensor<F>>(
177 std::make_shared<CtfMachineTensor<F>>(
178 lens,
179 name,
180 world,
182 }
183
184 static std::shared_ptr<CtfMachineTensorFactory<F>>
185 create(CTF::World *world = Sisi4s::world) {
186 return std::make_shared<CtfMachineTensorFactory<F>>(world,
188 }
189
190protected:
191 CTF::World *world;
192};
193} // namespace sisi4s
194
195#endif
#define EXCEPTION(message)
Definition Exception.hpp:8
#define LOG(...)
Definition Log.hpp:119
Definition CtfMachineTensor.hpp:166
Definition CtfMachineTensor.hpp:164
CtfMachineTensorFactory(CTF::World *world_, const ProtectedToken &)
Definition CtfMachineTensor.hpp:169
virtual std::shared_ptr< tcc::MachineTensor< F > > createTensor(const std::vector< int > &lens, const std::string &name)
Definition CtfMachineTensor.hpp:175
static std::shared_ptr< CtfMachineTensorFactory< F > > create(CTF::World *world=Sisi4s::world)
Definition CtfMachineTensor.hpp:185
CTF::World * world
Definition CtfMachineTensor.hpp:191
virtual ~CtfMachineTensorFactory()
Definition CtfMachineTensor.hpp:172
Definition CtfMachineTensor.hpp:17
Definition CtfMachineTensor.hpp:15
CtfMachineTensor(const Tensor &T, const ProtectedToken &)
Definition CtfMachineTensor.hpp:36
void contract(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, const std::shared_ptr< tcc::MachineTensor< F > > &B, const std::string &bIndices, F beta, const std::string &cIndices)
Definition CtfMachineTensor.hpp:88
virtual ~CtfMachineTensor()
Definition CtfMachineTensor.hpp:43
void contract(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, const std::shared_ptr< tcc::MachineTensor< F > > &B, const std::string &bIndices, F beta, const std::string &cIndices, const std::function< F(const F, const F)> &g)
Definition CtfMachineTensor.hpp:117
CtfMachineTensorFactory< F > Factory
Definition CtfMachineTensor.hpp:21
CtfMachineTensor(const std::vector< int > &lens, const std::string &name, CTF::World *world, const ProtectedToken &)
Definition CtfMachineTensor.hpp:25
static std::shared_ptr< CtfMachineTensor< F > > create(const Tensor &T)
Definition CtfMachineTensor.hpp:39
void move(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, F beta, const std::string &bIndices, const std::function< F(const F)> &f)
Definition CtfMachineTensor.hpp:64
Tensor tensor
The adapted CTF tensor.
Definition CtfMachineTensor.hpp:158
virtual std::string getName() const
Definition CtfMachineTensor.hpp:153
virtual std::vector< int > getLens() const
Definition CtfMachineTensor.hpp:149
virtual void move(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, F beta, const std::string &bIndices)
Definition CtfMachineTensor.hpp:46
Tensor< F > Tensor
Definition CtfMachineTensor.hpp:22
static CTF::World * world
Definition Sisi4s.hpp:17
Definition Algorithm.hpp:10