Collective Variables Module - Developer Documentation
Loading...
Searching...
No Matches
colvar_arithmeticpath.h
1#ifndef ARITHMETICPATHCV_H
2#define ARITHMETICPATHCV_H
3
4#include "colvarmodule.h"
5
6#include <vector>
7#include <cmath>
8#include <limits>
9#include <string>
10#include <algorithm>
11
12namespace ArithmeticPathCV {
13
14using std::vector;
15
16template <typename scalar_type>
18public:
21 void initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda, const vector<scalar_type>& p_weights);
22 void reComputeLambda(const vector<scalar_type>& rmsd_between_refs);
23 template <typename element_type>
24 void computeValue(const vector<vector<element_type>>& frame_element_distances, scalar_type *s = nullptr, scalar_type *z = nullptr);
25 // can only be called after computeValue() for element-wise derivatives and store derivatives of i-th frame to dsdx and dzdx
26 template <typename element_type>
27 void computeDerivatives(const vector<vector<element_type>>& frame_element_distances, vector<vector<element_type>> *dsdx = nullptr, vector<vector<element_type>> *dzdx = nullptr);
28protected:
29 scalar_type lambda;
30 vector<scalar_type> squared_weights;
31 size_t num_elements;
32 size_t total_frames;
33 vector<scalar_type> exponents;
34 scalar_type max_exponent;
35 scalar_type saved_exponent_sum;
36 scalar_type normalization_factor;
37 scalar_type saved_s;
38};
39
40template <typename scalar_type>
41void ArithmeticPathBase<scalar_type>::initialize(size_t p_num_elements, size_t p_total_frames, scalar_type p_lambda, const vector<scalar_type>& p_weights) {
42 lambda = p_lambda;
43 for (size_t i = 0; i < p_weights.size(); ++i) squared_weights.push_back(p_weights[i] * p_weights[i]);
44 num_elements = p_num_elements;
45 total_frames = p_total_frames;
46 exponents.resize(total_frames);
47 normalization_factor = 1.0 / static_cast<scalar_type>(total_frames - 1);
48 saved_s = scalar_type();
49 saved_exponent_sum = scalar_type();
50 max_exponent = scalar_type();
51}
52
53template <typename scalar_type>
54template <typename element_type>
55void ArithmeticPathBase<scalar_type>::computeValue(
56 const vector<vector<element_type>>& frame_element_distances,
57 scalar_type *s, scalar_type *z)
58{
59 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
60 scalar_type exponent_tmp = scalar_type();
61 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
62 exponent_tmp += squared_weights[j_elem] * frame_element_distances[i_frame][j_elem] * frame_element_distances[i_frame][j_elem];
63 }
64 exponents[i_frame] = exponent_tmp * -1.0 * lambda;
65 if (i_frame == 0 || exponents[i_frame] > max_exponent) max_exponent = exponents[i_frame];
66 }
67 scalar_type log_sum_exp_0 = scalar_type();
68 scalar_type log_sum_exp_1 = scalar_type();
69 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
70 exponents[i_frame] = cvm::exp(exponents[i_frame] - max_exponent);
71 log_sum_exp_0 += exponents[i_frame];
72 log_sum_exp_1 += i_frame * exponents[i_frame];
73 }
74 saved_exponent_sum = log_sum_exp_0;
75 log_sum_exp_0 = max_exponent + cvm::logn(log_sum_exp_0);
76 log_sum_exp_1 = max_exponent + cvm::logn(log_sum_exp_1);
77 saved_s = normalization_factor * cvm::exp(log_sum_exp_1 - log_sum_exp_0);
78 if (s != nullptr) {
79 *s = saved_s;
80 }
81 if (z != nullptr) {
82 *z = -1.0 / lambda * log_sum_exp_0;
83 }
84}
85
86template <typename scalar_type>
87void ArithmeticPathBase<scalar_type>::reComputeLambda(const vector<scalar_type>& rmsd_between_refs) {
88 scalar_type mean_square_displacements = 0.0;
89 for (size_t i_frame = 1; i_frame < total_frames; ++i_frame) {
90 cvm::log(std::string("Distance between frame ") + cvm::to_str(i_frame) + " and " + cvm::to_str(i_frame + 1) + " is " + cvm::to_str(rmsd_between_refs[i_frame - 1]) + std::string("\n"));
91 mean_square_displacements += rmsd_between_refs[i_frame - 1] * rmsd_between_refs[i_frame - 1];
92 }
93 mean_square_displacements /= scalar_type(total_frames - 1);
94 lambda = 1.0 / mean_square_displacements;
95}
96
97// frame-wise derivatives for frames using optimal rotation
98template <typename scalar_type>
99template <typename element_type>
100void ArithmeticPathBase<scalar_type>::computeDerivatives(
101 const vector<vector<element_type>>& frame_element_distances,
102 vector<vector<element_type>> *dsdx,
103 vector<vector<element_type>> *dzdx)
104{
105 vector<scalar_type> softmax_out, tmps;
106 softmax_out.reserve(total_frames);
107 tmps.reserve(total_frames);
108 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
109 softmax_out.push_back(exponents[i_frame] / saved_exponent_sum);
110 tmps.push_back(
111 (static_cast<scalar_type>(i_frame) -
112 static_cast<scalar_type>(total_frames - 1) * saved_s) *
113 normalization_factor);
114 }
115 if (dsdx != nullptr) {
116 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
117 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
118 (*dsdx)[i_frame][j_elem] =
119 -2.0 * squared_weights[j_elem] * lambda *
120 frame_element_distances[i_frame][j_elem] *
121 softmax_out[i_frame] * tmps[i_frame];
122 }
123 }
124 }
125 if (dzdx != nullptr) {
126 for (size_t i_frame = 0; i_frame < total_frames; ++i_frame) {
127 for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
128 (*dzdx)[i_frame][j_elem] =
129 2.0 * squared_weights[j_elem] * softmax_out[i_frame] *
130 frame_element_distances[i_frame][j_elem];
131 }
132 }
133 }
134}
135}
136
137#endif // ARITHMETICPATHCV_H
Definition: colvar_arithmeticpath.h:17
static real logn(real const &x)
Definition: colvarmodule.h:177
static real exp(real const &x)
Reimplemented to work around MS compiler issues.
Definition: colvarmodule.h:169
static void log(std::string const &message, int min_log_level=10)
Definition: colvarmodule.cpp:1955
static std::string to_str(char const *s)
Convert to string for output purposes.
Definition: colvarmodule.cpp:2378
Collective variables main module.