1#ifndef ARITHMETICPATHCV_H
2#define ARITHMETICPATHCV_H
12namespace ArithmeticPathCV {
16template <
typename scalar_type>
21 void initialize(
size_t p_num_elements,
size_t p_total_frames, scalar_type p_lambda,
const vector<scalar_type>& p_weights,
colvarmodule* cvmodule_in);
22 void reComputeLambda(
const vector<scalar_type>& rmsd_between_refs);
23 template <
typename element_type>
24 void computeValue(
const vector<vector<element_type>>& frame_element_distances, scalar_type *s =
nullptr, scalar_type *z =
nullptr);
26 template <
typename element_type>
27 void computeDerivatives(
const vector<vector<element_type>>& frame_element_distances, vector<vector<element_type>> *dsdx =
nullptr, vector<vector<element_type>> *dzdx =
nullptr);
30 vector<scalar_type> squared_weights;
33 vector<scalar_type> exponents;
34 scalar_type max_exponent;
35 scalar_type saved_exponent_sum;
36 scalar_type normalization_factor;
41template <
typename scalar_type>
43 const vector<scalar_type>& p_weights,
colvarmodule* cvmodule_in) {
45 for (
size_t i = 0; i < p_weights.size(); ++i) squared_weights.push_back(p_weights[i] * p_weights[i]);
46 num_elements = p_num_elements;
47 total_frames = p_total_frames;
48 exponents.resize(total_frames);
49 normalization_factor = 1.0 /
static_cast<scalar_type
>(total_frames - 1);
50 saved_s = scalar_type();
51 saved_exponent_sum = scalar_type();
52 max_exponent = scalar_type();
53 cvmodule = cvmodule_in;
56template <
typename scalar_type>
57template <
typename element_type>
58void ArithmeticPathBase<scalar_type>::computeValue(
59 const vector<vector<element_type>>& frame_element_distances,
60 scalar_type *s, scalar_type *z)
62 for (
size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
63 scalar_type exponent_tmp = scalar_type();
64 for (
size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
65 exponent_tmp += squared_weights[j_elem] * frame_element_distances[i_frame][j_elem] * frame_element_distances[i_frame][j_elem];
67 exponents[i_frame] = exponent_tmp * -1.0 * lambda;
68 if (i_frame == 0 || exponents[i_frame] > max_exponent) max_exponent = exponents[i_frame];
70 scalar_type log_sum_exp_0 = scalar_type();
71 scalar_type log_sum_exp_1 = scalar_type();
72 for (
size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
73 exponents[i_frame] =
cvm::exp(exponents[i_frame] - max_exponent);
74 log_sum_exp_0 += exponents[i_frame];
75 log_sum_exp_1 += i_frame * exponents[i_frame];
77 saved_exponent_sum = log_sum_exp_0;
78 log_sum_exp_0 = max_exponent +
cvm::logn(log_sum_exp_0);
79 log_sum_exp_1 = max_exponent +
cvm::logn(log_sum_exp_1);
80 saved_s = normalization_factor *
cvm::exp(log_sum_exp_1 - log_sum_exp_0);
85 *z = -1.0 / lambda * log_sum_exp_0;
89template <
typename scalar_type>
90void ArithmeticPathBase<scalar_type>::reComputeLambda(
const vector<scalar_type>& rmsd_between_refs) {
91 scalar_type mean_square_displacements = 0.0;
92 for (
size_t i_frame = 1; i_frame < total_frames; ++i_frame) {
93 cvmodule->log(std::string(
"Distance between frame ") +
cvm::to_str(i_frame) +
" and " +
cvm::to_str(i_frame + 1)
94 +
" is " +
cvm::to_str(rmsd_between_refs[i_frame - 1]) + std::string(
"\n"));
95 mean_square_displacements += rmsd_between_refs[i_frame - 1] * rmsd_between_refs[i_frame - 1];
97 mean_square_displacements /= scalar_type(total_frames - 1);
98 lambda = 1.0 / mean_square_displacements;
102template <
typename scalar_type>
103template <
typename element_type>
104void ArithmeticPathBase<scalar_type>::computeDerivatives(
105 const vector<vector<element_type>>& frame_element_distances,
106 vector<vector<element_type>> *dsdx,
107 vector<vector<element_type>> *dzdx)
109 vector<scalar_type> softmax_out, tmps;
110 softmax_out.reserve(total_frames);
111 tmps.reserve(total_frames);
112 for (
size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
113 softmax_out.push_back(exponents[i_frame] / saved_exponent_sum);
115 (
static_cast<scalar_type
>(i_frame) -
116 static_cast<scalar_type
>(total_frames - 1) * saved_s) *
117 normalization_factor);
119 if (dsdx !=
nullptr) {
120 for (
size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
121 for (
size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
122 (*dsdx)[i_frame][j_elem] =
123 -2.0 * squared_weights[j_elem] * lambda *
124 frame_element_distances[i_frame][j_elem] *
125 softmax_out[i_frame] * tmps[i_frame];
129 if (dzdx !=
nullptr) {
130 for (
size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
131 for (
size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
132 (*dzdx)[i_frame][j_elem] =
133 2.0 * squared_weights[j_elem] * softmax_out[i_frame] *
134 frame_element_distances[i_frame][j_elem];
Definition: colvar_arithmeticpath.h:17
Collective variables module (main class)
Definition: colvarmodule.h:72
static real logn(real const &x)
Definition: colvarmodule.h:220
static real exp(real const &x)
Reimplemented to work around MS compiler issues.
Definition: colvarmodule.h:212
static std::string to_str(char const *s)
Convert to string for output purposes.
Definition: colvarmodule.cpp:2541
Collective variables main module.