LCOV - code coverage report
Current view: top level - metatomic - vesin.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 268 377 71.1 %
Date: 2025-12-04 11:19:34 Functions: 27 31 87.1 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATOMIC-PLUMED team
       3             : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/metatomic/ for more information about the
       6             : metatomic package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATOMIC-PLUMED module.
       9             : 
      10             : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : /*INDENT-OFF*/
      24             : #include "vesin.h"
      25             : #include <cassert>
      26             : #include <cstdlib>
      27             : #include <cstring>
      28             : 
      29             : #include <algorithm>
      30             : #include <numeric>
      31             : #include <tuple>
      32             : #include <new>
      33             : 
      34             : #ifndef VESIN_CPU_CELL_LIST_HPP
      35             : #define VESIN_CPU_CELL_LIST_HPP
      36             : 
      37             : #include <vector>
      38             : 
      39             : #include "vesin.h"
      40             : 
      41             : #ifndef VESIN_TYPES_HPP
      42             : #define VESIN_TYPES_HPP
      43             : 
      44             : #ifndef VESIN_MATH_HPP
      45             : #define VESIN_MATH_HPP
      46             : 
      47             : #include <array>
      48             : #include <cmath>
      49             : #include <stdexcept>
      50             : 
      51             : namespace PLMD {
      52             : namespace metatomic {
      53             : namespace vesin {
      54             : struct Vector;
      55             : 
      56             : Vector operator*(Vector vector, double scalar);
      57             : 
      58             : struct Vector: public std::array<double, 3> {
      59             :     double dot(Vector other) const {
      60      109229 :         return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2];
      61             :     }
      62             : 
      63          45 :     double norm() const {
      64          45 :         return std::sqrt(this->dot(*this));
      65             :     }
      66             : 
      67          45 :     Vector normalize() const {
      68          45 :         return *this * (1.0 / this->norm());
      69             :     }
      70             : 
      71             :     Vector cross(Vector other) const {
      72             :         return Vector{
      73          45 :             (*this)[1] * other[2] - (*this)[2] * other[1],
      74          45 :             (*this)[2] * other[0] - (*this)[0] * other[2],
      75          45 :             (*this)[0] * other[1] - (*this)[1] * other[0],
      76          45 :         };
      77             :     }
      78             : };
      79             : 
      80             : inline Vector operator+(Vector u, Vector v) {
      81             :     return Vector{
      82      109169 :         u[0] + v[0],
      83      109169 :         u[1] + v[1],
      84      109169 :         u[2] + v[2],
      85             :     };
      86             : }
      87             : 
      88             : inline Vector operator-(Vector u, Vector v) {
      89             :     return Vector{
      90      109169 :         u[0] - v[0],
      91      109169 :         u[1] - v[1],
      92      109169 :         u[2] - v[2],
      93             :     };
      94             : }
      95             : 
      96             : inline Vector operator*(double scalar, Vector vector) {
      97             :     return Vector{
      98             :         scalar * vector[0],
      99             :         scalar * vector[1],
     100             :         scalar * vector[2],
     101             :     };
     102             : }
     103             : 
     104             : inline Vector operator*(Vector vector, double scalar) {
     105             :     return Vector{
     106          45 :         scalar * vector[0],
     107          45 :         scalar * vector[1],
     108          45 :         scalar * vector[2],
     109          45 :     };
     110             : }
     111             : 
     112             : 
     113             : struct Matrix: public std::array<std::array<double, 3>, 3> {
     114          16 :     double determinant() const {
     115          16 :         return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2])
     116          16 :              - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0])
     117          16 :              + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]);
     118             :     }
     119             : 
     120           8 :     Matrix inverse() const {
     121           8 :         auto det = this->determinant();
     122             : 
     123           8 :         if (std::abs(det) < 1e-30) {
     124           0 :             throw std::runtime_error("this matrix is not invertible");
     125             :         }
     126             : 
     127             :         auto inverse = Matrix();
     128           8 :         inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det;
     129           8 :         inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det;
     130           8 :         inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det;
     131           8 :         inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det;
     132           8 :         inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det;
     133           8 :         inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det;
     134           8 :         inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det;
     135           8 :         inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det;
     136           8 :         inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det;
     137           8 :         return inverse;
     138             :     }
     139             : };
     140             : 
     141             : 
     142             : inline Vector operator*(Matrix matrix, Vector vector) {
     143             :     return Vector{
     144             :         matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2],
     145             :         matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2],
     146             :         matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2],
     147             :     };
     148             : }
     149             : 
     150      111524 : inline Vector operator*(Vector vector, Matrix matrix) {
     151             :     return Vector{
     152      111524 :         vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0],
     153      111524 :         vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1],
     154      111524 :         vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2],
     155      111524 :     };
     156             : }
     157             : 
     158             : } // namespace vesin
     159             : } // namespace metatomic
     160             : } // namespace PLMD
     161             : 
     162             : #endif
     163             : 
     164             : namespace PLMD {
     165             : namespace metatomic {
     166             : namespace vesin {
     167             : 
     168             : class BoundingBox {
     169             : public:
     170          15 :     BoundingBox(Matrix matrix, bool periodic): matrix_(matrix), periodic_(periodic) {
     171          15 :         if (periodic) {
     172           8 :             auto det = matrix_.determinant();
     173           8 :             if (std::abs(det) < 1e-30) {
     174           0 :                 throw std::runtime_error("the box matrix is not invertible");
     175             :             }
     176             : 
     177           8 :             this->inverse_ = matrix_.inverse();
     178             :         } else {
     179           7 :             this->matrix_ = Matrix{{{
     180             :                 {{1, 0, 0}},
     181             :                 {{0, 1, 0}},
     182             :                 {{0, 0, 1}}
     183             :             }}};
     184           7 :             this->inverse_ = matrix_;
     185             :         }
     186          15 :     }
     187             : 
     188             :     const Matrix& matrix() const {
     189             :         return this->matrix_;
     190             :     }
     191             : 
     192             :     bool periodic() const {
     193      214354 :         return this->periodic_;
     194             :     }
     195             : 
     196             :     /// Convert a vector from cartesian coordinates to fractional coordinates
     197             :     Vector cartesian_to_fractional(Vector cartesian) const {
     198        2355 :         return cartesian * inverse_;
     199             :     }
     200             : 
     201             :     /// Convert a vector from fractional coordinates to cartesian coordinates
     202             :     Vector fractional_to_cartesian(Vector fractional) const {
     203             :         return fractional * matrix_;
     204             :     }
     205             : 
     206             :     /// Get the three distances between faces of the bounding box
     207          15 :     Vector distances_between_faces() const {
     208          15 :         auto a = Vector{matrix_[0]};
     209          15 :         auto b = Vector{matrix_[1]};
     210          15 :         auto c = Vector{matrix_[2]};
     211             : 
     212             :         // Plans normal vectors
     213          15 :         auto na = b.cross(c).normalize();
     214          15 :         auto nb = c.cross(a).normalize();
     215          15 :         auto nc = a.cross(b).normalize();
     216             : 
     217             :         return Vector{
     218             :             std::abs(na.dot(a)),
     219             :             std::abs(nb.dot(b)),
     220             :             std::abs(nc.dot(c)),
     221          15 :         };
     222             :     }
     223             : 
     224             : private:
     225             :     Matrix matrix_;
     226             :     Matrix inverse_;
     227             :     bool periodic_;
     228             : };
     229             : 
     230             : 
     231             : /// A cell shift represents the displacement along cell axis between the actual
     232             : /// position of an atom and a periodic image of this atom.
     233             : ///
     234             : /// The cell shift can be used to reconstruct the vector between two points,
     235             : /// wrapped inside the unit cell.
     236             : struct CellShift: public std::array<int32_t, 3> {
     237             :     /// Compute the shift vector in cartesian coordinates, using the given cell
     238             :     /// matrix (stored in row major order).
     239      109169 :     Vector cartesian(Matrix cell) const {
     240             :         auto vector = Vector{
     241      109169 :             static_cast<double>((*this)[0]),
     242      109169 :             static_cast<double>((*this)[1]),
     243      109169 :             static_cast<double>((*this)[2]),
     244      109169 :         };
     245      109169 :         return vector * cell;
     246             :     }
     247             : };
     248             : 
     249             : inline CellShift operator+(CellShift a, CellShift b) {
     250             :     return CellShift{
     251      211981 :         a[0] + b[0],
     252      211981 :         a[1] + b[1],
     253      211981 :         a[2] + b[2],
     254             :     };
     255             : }
     256             : 
     257             : inline CellShift operator-(CellShift a, CellShift b) {
     258             :     return CellShift{
     259      211981 :         a[0] - b[0],
     260      211981 :         a[1] - b[1],
     261      211981 :         a[2] - b[2],
     262             :     };
     263             : }
     264             : 
     265             : 
     266             : } // namespace vesin
     267             : } // namespace metatomic
     268             : } // namespace PLMD
     269             : 
     270             : #endif
     271             : 
     272             : namespace PLMD {
     273             : namespace metatomic {
     274             : namespace vesin { namespace cpu {
     275             : 
     276             : void free_neighbors(VesinNeighborList& neighbors);
     277             : 
     278             : void neighbors(
     279             :     const Vector* points,
     280             :     size_t n_points,
     281             :     BoundingBox cell,
     282             :     VesinOptions options,
     283             :     VesinNeighborList& neighbors
     284             : );
     285             : 
     286             : 
     287             : /// The cell list is used to sort atoms inside bins/cells.
     288             : ///
     289             : /// The list of potential pairs is then constructed by looking through all
     290             : /// neighboring cells (the number of cells to search depends on the cutoff and
     291             : /// the size of the cells) for each atom to create pair candidates.
     292          15 : class CellList {
     293             : public:
     294             :     /// Create a new `CellList` for the given bounding box and cutoff,
     295             :     /// determining all required parameters.
     296             :     CellList(BoundingBox box, double cutoff);
     297             : 
     298             :     /// Add a single point to the cell list at the given `position`. The point
     299             :     /// is uniquely identified by its `index`.
     300             :     void add_point(size_t index, Vector position);
     301             : 
     302             :     /// Iterate over all possible pairs, calling the given callback every time
     303             :     template <typename Function>
     304             :     void foreach_pair(Function callback);
     305             : 
     306             : private:
     307             :     /// How many cells do we need to look at when searching neighbors to include
     308             :     /// all neighbors below cutoff
     309             :     std::array<int32_t, 3> n_search_;
     310             : 
     311             :     /// the cells themselves are a list of points & corresponding
     312             :     /// shift to place the point inside the cell
     313             :     struct Point {
     314             :         size_t index;
     315             :         CellShift shift;
     316             :     };
     317         872 :     struct Cell: public std::vector<Point> {};
     318             : 
     319             :     // raw data for the cells
     320             :     std::vector<Cell> cells_;
     321             :     // shape of the cell array
     322             :     std::array<size_t, 3> cells_shape_;
     323             : 
     324             :     BoundingBox box_;
     325             : 
     326             :     Cell& get_cell(std::array<int32_t, 3> index);
     327             : };
     328             : 
     329             : /// Wrapper around `VesinNeighborList` that behaves like a std::vector,
     330             : /// automatically growing memory allocations.
     331             : class GrowableNeighborList {
     332             : public:
     333             :     VesinNeighborList& neighbors;
     334             :     size_t capacity;
     335             :     VesinOptions options;
     336             : 
     337             :     size_t length() const {
     338       15174 :         return neighbors.length;
     339             :     }
     340             : 
     341             :     void increment_length() {
     342       15174 :         neighbors.length += 1;
     343       15174 :     }
     344             : 
     345             :     void set_pair(size_t index, size_t first, size_t second);
     346             :     void set_shift(size_t index, PLMD::metatomic::vesin::CellShift shift);
     347             :     void set_distance(size_t index, double distance);
     348             :     void set_vector(size_t index, PLMD::metatomic::vesin::Vector vector);
     349             : 
     350             :     // reset length to 0, and allocate/deallocate members of
     351             :     // `neighbors` according to `options`
     352             :     void reset();
     353             : 
     354             :     // allocate more memory & update capacity
     355             :     void grow();
     356             : 
     357             :     // sort the pairs currently in the neighbor list
     358             :     void sort();
     359             : };
     360             : 
     361             : } // namespace vesin
     362             : } // namespace metatomic
     363             : } // namespace PLMD
     364             : } // namespace cpu
     365             : 
     366             : #endif
     367             : 
     368             : using namespace PLMD::metatomic::vesin;
     369             : using namespace PLMD::metatomic::vesin::cpu;
     370             : 
     371          15 : void PLMD::metatomic::vesin::cpu::neighbors(
     372             :     const Vector* points,
     373             :     size_t n_points,
     374             :     BoundingBox cell,
     375             :     VesinOptions options,
     376             :     VesinNeighborList& raw_neighbors
     377             : ) {
     378          15 :     auto cell_list = CellList(cell, options.cutoff);
     379             : 
     380        2370 :     for (size_t i=0; i<n_points; i++) {
     381        2355 :         cell_list.add_point(i, points[i]);
     382             :     }
     383             : 
     384          15 :     auto cell_matrix = cell.matrix();
     385          15 :     auto cutoff2 = options.cutoff * options.cutoff;
     386             : 
     387             :     // the cell list creates too many pairs, we only need to keep the
     388             :     // one where the distance is actually below the cutoff
     389          15 :     auto neighbors = GrowableNeighborList{raw_neighbors, raw_neighbors.length, options};
     390          15 :     neighbors.reset();
     391             : 
     392          15 :     cell_list.foreach_pair([&](size_t first, size_t second, CellShift shift) {
     393      209626 :         if (!options.full) {
     394             :             // filter out some pairs for half neighbor lists
     395      200914 :             if (first > second) {
     396             :                 return;
     397             :             }
     398             : 
     399      100457 :             if (first == second) {
     400             :                 // When creating pairs between a point and one of its periodic
     401             :                 // images, the code generate multiple redundant pairs (e.g. with
     402             :                 // shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
     403             :                 // these.
     404           0 :                 if (shift[0] + shift[1] + shift[2] < 0) {
     405             :                     // drop shifts on the negative half-space
     406             :                     return;
     407             :                 }
     408             : 
     409             :                 if ((shift[0] + shift[1] + shift[2] == 0)
     410           0 :                     && (shift[2] < 0 || (shift[2] == 0 && shift[1] < 0))) {
     411             :                     // drop shifts in the negative half plane or the negative
     412             :                     // shift[1] axis. See below for a graphical representation:
     413             :                     // we are keeping the shifts indicated with `O` and dropping
     414             :                     // the ones indicated with `X`
     415             :                     //
     416             :                     //  O O O │ O O O
     417             :                     //  O O O │ O O O
     418             :                     //  O O O │ O O O
     419             :                     // ─X─X─X─┼─O─O─O─
     420             :                     //  X X X │ X X X
     421             :                     //  X X X │ X X X
     422             :                     //  X X X │ X X X
     423             :                     return;
     424             :                 }
     425             :             }
     426             :         }
     427             : 
     428      109169 :         auto vector = points[second] - points[first] + shift.cartesian(cell_matrix);
     429             :         auto distance2 = vector.dot(vector);
     430             : 
     431      109169 :         if (distance2 < cutoff2) {
     432       15174 :             auto index = neighbors.length();
     433       15174 :             neighbors.set_pair(index, first, second);
     434             : 
     435       15174 :             if (options.return_shifts) {
     436       15174 :                 neighbors.set_shift(index, shift);
     437             :             }
     438             : 
     439       15174 :             if (options.return_distances) {
     440           0 :                 neighbors.set_distance(index, std::sqrt(distance2));
     441             :             }
     442             : 
     443       15174 :             if (options.return_vectors) {
     444       15174 :                 neighbors.set_vector(index, vector);
     445             :             }
     446             : 
     447             :             neighbors.increment_length();
     448             :         }
     449             :     });
     450             : 
     451          15 :     if (options.sorted) {
     452           0 :         neighbors.sort();
     453             :     }
     454          15 : }
     455             : 
     456             : /* ========================================================================== */
     457             : 
     458             : /// Maximal number of cells, we need to use this to prevent having too many
     459             : /// cells with a small bounding box and a large cutoff
     460             : #define MAX_NUMBER_OF_CELLS 1e5
     461             : 
     462             : 
     463             : /// Function to compute both quotient and remainder of the division of a by b.
     464             : /// This function follows Python convention, making sure the remainder have the
     465             : /// same sign as `b`.
     466       77520 : static std::tuple<int32_t, int32_t> divmod(int32_t a, size_t b) {
     467             :     assert(b < (std::numeric_limits<int32_t>::max()));
     468       77520 :     auto b_32 = static_cast<int32_t>(b);
     469       77520 :     auto quotient = a / b_32;
     470       77520 :     auto remainder = a % b_32;
     471       77520 :     if (remainder < 0) {
     472        3975 :         remainder += b_32;
     473        3975 :         quotient -= 1;
     474             :     }
     475       77520 :     return std::make_tuple(quotient, remainder);
     476             : }
     477             : 
     478             : /// Apply the `divmod` function to three components at the time
     479             : static std::tuple<std::array<int32_t, 3>, std::array<int32_t, 3>>
     480       25840 : divmod(std::array<int32_t, 3> a, std::array<size_t, 3> b) {
     481       25840 :     auto [qx, rx] = divmod(a[0], b[0]);
     482       25840 :     auto [qy, ry] = divmod(a[1], b[1]);
     483       25840 :     auto [qz, rz] = divmod(a[2], b[2]);
     484             :     return std::make_tuple(
     485       25840 :         std::array<int32_t, 3>{qx, qy, qz},
     486       25840 :         std::array<int32_t, 3>{rx, ry, rz}
     487       25840 :     );
     488             : }
     489             : 
     490          15 : CellList::CellList(BoundingBox box, double cutoff):
     491          15 :     n_search_({0, 0, 0}),
     492          15 :     cells_shape_({0, 0, 0}),
     493          15 :     box_(box)
     494             : {
     495          15 :     auto distances_between_faces = box_.distances_between_faces();
     496             : 
     497             :     auto n_cells = Vector{
     498          15 :         std::clamp(std::trunc(distances_between_faces[0] / cutoff), 1.0, HUGE_VAL),
     499          30 :         std::clamp(std::trunc(distances_between_faces[1] / cutoff), 1.0, HUGE_VAL),
     500          30 :         std::clamp(std::trunc(distances_between_faces[2] / cutoff), 1.0, HUGE_VAL),
     501          28 :     };
     502             : 
     503             :     assert(std::isfinite(n_cells[0]) && std::isfinite(n_cells[1]) && std::isfinite(n_cells[2]));
     504             : 
     505             :     // limit memory consumption by ensuring we have less than `MAX_N_CELLS`
     506             :     // cells to look though
     507          15 :     auto n_cells_total = n_cells[0] * n_cells[1] * n_cells[2];
     508          15 :     if (n_cells_total > MAX_NUMBER_OF_CELLS) {
     509             :         // set the total number of cells close to MAX_N_CELLS, while keeping
     510             :         // roughly the ratio of cells in each direction
     511           0 :         auto ratio_x_y = n_cells[0] / n_cells[1];
     512           0 :         auto ratio_y_z = n_cells[1] / n_cells[2];
     513             : 
     514           0 :         n_cells[2] = std::trunc(std::cbrt(MAX_NUMBER_OF_CELLS / (ratio_x_y * ratio_y_z * ratio_y_z)));
     515           0 :         n_cells[1] = std::trunc(ratio_y_z * n_cells[2]);
     516           0 :         n_cells[0] = std::trunc(ratio_x_y * n_cells[1]);
     517             :     }
     518             : 
     519             :     // number of cells to search in each direction to make sure all possible
     520             :     // pairs below the cutoff are accounted for.
     521          15 :     this->n_search_ = std::array<int32_t, 3>{
     522          15 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[0] / distances_between_faces[0])),
     523          15 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[1] / distances_between_faces[1])),
     524          15 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[2] / distances_between_faces[2])),
     525             :     };
     526             : 
     527          15 :     this->cells_shape_ = std::array<size_t, 3>{
     528          15 :         static_cast<size_t>(n_cells[0]),
     529          15 :         static_cast<size_t>(n_cells[1]),
     530          15 :         static_cast<size_t>(n_cells[2]),
     531             :     };
     532             : 
     533          60 :     for (size_t spatial=0; spatial<3; spatial++) {
     534          45 :         if (n_search_[spatial] < 1) {
     535           0 :             n_search_[spatial] = 1;
     536             :         }
     537             : 
     538             :         // don't look for neighboring cells if we have only one cell and no
     539             :         // periodic boundary condition
     540          45 :         if (n_cells[spatial] == 1 && !box.periodic()) {
     541           6 :             n_search_[spatial] = 0;
     542             :         }
     543             :     }
     544             : 
     545          15 :     this->cells_.resize(cells_shape_[0] * cells_shape_[1] * cells_shape_[2]);
     546          15 : }
     547             : 
     548        2355 : void CellList::add_point(size_t index, Vector position) {
     549        2355 :     auto fractional = box_.cartesian_to_fractional(position);
     550             : 
     551             :     // find the cell in which this atom should go
     552             :     auto cell_index = std::array<int32_t, 3>{
     553        2355 :         static_cast<int32_t>(std::floor(fractional[0] * static_cast<double>(cells_shape_[0]))),
     554        2355 :         static_cast<int32_t>(std::floor(fractional[1] * static_cast<double>(cells_shape_[1]))),
     555        2355 :         static_cast<int32_t>(std::floor(fractional[2] * static_cast<double>(cells_shape_[2]))),
     556        2355 :     };
     557             : 
     558             :     // deal with pbc by wrapping the atom inside if it was outside of the
     559             :     // cell
     560             :     CellShift shift;
     561             :     // auto (shift, cell_index) =
     562        2355 :     if (box_.periodic()) {
     563        2348 :         auto result = divmod(cell_index, cells_shape_);
     564        2348 :         shift = CellShift{std::get<0>(result)};
     565        2348 :         cell_index = std::get<1>(result);
     566             :     } else {
     567             :         shift = CellShift({0, 0, 0});
     568           7 :         cell_index = std::array<int32_t, 3>{
     569           7 :             std::clamp(cell_index[0], 0, static_cast<int32_t>(cells_shape_[0] - 1)),
     570           7 :             std::clamp(cell_index[1], 0, static_cast<int32_t>(cells_shape_[1] - 1)),
     571          14 :             std::clamp(cell_index[2], 0, static_cast<int32_t>(cells_shape_[2] - 1)),
     572             :         };
     573             :     }
     574             : 
     575        2355 :     this->get_cell(cell_index).emplace_back(Point{index, shift});
     576        2355 : }
     577             : 
     578             : 
     579             : template <typename Function>
     580          15 : void CellList::foreach_pair(Function callback) {
     581          51 :     for (int32_t cell_i_x=0; cell_i_x<static_cast<int32_t>(cells_shape_[0]); cell_i_x++) {
     582         180 :     for (int32_t cell_i_y=0; cell_i_y<static_cast<int32_t>(cells_shape_[1]); cell_i_y++) {
     583        1016 :     for (int32_t cell_i_z=0; cell_i_z<static_cast<int32_t>(cells_shape_[2]); cell_i_z++) {
     584         872 :         const auto& current_cell = this->get_cell({cell_i_x, cell_i_y, cell_i_z});
     585             :         // look through each neighboring cell
     586        3484 :         for (int32_t delta_x=-n_search_[0]; delta_x<=n_search_[0]; delta_x++) {
     587       10444 :         for (int32_t delta_y=-n_search_[1]; delta_y<=n_search_[1]; delta_y++) {
     588       31324 :         for (int32_t delta_z=-n_search_[2]; delta_z<=n_search_[2]; delta_z++) {
     589       23492 :             auto cell_i = std::array<int32_t, 3>{
     590       23492 :                 cell_i_x + delta_x,
     591       23492 :                 cell_i_y + delta_y,
     592       23492 :                 cell_i_z + delta_z,
     593             :             };
     594             : 
     595             :             // shift vector from one cell to the other and index of
     596             :             // the neighboring cell
     597       23492 :             auto [cell_shift, neighbor_cell_i] = divmod(cell_i, cells_shape_);
     598             : 
     599       87025 :             for (const auto& atom_i: current_cell) {
     600      275514 :                 for (const auto& atom_j: this->get_cell(neighbor_cell_i)) {
     601      211981 :                     auto shift = CellShift{cell_shift} + atom_i.shift - atom_j.shift;
     602      211981 :                     auto shift_is_zero = shift[0] == 0 && shift[1] == 0 && shift[2] == 0;
     603             : 
     604      211981 :                     if (!box_.periodic() && !shift_is_zero) {
     605             :                         // do not create pairs crossing the periodic
     606             :                         // boundaries in a non-periodic box
     607           0 :                         continue;
     608             :                     }
     609             : 
     610      211981 :                     if (atom_i.index == atom_j.index && shift_is_zero) {
     611             :                         // only create pairs with the same atom twice if the
     612             :                         // pair spans more than one bounding box
     613        2355 :                         continue;
     614             :                     }
     615             : 
     616      209626 :                     callback(atom_i.index, atom_j.index, shift);
     617             :                 }
     618             :             } // loop over atoms in current neighbor cells
     619             :         }}}
     620             :     }}} // loop over neighboring cells
     621          15 : }
     622             : 
     623       66760 : CellList::Cell& CellList::get_cell(std::array<int32_t, 3> index) {
     624       66760 :     size_t linear_index = (cells_shape_[0] * cells_shape_[1] * index[2])
     625       66760 :                         + (cells_shape_[0] * index[1])
     626       66760 :                         + index[0];
     627       66760 :     return cells_[linear_index];
     628             : }
     629             : 
     630             : /* ========================================================================== */
     631             : 
     632             : 
     633       15174 : void GrowableNeighborList::set_pair(size_t index, size_t first, size_t second) {
     634       15174 :     if (index >= this->capacity) {
     635          88 :         this->grow();
     636             :     }
     637             : 
     638       15174 :     this->neighbors.pairs[index][0] = first;
     639       15174 :     this->neighbors.pairs[index][1] = second;
     640       15174 : }
     641             : 
     642       15174 : void GrowableNeighborList::set_shift(size_t index, PLMD::metatomic::vesin::CellShift shift) {
     643       15174 :     if (index >= this->capacity) {
     644           0 :         this->grow();
     645             :     }
     646             : 
     647       15174 :     this->neighbors.shifts[index][0] = shift[0];
     648       15174 :     this->neighbors.shifts[index][1] = shift[1];
     649       15174 :     this->neighbors.shifts[index][2] = shift[2];
     650       15174 : }
     651             : 
     652           0 : void GrowableNeighborList::set_distance(size_t index, double distance) {
     653           0 :     if (index >= this->capacity) {
     654           0 :         this->grow();
     655             :     }
     656             : 
     657           0 :     this->neighbors.distances[index] = distance;
     658           0 : }
     659             : 
     660       15174 : void GrowableNeighborList::set_vector(size_t index, PLMD::metatomic::vesin::Vector vector) {
     661       15174 :     if (index >= this->capacity) {
     662           0 :         this->grow();
     663             :     }
     664             : 
     665       15174 :     this->neighbors.vectors[index][0] = vector[0];
     666       15174 :     this->neighbors.vectors[index][1] = vector[1];
     667       15174 :     this->neighbors.vectors[index][2] = vector[2];
     668       15174 : }
     669             : 
     670             : template <typename scalar_t, size_t N>
     671         294 : static scalar_t (*alloc(scalar_t (*ptr)[N], size_t size, size_t new_size))[N] {
     672         294 :     auto* new_ptr = reinterpret_cast<scalar_t (*)[N]>(std::realloc(ptr, new_size * sizeof(scalar_t[N])));
     673             : 
     674         294 :     if (new_ptr == nullptr) {
     675           0 :         throw std::bad_alloc();
     676             :     }
     677             : 
     678             :     // initialize with a bit pattern that maps to NaN for double
     679         294 :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t[N]));
     680             : 
     681         294 :     return new_ptr;
     682             : }
     683             : 
     684             : template <typename scalar_t>
     685           0 : static scalar_t* alloc(scalar_t* ptr, size_t size, size_t new_size) {
     686           0 :     auto* new_ptr = reinterpret_cast<scalar_t*>(std::realloc(ptr, new_size * sizeof(scalar_t)));
     687             : 
     688           0 :     if (new_ptr == nullptr) {
     689           0 :         throw std::bad_alloc();
     690             :     }
     691             : 
     692             :     // initialize with a bit pattern that maps to NaN for double
     693           0 :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t));
     694             : 
     695           0 :     return new_ptr;
     696             : }
     697             : 
     698          88 : void GrowableNeighborList::grow() {
     699          88 :     auto new_size = neighbors.length * 2;
     700             :     if (new_size == 0) {
     701             :         new_size = 1;
     702             :     }
     703             : 
     704          88 :     auto* new_pairs = alloc<size_t, 2>(neighbors.pairs, neighbors.length, new_size);
     705             : 
     706             :     int32_t (*new_shifts)[3] = nullptr;
     707          88 :     if (options.return_shifts) {
     708          88 :         new_shifts = alloc<int32_t, 3>(neighbors.shifts, neighbors.length, new_size);
     709             :     }
     710             : 
     711             :     double *new_distances = nullptr;
     712          88 :     if (options.return_distances) {
     713           0 :         new_distances = alloc<double>(neighbors.distances, neighbors.length, new_size);
     714             :     }
     715             : 
     716             :     double (*new_vectors)[3] = nullptr;
     717          88 :     if (options.return_vectors) {
     718          88 :         new_vectors = alloc<double, 3>(neighbors.vectors, neighbors.length, new_size);
     719             :     }
     720             : 
     721          88 :     this->neighbors.pairs = new_pairs;
     722          88 :     this->neighbors.shifts = new_shifts;
     723          88 :     this->neighbors.distances = new_distances;
     724          88 :     this->neighbors.vectors = new_vectors;
     725             : 
     726          88 :     this->capacity = new_size;
     727          88 : }
     728             : 
     729          15 : void GrowableNeighborList::reset() {
     730             :     // set all allocated data to zero
     731          15 :     auto size = this->neighbors.length;
     732          15 :     std::memset(this->neighbors.pairs, 0, size * sizeof(size_t[2]));
     733             : 
     734          15 :     if (this->neighbors.shifts != nullptr) {
     735           0 :         std::memset(this->neighbors.shifts, 0, size * sizeof(int32_t[3]));
     736             :     }
     737             : 
     738          15 :     if (this->neighbors.distances != nullptr) {
     739           0 :         std::memset(this->neighbors.distances, 0, size * sizeof(double));
     740             :     }
     741             : 
     742          15 :     if (this->neighbors.vectors != nullptr) {
     743           0 :         std::memset(this->neighbors.vectors, 0, size * sizeof(double[3]));
     744             :     }
     745             : 
     746             :     // reset length (but keep the capacity where it's at)
     747          15 :     this->neighbors.length = 0;
     748             : 
     749             :     // allocate/deallocate pointers as required
     750          15 :     auto* shifts = this->neighbors.shifts;
     751          15 :     if (this->options.return_shifts && shifts == nullptr) {
     752          15 :         shifts = alloc<int32_t, 3>(shifts, 0, capacity);
     753           0 :     } else if (!this->options.return_shifts && shifts != nullptr) {
     754           0 :         std::free(shifts);
     755             :         shifts = nullptr;
     756             :     }
     757             : 
     758          15 :     auto* distances = this->neighbors.distances;
     759          15 :     if (this->options.return_distances && distances == nullptr) {
     760           0 :         distances = alloc<double>(distances, 0, capacity);
     761          15 :     } else if (!this->options.return_distances && distances != nullptr) {
     762           0 :         std::free(distances);
     763             :         distances = nullptr;
     764             :     }
     765             : 
     766          15 :     auto* vectors = this->neighbors.vectors;
     767          15 :     if (this->options.return_vectors && vectors == nullptr) {
     768          15 :         vectors = alloc<double, 3>(vectors, 0, capacity);
     769           0 :     } else if (!this->options.return_vectors && vectors != nullptr) {
     770           0 :         std::free(vectors);
     771             :         vectors = nullptr;
     772             :     }
     773             : 
     774          15 :     this->neighbors.shifts = shifts;
     775          15 :     this->neighbors.distances = distances;
     776          15 :     this->neighbors.vectors = vectors;
     777          15 : }
     778             : 
     779           0 : void GrowableNeighborList::sort() {
     780           0 :     if (this->length() == 0) {
     781           0 :         return;
     782             :     }
     783             : 
     784             :     // step 1: sort an array of indices, comparing the pairs at the indices
     785           0 :     auto indices = std::vector<int64_t>(this->length(), 0);
     786           0 :     std::iota(std::begin(indices), std::end(indices), 0);
     787             : 
     788             :     struct compare_pairs {
     789             :         compare_pairs(size_t (*pairs_)[2]): pairs(pairs_) {}
     790             : 
     791           0 :         bool operator()(int64_t a, int64_t b) const {
     792           0 :             if (pairs[a][0] == pairs[b][0]) {
     793           0 :                 return pairs[a][1] < pairs[b][1];
     794             :             } else {
     795           0 :                 return pairs[a][0] < pairs[b][0];
     796             :             }
     797             :         }
     798             : 
     799             :         size_t (*pairs)[2];
     800             :     };
     801             : 
     802           0 :     std::sort(std::begin(indices), std::end(indices), compare_pairs(this->neighbors.pairs));
     803             : 
     804             :     // step 2: permute all data according to the sorted indices.
     805             :     int64_t cur = 0;
     806             :     int64_t is_sorted_up_to = 0;
     807             :     // data in `from` should go to `cur`
     808             :     auto from = indices[cur];
     809             : 
     810           0 :     size_t tmp_pair[2] = {0};
     811             :     double tmp_distance = 0;
     812           0 :     double tmp_vector[3] = {0};
     813           0 :     int32_t tmp_shift[3] = {0};
     814             : 
     815           0 :     while (cur < this->length()) {
     816             :         // move data from `cur` to temporary
     817           0 :         std::swap(tmp_pair, this->neighbors.pairs[cur]);
     818           0 :         if (options.return_distances) {
     819           0 :             std::swap(tmp_distance, this->neighbors.distances[cur]);
     820             :         }
     821           0 :         if (options.return_vectors) {
     822           0 :             std::swap(tmp_vector, this->neighbors.vectors[cur]);
     823             :         }
     824           0 :         if (options.return_shifts) {
     825           0 :             std::swap(tmp_shift, this->neighbors.shifts[cur]);
     826             :         }
     827             : 
     828           0 :         from = indices[cur];
     829             :         do {
     830           0 :             if (from == cur) {
     831             :                 // permutation loop of a single entry, i.e. this value stayed
     832             :                 // where is already was
     833             :                 break;
     834             :             }
     835             :             // move data from `from` to `cur`
     836           0 :             std::swap(this->neighbors.pairs[cur], this->neighbors.pairs[from]);
     837           0 :             if (options.return_distances) {
     838           0 :                 std::swap(this->neighbors.distances[cur], this->neighbors.distances[from]);
     839             :             }
     840           0 :             if (options.return_vectors) {
     841           0 :                 std::swap(this->neighbors.vectors[cur], this->neighbors.vectors[from]);
     842             :             }
     843           0 :             if (options.return_shifts) {
     844           0 :                 std::swap(this->neighbors.shifts[cur], this->neighbors.shifts[from]);
     845             :             }
     846             : 
     847             :             // mark this spot as already visited
     848           0 :             indices[cur] = -1;
     849             : 
     850             :             // update the indices
     851             :             cur = from;
     852           0 :             from = indices[cur];
     853           0 :         } while (indices[from] != -1);
     854             : 
     855             :         // we found a full loop of permutation, we can put tmp into `cur`
     856           0 :         std::swap(this->neighbors.pairs[cur], tmp_pair);
     857           0 :         if (options.return_distances) {
     858           0 :             std::swap(this->neighbors.distances[cur], tmp_distance);
     859             :         }
     860           0 :         if (options.return_vectors) {
     861           0 :             std::swap(this->neighbors.vectors[cur], tmp_vector);
     862             :         }
     863           0 :         if (options.return_shifts) {
     864           0 :             std::swap(this->neighbors.shifts[cur], tmp_shift);
     865             :         }
     866             : 
     867           0 :         indices[cur] = -1;
     868             : 
     869             :         // look for the next loop of permutation
     870             :         cur = is_sorted_up_to;
     871           0 :         while (indices[cur] == -1) {
     872           0 :             cur += 1;
     873           0 :             is_sorted_up_to += 1;
     874           0 :             if (cur == this->length()) {
     875             :                 break;
     876             :             }
     877             :         }
     878             :     }
     879             : }
     880             : 
     881             : 
     882          15 : void PLMD::metatomic::vesin::cpu::free_neighbors(VesinNeighborList& neighbors) {
     883             :     assert(neighbors.device == VesinCPU);
     884             : 
     885          15 :     std::free(neighbors.pairs);
     886          15 :     std::free(neighbors.shifts);
     887          15 :     std::free(neighbors.vectors);
     888          15 :     std::free(neighbors.distances);
     889          15 : }
     890             : #include <cstring>
     891             : #include <string>
     892             : 
     893             : 
     894             : 
     895             : thread_local std::string LAST_ERROR;
     896             : 
     897          15 : extern "C" int vesin_neighbors(
     898             :     const double (*points)[3],
     899             :     size_t n_points,
     900             :     const double box[3][3],
     901             :     bool periodic,
     902             :     VesinDevice device,
     903             :     VesinOptions options,
     904             :     VesinNeighborList* neighbors,
     905             :     const char** error_message
     906             : ) {
     907          15 :     if (error_message == nullptr) {
     908             :         return EXIT_FAILURE;
     909             :     }
     910             : 
     911          15 :     if (points == nullptr) {
     912           0 :         *error_message = "`points` can not be a NULL pointer";
     913           0 :         return EXIT_FAILURE;
     914             :     }
     915             : 
     916          15 :     if (box == nullptr) {
     917           0 :         *error_message = "`cell` can not be a NULL pointer";
     918           0 :         return EXIT_FAILURE;
     919             :     }
     920             : 
     921          15 :     if (neighbors == nullptr) {
     922           0 :         *error_message = "`neighbors` can not be a NULL pointer";
     923           0 :         return EXIT_FAILURE;
     924             :     }
     925             : 
     926          15 :     if (!std::isfinite(options.cutoff) || options.cutoff <= 0) {
     927           0 :         *error_message = "cutoff must be a finite, positive number";
     928           0 :         return EXIT_FAILURE;
     929             :     }
     930             : 
     931          15 :     if (options.cutoff <= 1e-6) {
     932           0 :         *error_message = "cutoff is too small";
     933           0 :         return EXIT_FAILURE;
     934             :     }
     935             : 
     936          15 :     if (neighbors->device != VesinUnknownDevice && neighbors->device != device) {
     937           0 :         *error_message = "`neighbors` device and data `device` do not match, free the neighbors first";
     938           0 :         return EXIT_FAILURE;
     939             :     }
     940             : 
     941          15 :     if (device == VesinUnknownDevice) {
     942           0 :         *error_message = "got an unknown device to use when running simulation";
     943           0 :         return EXIT_FAILURE;
     944             :     }
     945             : 
     946          15 :     if (neighbors->device == VesinUnknownDevice) {
     947             :         // initialize the device
     948          15 :         neighbors->device = device;
     949           0 :     } else if (neighbors->device != device) {
     950           0 :         *error_message = "`neighbors.device` and `device` do not match, free the neighbors first";
     951           0 :         return EXIT_FAILURE;
     952             :     }
     953             : 
     954             :     try {
     955          15 :         if (device == VesinCPU) {
     956             :             auto matrix = PLMD::metatomic::vesin::Matrix{{{
     957          15 :                 {{box[0][0], box[0][1], box[0][2]}},
     958          15 :                 {{box[1][0], box[1][1], box[1][2]}},
     959          15 :                 {{box[2][0], box[2][1], box[2][2]}},
     960          15 :             }}};
     961             : 
     962          15 :             PLMD::metatomic::vesin::cpu::neighbors(
     963             :                 reinterpret_cast<const PLMD::metatomic::vesin::Vector*>(points),
     964             :                 n_points,
     965             :                 PLMD::metatomic::vesin::BoundingBox(matrix, periodic),
     966             :                 options,
     967             :                 *neighbors
     968             :             );
     969             :         } else {
     970           0 :             throw std::runtime_error("unknown device " + std::to_string(device));
     971             :         }
     972           0 :     } catch (const std::bad_alloc&) {
     973           0 :         LAST_ERROR = "failed to allocate memory";
     974           0 :         *error_message = LAST_ERROR.c_str();
     975             :         return EXIT_FAILURE;
     976           0 :     } catch (const std::exception& e) {
     977           0 :         LAST_ERROR = e.what();
     978           0 :         *error_message = LAST_ERROR.c_str();
     979             :         return EXIT_FAILURE;
     980           0 :     } catch (...) {
     981           0 :         *error_message = "fatal error: unknown type thrown as exception";
     982             :         return EXIT_FAILURE;
     983           0 :     }
     984             : 
     985          15 :     return EXIT_SUCCESS;
     986             : }
     987             : 
     988             : 
     989          15 : extern "C" void vesin_free(VesinNeighborList* neighbors) {
     990          15 :     if (neighbors == nullptr) {
     991             :         return;
     992             :     }
     993             : 
     994          15 :     if (neighbors->device == VesinUnknownDevice) {
     995             :         // nothing to do
     996          15 :     } else if (neighbors->device == VesinCPU) {
     997          15 :         PLMD::metatomic::vesin::cpu::free_neighbors(*neighbors);
     998             :     }
     999             : 
    1000             :     std::memset(neighbors, 0, sizeof(VesinNeighborList));
    1001             : }

Generated by: LCOV version 1.16