Collective Variables Module - Developer Documentation
Loading...
Searching...
No Matches
colvarcomp_torchann.h
1// -*- c++ -*-
2
3// This file is part of the Collective Variables module (Colvars).
4// The original version of Colvars and its updates are located at:
5// https://github.com/Colvars/colvars
6// Please update all Colvars source files before making any changes.
7// If you wish to distribute your changes, please submit them to the
8// Colvars repository at GitHub.
9//
10#ifndef COLVARCOMP_TORCH_H
11#define COLVARCOMP_TORCH_H
12
13// Declaration of torchann
14
15#include <memory>
16
17#include "colvar.h"
18#include "colvarcomp.h"
19#include "colvarmodule.h"
20
21#ifdef COLVARS_TORCH
22
23#include <torch/torch.h>
24#include <torch/script.h>
25
28{
29protected:
30 torch::jit::script::Module nn;
32 size_t m_output_index;
33 bool use_double_input;
34 //bool use_gpu;
35 // 1d tensor, concatenation of values of sub-cvcs
36 torch::Tensor input_tensor;
37 torch::Tensor nn_outputs;
38 torch::Tensor input_grad;
39 // record the initial index of of sub-cvcs in input_tensor
40 std::vector<int> cvc_indices;
41public:
42 torchANN();
43 virtual ~torchANN();
44 virtual int init(std::string const &conf);
45 virtual void calc_value();
46 virtual void calc_gradients();
47 virtual void apply_force(colvarvalue const &force);
48};
49
50#else
51
53 : public colvar::cvc
54{
55public:
56 torchANN();
57 virtual ~torchANN();
58 virtual int init(std::string const &conf);
59 virtual void calc_value();
60};
61#endif // COLVARS_TORCH checking
62
63#endif
64
Colvar component (base class for collective variables)
Definition: colvarcomp.h:70
virtual void calc_gradients()
Calculate the atomic gradients, to be reused later in order to apply forces.
Definition: colvarcomp.h:141
virtual void apply_force(colvarvalue const &cvforce)
Apply the collective variable force, by communicating the atomic forces to the simulation program (No...
Definition: colvarcomp.cpp:543
Current only linear combination of sub-CVCs is available.
Definition: colvarcomp.h:1313
Definition: colvarcomp_torchann.h:54
virtual void calc_value()
Calculate the variable.
Definition: colvarcomp_torchann.cpp:222
virtual int init(std::string const &conf)
Definition: colvarcomp_torchann.cpp:212
Value of a collective variable: this is a metatype which can be set at runtime. By default it is set ...
Definition: colvarvalue.h:43
Collective variables main module.