1#ifndef CTF_MACHINE_TENSOR_DEFINED
2#define CTF_MACHINE_TENSOR_DEFINED
4#include <tcc/MachineTensor.hpp>
11template <
typename F =
double>
12class CtfMachineTensorFactory;
14template <
typename F =
double>
26 const std::string &name,
29 :
tensor(static_cast<int>(lens.size()),
31 std::vector<int>(0, lens.size()).data(),
39 static std::shared_ptr<CtfMachineTensor<F>>
create(
const Tensor &T) {
47 const std::shared_ptr<tcc::MachineTensor<F>> &A,
48 const std::string &aIndices,
50 const std::string &bIndices) {
51 std::shared_ptr<CtfMachineTensor<F>> ctfA(
54 throw new EXCEPTION(
"Passed machine tensor of wrong implementation.");
56 LOG(2,
"TCC") <<
"move " <<
getName() <<
"[" << bIndices
57 <<
"] <<= " << alpha <<
" * " << ctfA->getName() <<
"["
58 << aIndices <<
"] + " << beta <<
" * " <<
getName() <<
"["
59 << bIndices <<
"]" << std::endl;
60 tensor.sum(alpha, ctfA->tensor, aIndices.c_str(), beta, bIndices.c_str());
65 const std::shared_ptr<tcc::MachineTensor<F>> &A,
66 const std::string &aIndices,
68 const std::string &bIndices,
69 const std::function<F(
const F)> &f) {
70 std::shared_ptr<CtfMachineTensor<F>> ctfA(
73 throw new EXCEPTION(
"Passed machine tensor of wrong implementation.");
75 LOG(2,
"TCC") <<
"move " <<
getName() <<
"[" << bIndices
76 <<
"] <<= " << alpha <<
" * " << ctfA->getName() <<
"["
77 << aIndices <<
"] + " << beta <<
" * " <<
getName() <<
"["
78 << bIndices <<
"]" << std::endl;
84 CTF::Univar_Function<F>(f));
89 const std::shared_ptr<tcc::MachineTensor<F>> &A,
90 const std::string &aIndices,
91 const std::shared_ptr<tcc::MachineTensor<F>> &B,
92 const std::string &bIndices,
94 const std::string &cIndices) {
95 std::shared_ptr<CtfMachineTensor<F>> ctfA(
97 std::shared_ptr<CtfMachineTensor<F>> ctfB(
100 throw new EXCEPTION(
"Passed machine tensor of wrong implementation.");
102 LOG(2,
"TCC") <<
"contract " <<
getName() <<
"[" << cIndices <<
"] <<= g("
103 << alpha <<
" * " << ctfA->getName() <<
"[" << aIndices
104 <<
"], " << ctfB->getName() <<
"[" << bIndices <<
"]) + "
105 << beta <<
" * " <<
getName() <<
"[" << cIndices <<
"]"
118 const std::shared_ptr<tcc::MachineTensor<F>> &A,
119 const std::string &aIndices,
120 const std::shared_ptr<tcc::MachineTensor<F>> &B,
121 const std::string &bIndices,
123 const std::string &cIndices,
124 const std::function<F(
const F,
const F)> &g) {
125 std::shared_ptr<CtfMachineTensor<F>> ctfA(
127 std::shared_ptr<CtfMachineTensor<F>> ctfB(
129 if (!ctfA || !ctfB) {
130 throw new EXCEPTION(
"Passed machine tensor of wrong implementation.");
132 LOG(2,
"TCC") <<
"contract " <<
getName() <<
"[" << cIndices <<
"] <<= g("
133 << alpha <<
" * " << ctfA->getName() <<
"[" << aIndices
134 <<
"], " << ctfB->getName() <<
"[" << bIndices <<
"]) + "
135 << beta <<
" * " <<
getName() <<
"[" << cIndices <<
"]"
144 CTF::Bivar_Function<F>(g));
153 virtual std::string
getName()
const {
return std::string(
tensor.get_name()); }
174 virtual std::shared_ptr<tcc::MachineTensor<F>>
176 return std::shared_ptr<typename tcc::MachineTensor<F>>(
177 std::make_shared<CtfMachineTensor<F>>(
184 static std::shared_ptr<CtfMachineTensorFactory<F>>
186 return std::make_shared<CtfMachineTensorFactory<F>>(
world,
#define EXCEPTION(message)
Definition Exception.hpp:8
#define LOG(...)
Definition Log.hpp:119
Definition CtfMachineTensor.hpp:166
Definition CtfMachineTensor.hpp:164
CtfMachineTensorFactory(CTF::World *world_, const ProtectedToken &)
Definition CtfMachineTensor.hpp:169
virtual std::shared_ptr< tcc::MachineTensor< F > > createTensor(const std::vector< int > &lens, const std::string &name)
Definition CtfMachineTensor.hpp:175
static std::shared_ptr< CtfMachineTensorFactory< F > > create(CTF::World *world=Sisi4s::world)
Definition CtfMachineTensor.hpp:185
CTF::World * world
Definition CtfMachineTensor.hpp:191
virtual ~CtfMachineTensorFactory()
Definition CtfMachineTensor.hpp:172
Definition CtfMachineTensor.hpp:17
Definition CtfMachineTensor.hpp:15
CtfMachineTensor(const Tensor &T, const ProtectedToken &)
Definition CtfMachineTensor.hpp:36
void contract(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, const std::shared_ptr< tcc::MachineTensor< F > > &B, const std::string &bIndices, F beta, const std::string &cIndices)
Definition CtfMachineTensor.hpp:88
virtual ~CtfMachineTensor()
Definition CtfMachineTensor.hpp:43
void contract(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, const std::shared_ptr< tcc::MachineTensor< F > > &B, const std::string &bIndices, F beta, const std::string &cIndices, const std::function< F(const F, const F)> &g)
Definition CtfMachineTensor.hpp:117
CtfMachineTensorFactory< F > Factory
Definition CtfMachineTensor.hpp:21
CtfMachineTensor(const std::vector< int > &lens, const std::string &name, CTF::World *world, const ProtectedToken &)
Definition CtfMachineTensor.hpp:25
static std::shared_ptr< CtfMachineTensor< F > > create(const Tensor &T)
Definition CtfMachineTensor.hpp:39
void move(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, F beta, const std::string &bIndices, const std::function< F(const F)> &f)
Definition CtfMachineTensor.hpp:64
Tensor tensor
The adapted CTF tensor.
Definition CtfMachineTensor.hpp:158
virtual std::string getName() const
Definition CtfMachineTensor.hpp:153
virtual std::vector< int > getLens() const
Definition CtfMachineTensor.hpp:149
virtual void move(F alpha, const std::shared_ptr< tcc::MachineTensor< F > > &A, const std::string &aIndices, F beta, const std::string &bIndices)
Definition CtfMachineTensor.hpp:46
Tensor< F > Tensor
Definition CtfMachineTensor.hpp:22
static CTF::World * world
Definition Sisi4s.hpp:17
Definition Algorithm.hpp:10