LCOV - code coverage report
Current view: top level - sizeshape - mahadist.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 192 206 93.2 %
Date: 2025-12-04 11:19:34 Functions: 8 10 80.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 by Glen Hocky, New York University on behalf of authors
       3             : 
       4             : The sizeshape module is free software: you can redistribute it and/or modify
       5             : it under the terms of the GNU Lesser General Public License as published by
       6             : the Free Software Foundation, either version 3 of the License, or
       7             : (at your option) any later version.
       8             : 
       9             : The sizeshape module is distributed in the hope that it will be useful,
      10             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      11             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      12             : GNU Lesser General Public License for more details.
      13             : 
      14             : You should have received a copy of the GNU Lesser General Public License
      15             : along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      16             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      17             : #include "colvar/Colvar.h"
      18             : #include "core/ActionRegister.h"
      19             : #include "tools/Pbc.h"
      20             : #include "tools/File.h"           // Input and output from files 
      21             : #include "tools/Matrix.h"         // Linear Algebra operations
      22             : #include <sstream>
      23             : #include <cmath>
      24             : 
      25             : namespace PLMD {
      26             : namespace sizeshape {
      27             : 
      28             : //+PLUMEDOC COLVAR SIZESHAPE_POSITION_MAHA_DIST
      29             : /*
      30             : Calculates Mahalanobis distance of a current configuration from a  given reference configurational distribution in size-and-shape space.
      31             : 
      32             : The Mahalanobis distance is given as:
      33             : 
      34             : $$
      35             : d(\mathbf{x}, \mathbf{\mu}, \mathbf{\Sigma}) = \sqrt{(\mathbf{x}-\mathbf{\mu})^T \mathbf{\Sigma}^{-1} (\mathbf{x}-\mathbf{\mu})}
      36             : $$
      37             : 
      38             : Here $\mathbf{x}$ is the configuration at time t, $\mathbf{\mu}$ is the reference and $\mathbf{\Sigma}^{-1}$ is the $N \times N$ precision matrix.
      39             : 
      40             : Size-and-shape Gaussian Mixture Model (shapeGMM) is a probabilistic clustering technique that is used to perform structural clusteing on ensemble of molecular configurations and to obtain reference
      41             : $(\mathbf{\mu})$ and precision $(\mathbf{\Sigma}^{-1})$ corresponding to each of the cluster centers. Please chcek out <a href="https://github.com/mccullaghlab/shapeGMMTorch">shapeGMMTorch-GitHub</a> and <a href="https://pypi.org/project/shapeGMMTorch/"> shapeGMMTorch-PyPI</a> for examples and informations on preforming shapeGMM clustering.
      42             : 
      43             : ## Examples
      44             : In the following example, a group is defined with atom indices of selected atoms and then Mahalanobis distance is calculated with respect to the given reference and precision. Each file is a space separated list of 3N floating point numbers.
      45             : 
      46             : ```plumed
      47             : #SETTINGS INPUTFILES=regtest/sizeshape/rt-mahadist/global_avg.txt
      48             : #SETTINGS INPUTFILES=regtest/sizeshape/rt-mahadist/global_precision.txt
      49             : 
      50             : UNITS LENGTH=A TIME=ps ENERGY=kcal/mol
      51             : GROUP ATOMS=18,20,22,31,33,35,44,46,48,57,59,61,70,72,74,83,85,87,96,98,100,109,111 LABEL=ga_list
      52             : d: SIZESHAPE_POSITION_MAHA_DIST ...
      53             :    REFERENCE=regtest/sizeshape/rt-mahadist/global_avg.txt
      54             :    PRECISION=regtest/sizeshape/rt-mahadist/global_precision.txt
      55             :    GROUP=ga_list
      56             : ...
      57             : PRINT ARG=d STRIDE=1 FILE=output FMT=%8.8f
      58             : ```
      59             : 
      60             : */
      61             : //+ENDPLUMEDOC
      62             : 
      63             : class position_maha_dist : public Colvar {
      64             : 
      65             : private:
      66             :   bool pbc, squared;
      67             :   std::string prec_f_name;                      // precision file name
      68             :   std::string ref_f_name;                       // reference file name
      69             :   IFile in_;                                    // create an object of class IFile
      70             :   //Log out_;
      71             :   Matrix<double> ref_str;                 // coords of reference
      72             :   Matrix<double> mobile_str;              // coords of mobile
      73             :   Matrix<double> prec;                            // precision data
      74             :   Matrix<double> rotation;
      75             :   Matrix<double> derv_;
      76             :   Matrix<double> derv_numeric;
      77             :   void readinputs();                            // reads the input data
      78             :   double dist;
      79             :   std::vector<AtomNumber> atom_list;            // list of atoms
      80             :   const double SMALL = 1.0E-30;
      81             :   const double delta = 0.00001;
      82             : public:
      83             :   static void registerKeywords( Keywords& keys );
      84             :   explicit position_maha_dist(const ActionOptions&);
      85             :   double determinant(int n, const std::vector<std::vector<double>>* B);
      86             :   void kabsch_rot_mat();                // gives rotation matrix
      87             :   double cal_maha_dist();               // calculates the mahalanobis distance
      88             :   void grad_maha(double);               // calculates the gradient
      89             :   void numeric_maha();                  // calculates the numeric gradient
      90             :   // active methods:
      91             :   void calculate() override;
      92             : };
      93             : 
      94             : PLUMED_REGISTER_ACTION(position_maha_dist,"SIZESHAPE_POSITION_MAHA_DIST")
      95             : 
      96           3 : void position_maha_dist::registerKeywords( Keywords& keys ) {
      97           3 :   Colvar::registerKeywords( keys );
      98           3 :   keys.add("compulsory", "PRECISION", "Precision Matrix (inverse of covariance)" );
      99           3 :   keys.add("compulsory", "REFERENCE", "Reference structure.");
     100           3 :   keys.add("atoms","GROUP","The group of atoms being used");
     101           3 :   keys.addFlag("SQUARED",false,"Returns the square of distance.");
     102           6 :   keys.setValueDescription("scalar","the Mahalanobis distance between the instantaneous configuration and a given reference distribution in size-and-shape space");
     103           3 : }
     104             : 
     105             : // constructor function
     106           1 : position_maha_dist::position_maha_dist(const ActionOptions&ao):
     107             :   PLUMED_COLVAR_INIT(ao),
     108           1 :   pbc(true),
     109           1 :   squared(false),
     110           1 :   dist(0),
     111           2 :   prec_f_name(""),
     112           1 :   ref_f_name("") {  // Note! no comma here in the last line.
     113           1 :   parseAtomList("GROUP",atom_list);
     114           1 :   parse("REFERENCE", ref_f_name);
     115           1 :   parse("PRECISION", prec_f_name);
     116             : 
     117           1 :   bool nopbc=!pbc;
     118           1 :   parseFlag("NOPBC",nopbc);
     119           1 :   parseFlag("SQUARED",squared);
     120           1 :   pbc=!nopbc;
     121             : 
     122           1 :   checkRead();
     123             : 
     124           1 :   log.printf("  of %u atoms\n",static_cast<unsigned>(atom_list.size()));
     125          24 :   for(unsigned int i=0; i<atom_list.size(); ++i) {
     126          23 :     log.printf("  %d", atom_list[i].serial());
     127             :   }
     128             : 
     129           1 :   if(squared) {
     130           0 :     log.printf("\n chosen to use SQUARED option for SIZESHAPE_POSITION_MAHA_DIST\n");
     131             :   }
     132             : 
     133           1 :   if(pbc) {
     134           1 :     log.printf("\n using periodic boundary conditions\n");
     135             :   } else {
     136           0 :     log.printf("\n without periodic boundary conditions\n");
     137             :   }
     138             : 
     139           1 :   addValueWithDerivatives();
     140           1 :   setNotPeriodic();
     141             : 
     142           1 :   requestAtoms(atom_list);
     143             : 
     144             :   // call the readinputs() function here
     145           1 :   readinputs();
     146             : 
     147           1 : }
     148             : 
     149             : // read inputs function
     150           1 : void position_maha_dist::readinputs() {
     151             :   unsigned N=getNumberOfAtoms();
     152             :   // read ref coords
     153           1 :   in_.open(ref_f_name);
     154             : 
     155           1 :   ref_str.resize(N,3);
     156           1 :   prec.resize(N,N);
     157             : 
     158             :   std::string line_, val_;
     159             :   unsigned c_=0;
     160             : 
     161          24 :   while (c_ < N) {
     162          23 :     in_.getline(line_);
     163             :     std::vector<std::string> items_;
     164          23 :     std::stringstream check_(line_);
     165             : 
     166          92 :     while(std::getline(check_, val_, ' ')) {
     167          69 :       items_.push_back(val_);
     168             :     }
     169          92 :     for(int i=0; i<3; ++i) {
     170          69 :       ref_str(c_,i) = std::stold(items_[i]);
     171             :     }
     172          23 :     c_ += 1;
     173          23 :   }
     174           1 :   in_.close();
     175             : 
     176             :   //read precision
     177           1 :   in_.open(prec_f_name);
     178             : 
     179             :   std::string line, val;
     180             :   unsigned int c = 0;
     181             : 
     182          24 :   while(c < N) {
     183          23 :     in_.getline(line);
     184             : 
     185             :     // vector for storing the objects
     186             :     std::vector<std::string> items;
     187             : 
     188             :     // stringstream helps to treat a string like an ifstream!
     189          23 :     std::stringstream check(line);
     190             : 
     191         552 :     while (std::getline(check, val, ' ')) {
     192         529 :       items.push_back(val);
     193             :     }
     194             : 
     195         552 :     for(unsigned int i=0; i<N; ++i) {
     196         529 :       prec(c, i) = std::stold(items[i]);
     197             :     }
     198             : 
     199          23 :     c += 1;
     200             : 
     201          23 :   }
     202           1 :   in_.close();
     203           1 : }
     204             : 
     205             : 
     206          10 : double position_maha_dist::determinant(int n, const std::vector<std::vector<double>>* B) {
     207             : 
     208          10 :   std::vector<std::vector<double>> A(n, std::vector<double>(n, 0));
     209             :   // make a copy first!
     210          40 :   for(int i=0; i<n; ++i) {
     211         120 :     for(int j=0; j<n; ++j) {
     212          90 :       A[i][j] = (*B)[i][j];
     213             :     }
     214             :   }
     215             : 
     216             : 
     217             :   //  It calculates determinant of a matrix using partial pivoting.
     218             : 
     219             :   double det = 1;
     220             : 
     221             :   // Row operations for i = 0, ,,,, n - 2 (n-1 not needed)
     222          30 :   for ( int i = 0; i < n - 1; i++ ) {
     223             :     // Partial pivot: find row r below with largest element in column i
     224             :     int r = i;
     225          20 :     double maxA = std::abs( A[i][i] );
     226          50 :     for ( int k = i + 1; k < n; k++ ) {
     227          30 :       double val = std::abs( A[k][i] );
     228          30 :       if ( val > maxA ) {
     229             :         r = k;
     230             :         maxA = val;
     231             :       }
     232             :     }
     233          20 :     if ( r != i ) {
     234          70 :       for ( int j = i; j < n; j++ ) {
     235          50 :         std::swap( A[i][j], A[r][j] );
     236             :       }
     237          20 :       det = -det;
     238             :     }
     239             : 
     240             :     // Row operations to make upper-triangular
     241          20 :     double pivot = A[i][i];
     242          20 :     if (std::abs( pivot ) < SMALL ) {
     243             :       return 0.0;  // Singular matrix
     244             :     }
     245             : 
     246          50 :     for ( int r = i + 1; r < n; r++ ) {                  // On lower rows
     247          30 :       double multiple = A[r][i] / pivot;                // Multiple of row i to clear element in ith column
     248         110 :       for ( int j = i; j < n; j++ ) {
     249          80 :         A[r][j] -= multiple * A[i][j];
     250             :       }
     251             :     }
     252          20 :     det *= pivot;                                        // Determinant is product of diagonal
     253             :   }
     254             : 
     255          10 :   det *= A[n-1][n-1];
     256             : 
     257          10 :   return det;
     258          10 : }
     259             : 
     260             : // kabsch rotation
     261           5 : void position_maha_dist::kabsch_rot_mat() {
     262             : 
     263             :   unsigned N=getNumberOfAtoms();
     264             : 
     265           5 :   Matrix<double> mobile_str_T(3,N);
     266           5 :   Matrix<double> prec_dot_ref_str(N,3);
     267           5 :   Matrix<double> correlation(3,3);
     268             : 
     269             : 
     270           5 :   transpose(mobile_str, mobile_str_T);
     271           5 :   mult(prec, ref_str, prec_dot_ref_str);
     272           5 :   mult(mobile_str_T, prec_dot_ref_str, correlation);
     273             : 
     274             : 
     275           5 :   int rw = correlation.nrows();
     276           5 :   int cl = correlation.ncols();
     277           5 :   int sz = rw*cl;
     278             : 
     279             :   // SVD part (taking from plu2/src/tools/Matrix.h: pseudoInvert function)
     280             : 
     281           5 :   std::vector<double> da(sz);
     282             :   unsigned k=0;
     283             : 
     284             :   // Transfer the matrix to the local array
     285          20 :   for (int i=0; i<cl; ++i)
     286          60 :     for (int j=0; j<rw; ++j) {
     287          45 :       da[k++]=static_cast<double>( correlation(j,i) );  // note! its [j][i] not [i][j]
     288             :     }
     289             : 
     290           5 :   int nsv, info, nrows=rw, ncols=cl;
     291             :   if(rw>cl) {
     292             :     nsv=cl;
     293             :   } else {
     294             :     nsv=rw;
     295             :   }
     296             : 
     297             :   // Create some containers for stuff from single value decomposition
     298           5 :   std::vector<double> S(nsv);
     299           5 :   std::vector<double> U(nrows*nrows);
     300           5 :   std::vector<double> VT(ncols*ncols);
     301           5 :   std::vector<int> iwork(8*nsv);
     302             : 
     303             :   // This optimizes the size of the work array used in lapack singular value decomposition
     304           5 :   int lwork=-1;
     305           5 :   std::vector<double> work(1);
     306           5 :   plumed_lapack_dgesdd( "A", &nrows, &ncols, da.data(), &nrows, S.data(), U.data(), &nrows, VT.data(), &ncols, work.data(), &lwork, iwork.data(), &info );
     307             :   //if(info!=0) return info;
     308           5 :   if(info!=0) {
     309           0 :     log.printf("info:", info);
     310             :   }
     311             : 
     312             :   // Retrieve correct sizes for work and rellocate
     313           5 :   lwork=(int) work[0];
     314           5 :   work.resize(lwork);
     315             : 
     316             :   // This does the singular value decomposition
     317           5 :   plumed_lapack_dgesdd( "A", &nrows, &ncols, da.data(), &nrows, S.data(), U.data(), &nrows, VT.data(), &ncols, work.data(), &lwork, iwork.data(), &info );
     318             :   //if(info!=0) return info;
     319           5 :   if(info!=0) {
     320           0 :     log.printf("info:", info);
     321             :   }
     322             : 
     323             : 
     324             :   // get U and VT in form of 2D vector (U_, VT_)
     325           5 :   std::vector<std::vector<double>> U_(nrows, std::vector<double>(nrows,0));
     326           5 :   std::vector<std::vector<double>> VT_(ncols, std::vector<double>(ncols,0));
     327             : 
     328             :   int  c=0;
     329             : 
     330          20 :   for(int i=0; i<nrows; ++i) {
     331          60 :     for(int j=0; j<nrows; ++j) {
     332          45 :       U_[j][i] = U[c];
     333          45 :       c += 1;
     334             :     }
     335             :   }
     336             :   c = 0; // note! its [j][i] not [i][j]
     337          20 :   for(int i=0; i<ncols; ++i) {
     338          60 :     for(int j=0; j<ncols; ++j) {
     339          45 :       VT_[j][i] = VT[c];
     340          45 :       c += 1;
     341             :     }
     342             :   }
     343             :   c=0; // note! its [j][i] not [i][j]
     344             : 
     345             : 
     346             :   // calculate determinants
     347           5 :   double det_u = determinant(nrows, &U_);
     348           5 :   double det_vt = determinant(ncols, &VT_);
     349             : 
     350             :   // check!
     351           5 :   if (det_u * det_vt < 0.0) {
     352           8 :     for(int i=0; i<nrows; ++i) {
     353           6 :       U_[i][nrows-1] *= -1;
     354             :     }
     355             :   }
     356             : 
     357             : 
     358             :   //Matrix<double> rotation(3,3);
     359           5 :   rotation.resize(3,3);
     360           5 :   Matrix<double> u(3,3), vt(3,3);
     361          20 :   for(int i=0; i<3; ++i) {
     362          60 :     for(int j=0; j<3; ++j) {
     363          45 :       u(i,j)=U_[i][j];
     364          45 :       vt(i,j)=VT_[i][j];
     365             :     }
     366             :   }
     367             : 
     368             :   // get rotation matrix
     369           5 :   mult(u, vt, rotation);
     370             : 
     371          10 : }
     372             : 
     373             : 
     374             : // calculates maha dist
     375           5 : double position_maha_dist::cal_maha_dist() {
     376             : 
     377             :   unsigned N=getNumberOfAtoms();
     378             : 
     379           5 :   Matrix<double> rotated_obj(N,3);
     380             :   // rotate the object
     381           5 :   mult(mobile_str, rotation, rotated_obj);
     382             : 
     383             :   // compute the displacement
     384           5 :   Matrix<double> disp(N,3);
     385         120 :   for(unsigned int i=0; i<N; ++i) {
     386         460 :     for(unsigned int j=0; j<3; ++j) {
     387         345 :       disp(i,j) = (rotated_obj(i,j)-ref_str(i,j));
     388             :     }
     389             :   }
     390             : 
     391           5 :   Matrix<double> prec_dot_disp(N,3);
     392           5 :   Matrix<double> disp_T(3,N);
     393           5 :   Matrix<double> out(3,3);
     394             : 
     395           5 :   mult(prec, disp, prec_dot_disp);
     396           5 :   transpose(disp, disp_T);
     397           5 :   mult(disp_T, prec_dot_disp, out);
     398             : 
     399             : 
     400             : 
     401             :   double maha_d=0.0;
     402          20 :   for(int i=0; i<3; ++i) {
     403          15 :     maha_d += out(i,i);
     404             :   }
     405             : 
     406           5 :   if (!squared) {
     407           5 :     maha_d = std::sqrt(maha_d);
     408             :   }
     409             : 
     410           5 :   return maha_d;
     411             : }
     412             : 
     413             : // gradient function
     414           5 : void position_maha_dist::grad_maha(double d) {
     415             : 
     416             :   unsigned N=getNumberOfAtoms();
     417             : 
     418           5 :   derv_.resize(N,3);
     419             : 
     420           5 :   Matrix<double> ref_str_rot_T(N,3);
     421           5 :   Matrix<double> rot_T(3,3);
     422           5 :   Matrix<double> diff_(N,3);
     423             : 
     424           5 :   transpose(rotation, rot_T);
     425           5 :   mult(ref_str, rot_T, ref_str_rot_T);
     426             : 
     427         120 :   for(unsigned i=0; i<N; ++i) {
     428         460 :     for(unsigned j=0; j<3; ++j) {
     429         345 :       diff_(i,j) = mobile_str(i,j) - ref_str_rot_T(i,j);
     430             :     }
     431             :   }
     432             : 
     433           5 :   mult(prec, diff_, derv_);
     434             : 
     435             :   //for(unsigned i=0; i<N; ++i){ for(unsigned j=0; j<3; ++j) {derv_(i,j) /= (N*d);} }  // dividing by N here!
     436         120 :   for(unsigned i=0; i<N; ++i) {
     437         460 :     for(unsigned j=0; j<3; ++j) {
     438         345 :       if (!squared) {
     439         345 :         derv_(i,j) /= d;
     440             :       } else {
     441           0 :         derv_(i,j) *= 2.0;
     442             :       }
     443             :     }
     444             :   }
     445             : 
     446             : 
     447           5 : }
     448             : 
     449             : 
     450             : // numeric gradient
     451           0 : void position_maha_dist::numeric_maha() {
     452             :   // This function performs numerical derivative.
     453             :   unsigned N=getNumberOfAtoms();
     454           0 :   derv_numeric.resize(N,3);
     455             : 
     456           0 :   for(unsigned int atom=0; atom<N; ++atom) {
     457           0 :     for(unsigned int j=0; j<3; ++j) {
     458           0 :       mobile_str(atom,j) += delta;
     459           0 :       kabsch_rot_mat();
     460           0 :       derv_numeric(atom,j) = (cal_maha_dist() - dist)/delta;
     461           0 :       mobile_str(atom,j) -= delta;
     462             :     }
     463             :   }
     464             : 
     465           0 : }
     466             : 
     467             : 
     468             : // calculator
     469           5 : void position_maha_dist::calculate() {
     470             : 
     471           5 :   if(pbc) {
     472           5 :     makeWhole();
     473             :   }
     474             :   unsigned N=getNumberOfAtoms();
     475             : 
     476           5 :   mobile_str.resize(N,3);
     477             : 
     478             :   // load the mobile str
     479         120 :   for(unsigned int i=0; i<N; ++i) {
     480         115 :     Vector pos=getPosition(i);  // const PLMD::Vector
     481         460 :     for(unsigned j=0; j<3; ++j) {
     482         345 :       mobile_str(i,j) = pos[j];
     483             :     }
     484             :   }
     485             : 
     486             :   // translating the structure to center of geometry
     487           5 :   double center_of_geometry[3]= {0.0, 0.0, 0.0};
     488             : 
     489         120 :   for(unsigned int i=0; i<N; ++i) {
     490         115 :     center_of_geometry[0] += mobile_str(i,0);
     491         115 :     center_of_geometry[1] += mobile_str(i,1);
     492         115 :     center_of_geometry[2] += mobile_str(i,2);
     493             :   }
     494             : 
     495         120 :   for(unsigned int i=0; i<N; ++i) {
     496         460 :     for(unsigned int j=0; j<3; ++j) {
     497         345 :       mobile_str(i,j) -= (center_of_geometry[j]/N);
     498             :     }
     499             :   }
     500             : 
     501           5 :   kabsch_rot_mat();
     502           5 :   dist = cal_maha_dist();
     503             : 
     504           5 :   grad_maha(dist);
     505             :   // set derivatives
     506         120 :   for(unsigned i=0; i<N; ++i) {
     507         115 :     Vector vi(derv_(i,0), derv_(i,1), derv_(i,2) );
     508         115 :     setAtomsDerivatives(i, vi);
     509             :   }
     510           5 :   setBoxDerivativesNoPbc();
     511           5 :   setValue(dist);
     512             : 
     513           5 : }
     514             : 
     515             : }
     516             : }
     517             : 
     518             : 
     519             : 

Generated by: LCOV version 1.16