Collective Variables Module - Developer Documentation
Loading...
Searching...
No Matches
colvar_arithmeticpath.h
1#ifndef ARITHMETICPATHCV_H
2#define ARITHMETICPATHCV_H
3
4#include "colvarmodule.h"
5
6#include <vector>
7#include <cmath>
8#include <limits>
9#include <string>
10#include <algorithm>
11
12namespace ArithmeticPathCV {
13
14using std::vector;
15
16template <typename scalar_type>
18public:
21 void initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda, const vector<scalar_type>& p_weights, colvarmodule* cvmodule_in);
22 void reComputeLambda(const vector<scalar_type>& rmsd_between_refs);
23 template <typename element_type>
24 void computeValue(const vector<vector<element_type>>& frame_element_distances, scalar_type *s = nullptr, scalar_type *z = nullptr);
25 // can only be called after computeValue() for element-wise derivatives and store derivatives of i-th frame to dsdx and dzdx
26 template <typename element_type>
27 void computeDerivatives(const vector<vector<element_type>>& frame_element_distances, vector<vector<element_type>> *dsdx = nullptr, vector<vector<element_type>> *dzdx = nullptr);
28protected:
29 scalar_type lambda;
30 vector<scalar_type> squared_weights;
31 size_t num_elements;
32 size_t total_frames;
33 vector<scalar_type> exponents;
34 scalar_type max_exponent;
35 scalar_type saved_exponent_sum;
36 scalar_type normalization_factor;
37 scalar_type saved_s;
38 colvarmodule* cvmodule;
39};
40
41template <typename scalar_type>
42void ArithmeticPathBase<scalar_type>::initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda,
43 const vector<scalar_type>& p_weights, colvarmodule* cvmodule_in) {
44 lambda = p_lambda;
45 for (size_t i = 0; i < p_weights.size(); ++i) squared_weights.push_back(p_weights[i] * p_weights[i]);
46 num_elements = p_num_elements;
47 total_frames = p_total_frames;
48 exponents.resize(total_frames);
49 normalization_factor = 1.0 / static_cast<scalar_type>(total_frames - 1);
50 saved_s = scalar_type();
51 saved_exponent_sum = scalar_type();
52 max_exponent = scalar_type();
53 cvmodule = cvmodule_in;
54}
55
56template <typename scalar_type>
57template <typename element_type>
58void ArithmeticPathBase<scalar_type>::computeValue(
59 const vector<vector<element_type>>& frame_element_distances,
60 scalar_type *s, scalar_type *z)
61{
62 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
63 scalar_type exponent_tmp = scalar_type();
64 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
65 exponent_tmp += squared_weights[j_elem] * frame_element_distances[i_frame][j_elem] * frame_element_distances[i_frame][j_elem];
66 }
67 exponents[i_frame] = exponent_tmp * -1.0 * lambda;
68 if (i_frame == 0 || exponents[i_frame] > max_exponent) max_exponent = exponents[i_frame];
69 }
70 scalar_type log_sum_exp_0 = scalar_type();
71 scalar_type log_sum_exp_1 = scalar_type();
72 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
73 exponents[i_frame] = cvm::exp(exponents[i_frame] - max_exponent);
74 log_sum_exp_0 += exponents[i_frame];
75 log_sum_exp_1 += i_frame * exponents[i_frame];
76 }
77 saved_exponent_sum = log_sum_exp_0;
78 log_sum_exp_0 = max_exponent + cvm::logn(log_sum_exp_0);
79 log_sum_exp_1 = max_exponent + cvm::logn(log_sum_exp_1);
80 saved_s = normalization_factor * cvm::exp(log_sum_exp_1 - log_sum_exp_0);
81 if (s != nullptr) {
82 *s = saved_s;
83 }
84 if (z != nullptr) {
85 *z = -1.0 / lambda * log_sum_exp_0;
86 }
87}
88
89template <typename scalar_type>
90void ArithmeticPathBase<scalar_type>::reComputeLambda(const vector<scalar_type>& rmsd_between_refs) {
91 scalar_type mean_square_displacements = 0.0;
92 for (size_t i_frame = 1; i_frame < total_frames; ++i_frame) {
93 cvmodule->log(std::string("Distance between frame ") + cvm::to_str(i_frame) + " and " + cvm::to_str(i_frame + 1)
94 + " is " + cvm::to_str(rmsd_between_refs[i_frame - 1]) + std::string("\n"));
95 mean_square_displacements += rmsd_between_refs[i_frame - 1] * rmsd_between_refs[i_frame - 1];
96 }
97 mean_square_displacements /= scalar_type(total_frames - 1);
98 lambda = 1.0 / mean_square_displacements;
99}
100
101// frame-wise derivatives for frames using optimal rotation
102template <typename scalar_type>
103template <typename element_type>
104void ArithmeticPathBase<scalar_type>::computeDerivatives(
105 const vector<vector<element_type>>& frame_element_distances,
106 vector<vector<element_type>> *dsdx,
107 vector<vector<element_type>> *dzdx)
108{
109 vector<scalar_type> softmax_out, tmps;
110 softmax_out.reserve(total_frames);
111 tmps.reserve(total_frames);
112 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
113 softmax_out.push_back(exponents[i_frame] / saved_exponent_sum);
114 tmps.push_back(
115 (static_cast<scalar_type>(i_frame) -
116 static_cast<scalar_type>(total_frames - 1) * saved_s) *
117 normalization_factor);
118 }
119 if (dsdx != nullptr) {
120 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
121 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
122 (*dsdx)[i_frame][j_elem] =
123 -2.0 * squared_weights[j_elem] * lambda *
124 frame_element_distances[i_frame][j_elem] *
125 softmax_out[i_frame] * tmps[i_frame];
126 }
127 }
128 }
129 if (dzdx != nullptr) {
130 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
131 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
132 (*dzdx)[i_frame][j_elem] =
133 2.0 * squared_weights[j_elem] * softmax_out[i_frame] *
134 frame_element_distances[i_frame][j_elem];
135 }
136 }
137 }
138}
139}
140
141#endif // ARITHMETICPATHCV_H
Definition: colvar_arithmeticpath.h:17
Collective variables module (main class)
Definition: colvarmodule.h:72
static real logn(real const &x)
Definition: colvarmodule.h:220
static real exp(real const &x)
Reimplemented to work around MS compiler issues.
Definition: colvarmodule.h:212
static std::string to_str(char const *s)
Convert to string for output purposes.
Definition: colvarmodule.cpp:2541
Collective variables main module.