LCOV - code coverage report
Current view: top level - metatomic - vesin.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 341 439 77.7 %
Date: 2026-06-05 17:04:24 Functions: 29 33 87.9 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATOMIC-PLUMED team
       3             : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/metatomic/ for more information about the
       6             : metatomic package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATOMIC-PLUMED module.
       9             : 
      10             : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : /*INDENT-OFF*/
      24             : #include "vesin.h"
      25             : // automatically generated 
      26             : // vesin version: 0.5.7
      27             : 
      28             : #include <cassert>
      29             : #include <cstdlib>
      30             : #include <cstring>
      31             : 
      32             : #include <algorithm>
      33             : #include <numeric>
      34             : #include <tuple>
      35             : 
      36             : #ifndef VESIN_CPU_CELL_LIST_HPP
      37             : #define VESIN_CPU_CELL_LIST_HPP
      38             : 
      39             : #include <vector>
      40             : 
      41             : #include "vesin.h"
      42             : 
      43             : #ifndef VESIN_TYPES_HPP
      44             : #define VESIN_TYPES_HPP
      45             : 
      46             : #include <array>
      47             : #include <cassert>
      48             : #include <string>
      49             : 
      50             : #ifndef VESIN_MATH_HPP
      51             : #define VESIN_MATH_HPP
      52             : 
      53             : #include <array>
      54             : #include <cmath>
      55             : #include <stdexcept>
      56             : 
      57             : namespace PLMD {
      58             : namespace metatomic {
      59             : namespace vesin {
      60             : struct Vector;
      61             : 
      62             : Vector operator*(Vector vector, double scalar);
      63             : 
      64             : struct Vector: public std::array<double, 3> {
      65             :     double dot(Vector other) const {
      66          27 :         return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2];
      67             :     }
      68             : 
      69          51 :     double norm() const {
      70          51 :         return std::sqrt(this->dot(*this));
      71             :     }
      72             : 
      73          51 :     Vector normalize() const {
      74          51 :         return *this * (1.0 / this->norm());
      75             :     }
      76             : 
      77             :     Vector cross(Vector other) const {
      78             :         return Vector{
      79          51 :             (*this)[1] * other[2] - (*this)[2] * other[1],
      80          51 :             (*this)[2] * other[0] - (*this)[0] * other[2],
      81          51 :             (*this)[0] * other[1] - (*this)[1] * other[0],
      82          51 :         };
      83             :     }
      84             : };
      85             : 
      86             : inline Vector operator+(Vector u, Vector v) {
      87             :     return Vector{
      88      111347 :         u[0] + v[0],
      89      111347 :         u[1] + v[1],
      90      111347 :         u[2] + v[2],
      91             :     };
      92             : }
      93             : 
      94             : inline Vector operator-(Vector u, Vector v) {
      95             :     return Vector{
      96      111347 :         u[0] - v[0],
      97      111347 :         u[1] - v[1],
      98      111347 :         u[2] - v[2],
      99             :     };
     100             : }
     101             : 
     102             : inline Vector operator*(double scalar, Vector vector) {
     103             :     return Vector{
     104             :         scalar * vector[0],
     105             :         scalar * vector[1],
     106             :         scalar * vector[2],
     107             :     };
     108             : }
     109             : 
     110             : inline Vector operator*(Vector vector, double scalar) {
     111             :     return Vector{
     112          51 :         scalar * vector[0],
     113          51 :         scalar * vector[1],
     114          51 :         scalar * vector[2],
     115          51 :     };
     116             : }
     117             : 
     118             : struct Matrix: public std::array<std::array<double, 3>, 3> {
     119          34 :     double determinant() const {
     120             :         // clang-format off
     121          34 :         return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2])
     122          34 :              - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0])
     123          34 :              + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]);
     124             :         // clang-format on
     125             :     }
     126             : 
     127          17 :     Matrix inverse() const {
     128          17 :         auto det = this->determinant();
     129             : 
     130          17 :         if (std::abs(det) < 1e-30) {
     131           0 :             throw std::runtime_error("this matrix is not invertible");
     132             :         }
     133             : 
     134             :         auto inverse = Matrix();
     135          17 :         inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det;
     136          17 :         inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det;
     137          17 :         inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det;
     138          17 :         inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det;
     139          17 :         inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det;
     140          17 :         inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det;
     141          17 :         inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det;
     142          17 :         inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det;
     143          17 :         inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det;
     144          17 :         return inverse;
     145             :     }
     146             : };
     147             : 
     148             : inline Vector operator*(Matrix matrix, Vector vector) {
     149             :     return Vector{
     150             :         matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2],
     151             :         matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2],
     152             :         matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2],
     153             :     };
     154             : }
     155             : 
     156      113712 : inline Vector operator*(Vector vector, Matrix matrix) {
     157             :     return Vector{
     158      113712 :         vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0],
     159      113712 :         vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1],
     160      113712 :         vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2],
     161      113712 :     };
     162             : }
     163             : 
     164             : } // namespace vesin
     165             : } // namespace metatomic
     166             : } // namespace PLMD
     167             : 
     168             : #endif
     169             : 
     170             : namespace PLMD {
     171             : namespace metatomic {
     172             : namespace vesin {
     173             : 
     174             : class BoundingBox {
     175             : public:
     176             :     BoundingBox(const BoundingBox&) = delete;
     177             :     BoundingBox& operator=(const BoundingBox&) = delete;
     178             : 
     179             :     BoundingBox(BoundingBox&&) = default;
     180             :     BoundingBox& operator=(BoundingBox&&) = default;
     181             : 
     182          17 :     BoundingBox(Matrix matrix, const bool periodic[3]):
     183          17 :         matrix_(matrix),
     184          17 :         periodic_({periodic[0], periodic[1], periodic[2]}),
     185          17 :         max_positions_({-1e300, -1e300, -1e300}),
     186          17 :         min_positions_({1e300, 1e300, 1e300}) {
     187             : 
     188             :         // find number of periodic directions and their indices
     189             :         int n_periodic = 0;
     190             :         int periodic_idx_1 = -1;
     191             :         int periodic_idx_2 = -1;
     192          68 :         for (int i = 0; i < 3; ++i) {
     193          51 :             if (periodic_[i]) {
     194          27 :                 n_periodic += 1;
     195          27 :                 if (periodic_idx_1 == -1) {
     196             :                     periodic_idx_1 = i;
     197          18 :                 } else if (periodic_idx_2 == -1) {
     198             :                     periodic_idx_2 = i;
     199             :                 }
     200             :             }
     201             :         }
     202             : 
     203             :         // adjust the box matrix to have a simple orthogonal dimension along
     204             :         // non-periodic directions
     205          17 :         if (n_periodic == 0) {
     206           8 :             matrix_ = Matrix{
     207             :                 std::array<double, 3>{1, 0, 0},
     208             :                 std::array<double, 3>{0, 1, 0},
     209             :                 std::array<double, 3>{0, 0, 1},
     210             :             };
     211           9 :         } else if (n_periodic == 1) {
     212             :             assert(periodic_idx_1 != -1);
     213             :             // Make the two non-periodic directions orthogonal to the periodic one
     214           0 :             auto a = Vector{matrix_[periodic_idx_1]};
     215             :             auto b = Vector{0, 1, 0};
     216           0 :             if (std::abs(a.normalize().dot(b)) > 0.9) {
     217             :                 b = Vector{0, 0, 1};
     218             :             }
     219           0 :             auto c = a.cross(b).normalize();
     220           0 :             b = c.cross(a).normalize();
     221             : 
     222             :             // Assign back to the matrix picking the "non-periodic" indices without ifs
     223           0 :             matrix_[(periodic_idx_1 + 1) % 3] = b;
     224           0 :             matrix_[(periodic_idx_1 + 2) % 3] = c;
     225           9 :         } else if (n_periodic == 2) {
     226             :             assert(periodic_idx_1 != -1 && periodic_idx_2 != -1);
     227             :             // Make the one non-periodic direction orthogonal to the two periodic ones
     228           0 :             auto a = Vector{matrix_[periodic_idx_1]};
     229           0 :             auto b = Vector{matrix_[periodic_idx_2]};
     230           0 :             auto c = a.cross(b).normalize();
     231             : 
     232             :             // Assign back to the matrix picking the "non-periodic" index without ifs
     233           0 :             matrix_[(3 - periodic_idx_1 - periodic_idx_2)] = c;
     234             :         }
     235             : 
     236             :         // precompute the inverse matrix
     237          17 :         auto det = matrix_.determinant();
     238          17 :         if (std::abs(det) < 1e-30) {
     239           0 :             throw std::runtime_error("the box matrix is not invertible");
     240             :         }
     241             : 
     242          17 :         this->inverse_ = matrix_.inverse();
     243             : 
     244             :         // precompute distances between faces of the bounding box
     245          17 :         auto a = Vector{matrix_[0]};
     246          17 :         auto b = Vector{matrix_[1]};
     247          17 :         auto c = Vector{matrix_[2]};
     248             : 
     249             :         // Plans normal vectors
     250          17 :         auto na = b.cross(c).normalize();
     251          17 :         auto nb = c.cross(a).normalize();
     252          17 :         auto nc = a.cross(b).normalize();
     253             : 
     254          17 :         distances_between_faces_ = Vector{
     255          17 :             periodic_[0] ? std::abs(na.dot(a)) : max_positions_[0] - min_positions_[0],
     256          17 :             periodic_[1] ? std::abs(nb.dot(b)) : max_positions_[1] - min_positions_[1],
     257          17 :             periodic_[2] ? std::abs(nc.dot(c)) : max_positions_[2] - min_positions_[2],
     258             :         };
     259          17 :     }
     260             : 
     261             :     const Matrix& matrix() const {
     262      111347 :         return this->matrix_;
     263             :     }
     264             : 
     265             :     bool periodic(size_t spatial) const {
     266      112464 :         return this->periodic_[spatial];
     267             :     }
     268             : 
     269             :     /// Convert a vector from cartesian coordinates to fractional coordinates
     270             :     ///
     271             :     /// For non-periodic dimensions, the fractional coordinates are not wrapped
     272             :     /// inside [0, 1], but are normalized by the corresponding box length.
     273        2365 :     Vector cartesian_to_fractional(Vector cartesian) const {
     274        2365 :         auto fractional = cartesian * inverse_;
     275        2365 :         if (!periodic_[0]) {
     276           8 :             fractional[0] = (cartesian[0] - min_positions_[0]) / distances_between_faces_[0];
     277             :         }
     278             : 
     279        2365 :         if (!periodic_[1]) {
     280           8 :             fractional[1] = (cartesian[1] - min_positions_[1]) / distances_between_faces_[1];
     281             :         }
     282             : 
     283        2365 :         if (!periodic_[2]) {
     284           8 :             fractional[2] = (cartesian[2] - min_positions_[2]) / distances_between_faces_[2];
     285             :         }
     286             : 
     287        2365 :         return fractional;
     288             :     }
     289             : 
     290             :     /// Convert a vector from fractional coordinates to cartesian coordinates
     291             :     Vector fractional_to_cartesian(Vector fractional) const {
     292             :         auto cartesian = fractional * matrix_;
     293             : 
     294             :         if (!periodic_[0]) {
     295             :             cartesian[0] *= distances_between_faces_[0];
     296             :             cartesian[0] += min_positions_[0];
     297             :         }
     298             : 
     299             :         if (!periodic_[1]) {
     300             :             cartesian[1] *= distances_between_faces_[1];
     301             :             cartesian[1] += min_positions_[1];
     302             :         }
     303             : 
     304             :         if (!periodic_[2]) {
     305             :             cartesian[2] *= distances_between_faces_[2];
     306             :             cartesian[2] += min_positions_[2];
     307             :         }
     308             : 
     309             :         return cartesian;
     310             :     }
     311             : 
     312             :     /// Get the three distances between faces of the bounding box
     313             :     Vector distances_between_faces() const {
     314          17 :         return distances_between_faces_;
     315             :     }
     316             : 
     317          17 :     void make_bounding_for(const double (*points)[3], size_t n_points) {
     318             :         // find the min and max coordinates along each axis
     319        2382 :         for (size_t i = 0; i < n_points; i++) {
     320        9460 :             for (size_t spatial = 0; spatial < 3; spatial++) {
     321        7095 :                 if (!std::isfinite(points[i][spatial])) {
     322           0 :                     throw std::runtime_error(
     323           0 :                         "point " + std::to_string(i) + " has non-finite coordinate " +
     324           0 :                         "along axis " + std::to_string(spatial) + ": " +
     325           0 :                         std::to_string(points[i][spatial])
     326           0 :                     );
     327             :                 }
     328             : 
     329        7095 :                 if (points[i][spatial] < min_positions_[spatial]) {
     330         148 :                     min_positions_[spatial] = points[i][spatial];
     331             :                 }
     332        7095 :                 if (points[i][spatial] > max_positions_[spatial]) {
     333         232 :                     max_positions_[spatial] = points[i][spatial];
     334             :                 }
     335             :             }
     336             :         }
     337             : 
     338          68 :         for (int dim = 0; dim < 3; dim++) {
     339             :             // if all atoms have the same coordinate in this dimension, pretend
     340             :             // that the bounding box is at least 1 unit wide to avoid numerical issues
     341          51 :             if (max_positions_[dim] - min_positions_[dim] < 1e-6) {
     342          24 :                 max_positions_[dim] = min_positions_[dim] + 1;
     343             :             }
     344             : 
     345          51 :             if (!periodic_[dim]) {
     346             :                 // add a 1% margin to make sure all points are strictly inside the
     347             :                 // bounding box
     348          24 :                 distances_between_faces_[dim] = max_positions_[dim] * 1.01 - min_positions_[dim];
     349             :             }
     350             :         }
     351          17 :     }
     352             : 
     353             : private:
     354             :     Matrix matrix_;
     355             :     std::array<bool, 3> periodic_;
     356             : 
     357             :     Matrix inverse_;
     358             :     Vector min_positions_;
     359             :     Vector max_positions_;
     360             :     Vector distances_between_faces_;
     361             : };
     362             : 
     363             : /// A cell shift represents the displacement along cell axis between the actual
     364             : /// position of an atom and a periodic image of this atom.
     365             : ///
     366             : /// The cell shift can be used to reconstruct the vector between two points,
     367             : /// wrapped inside the unit cell.
     368             : struct CellShift: public std::array<int32_t, 3> {
     369             :     /// Compute the shift vector in cartesian coordinates, using the given cell
     370             :     /// matrix (stored in row major order).
     371      111347 :     Vector cartesian(const BoundingBox& box) const {
     372             :         assert(box.periodic(0) || (*this)[0] == 0);
     373             :         assert(box.periodic(1) || (*this)[1] == 0);
     374             :         assert(box.periodic(2) || (*this)[2] == 0);
     375             : 
     376             :         auto vector = Vector{
     377      111347 :             static_cast<double>((*this)[0]),
     378      111347 :             static_cast<double>((*this)[1]),
     379      111347 :             static_cast<double>((*this)[2]),
     380      111347 :         };
     381      111347 :         return vector * box.matrix();
     382             :     }
     383             : };
     384             : 
     385             : inline CellShift operator+(CellShift a, CellShift b) {
     386             :     return CellShift{
     387      215625 :         a[0] + b[0],
     388      215625 :         a[1] + b[1],
     389      215625 :         a[2] + b[2],
     390             :     };
     391             : }
     392             : 
     393             : inline CellShift operator-(CellShift a, CellShift b) {
     394             :     return CellShift{
     395      215625 :         a[0] - b[0],
     396      215625 :         a[1] - b[1],
     397      215625 :         a[2] - b[2],
     398             :     };
     399             : }
     400             : 
     401             : } // namespace vesin
     402             : } // namespace metatomic
     403             : } // namespace PLMD
     404             : 
     405             : #endif
     406             : 
     407             : namespace PLMD {
     408             : namespace metatomic {
     409             : namespace vesin {
     410             : namespace cpu {
     411             : 
     412             : void free_neighbors(VesinNeighborList& neighbors);
     413             : 
     414             : void neighbors(
     415             :     const Vector* points,
     416             :     size_t n_points,
     417             :     BoundingBox box,
     418             :     VesinOptions options,
     419             :     VesinNeighborList& neighbors
     420             : );
     421             : 
     422             : /// The cell list is used to sort atoms inside bins/cells.
     423             : ///
     424             : /// The list of potential pairs is then constructed by looking through all
     425             : /// neighboring cells (the number of cells to search depends on the cutoff and
     426             : /// the size of the cells) for each atom to create pair candidates.
     427          17 : class CellList {
     428             : public:
     429             :     /// Create a new `CellList` for the given bounding box and cutoff,
     430             :     /// determining all required parameters.
     431             :     CellList(BoundingBox box, double cutoff);
     432             : 
     433             :     /// Add a single point to the cell list at the given `position`. The point
     434             :     /// is uniquely identified by its `index`.
     435             :     void add_point(size_t index, Vector position);
     436             : 
     437             :     /// Iterate over all possible pairs, calling the given callback every time
     438             :     template <typename Function>
     439             :     void foreach_pair(Function callback);
     440             : 
     441             : private:
     442             :     /// How many cells do we need to look at when searching neighbors to include
     443             :     /// all neighbors below cutoff
     444             :     std::array<int32_t, 3> n_search_;
     445             : 
     446             :     /// the cells themselves are a list of points & corresponding
     447             :     /// shift to place the point inside the cell
     448             :     struct Point {
     449             :         size_t index;
     450             :         CellShift shift;
     451             :     };
     452         881 :     struct Cell: public std::vector<Point> {};
     453             : 
     454             :     // raw data for the cells
     455             :     std::vector<Cell> cells_;
     456             :     // shape of the cell array
     457             :     std::array<size_t, 3> cells_shape_;
     458             : 
     459             :     BoundingBox box_;
     460             : 
     461             :     Cell& get_cell(std::array<int32_t, 3> index);
     462             : };
     463             : 
     464             : /// Wrapper around `VesinNeighborList` that behaves like a std::vector,
     465             : /// automatically growing memory allocations.
     466             : class GrowableNeighborList {
     467             : public:
     468             :     VesinNeighborList& neighbors;
     469             :     size_t capacity;
     470             :     VesinOptions options;
     471             : 
     472             :     size_t length() const {
     473       15407 :         return neighbors.length;
     474             :     }
     475             : 
     476             :     void increment_length() {
     477       15390 :         neighbors.length += 1;
     478       15390 :     }
     479             : 
     480             :     void set_pair(size_t index, size_t first, size_t second);
     481             :     void set_shift(size_t index, PLMD::metatomic::vesin::CellShift shift);
     482             :     void set_distance(size_t index, double distance);
     483             :     void set_vector(size_t index, PLMD::metatomic::vesin::Vector vector);
     484             : 
     485             :     // reset length to 0, and allocate/deallocate members of
     486             :     // `neighbors` according to `options`
     487             :     void reset();
     488             : 
     489             :     // allocate more memory & update capacity
     490             :     void grow();
     491             : 
     492             :     // sort the pairs currently in the neighbor list
     493             :     void sort();
     494             : };
     495             : 
     496             : } // namespace cpu
     497             : } // namespace vesin
     498             : } // namespace metatomic
     499             : } // namespace PLMD
     500             : 
     501             : #endif
     502             : 
     503             : using namespace PLMD::metatomic::vesin;
     504             : using namespace PLMD::metatomic::vesin::cpu;
     505             : 
     506          17 : void PLMD::metatomic::vesin::cpu::neighbors(
     507             :     const Vector* points,
     508             :     size_t n_points,
     509             :     BoundingBox box,
     510             :     VesinOptions options,
     511             :     VesinNeighborList& raw_neighbors
     512             : ) {
     513          17 :     if (options.algorithm == VesinAutoAlgorithm || options.algorithm == VesinCellList) {
     514             :         // all good, this is the only thing we implement
     515             :     } else {
     516           0 :         throw std::runtime_error("only VesinAutoAlgorithm and VesinCellList are supported on CPU");
     517             :     }
     518             : 
     519          17 :     auto cell_list = CellList(std::move(box), options.cutoff);
     520             : 
     521        2382 :     for (size_t i = 0; i < n_points; i++) {
     522        2365 :         cell_list.add_point(i, points[i]);
     523             :     }
     524             : 
     525          17 :     auto cutoff2 = options.cutoff * options.cutoff;
     526             : 
     527             :     // the cell list creates too many pairs, we only need to keep the
     528             :     // one where the distance is actually below the cutoff
     529          17 :     auto neighbors = GrowableNeighborList{raw_neighbors, raw_neighbors.length, options};
     530          17 :     neighbors.reset();
     531             : 
     532          17 :     cell_list.foreach_pair([&](size_t first, size_t second, CellShift shift) {
     533      211804 :         if (!options.full) {
     534             :             // filter out some pairs for half neighbor lists
     535      200914 :             if (first > second) {
     536             :                 return;
     537             :             }
     538             : 
     539      100457 :             if (first == second) {
     540             :                 // When creating pairs between a point and one of its periodic
     541             :                 // images, the code generate multiple redundant pairs (e.g. with
     542             :                 // shifts 0 1 1 and 0 -1 -1); and we want to only keep one of
     543             :                 // these.
     544           0 :                 if (shift[0] + shift[1] + shift[2] < 0) {
     545             :                     // drop shifts on the negative half-space
     546             :                     return;
     547             :                 }
     548             : 
     549           0 :                 if ((shift[0] + shift[1] + shift[2] == 0) && (shift[2] < 0 || (shift[2] == 0 && shift[1] < 0))) {
     550             :                     // drop shifts in the negative half plane or the negative
     551             :                     // shift[1] axis. See below for a graphical representation:
     552             :                     // we are keeping the shifts indicated with `O` and dropping
     553             :                     // the ones indicated with `X`
     554             :                     //
     555             :                     //  O O O │ O O O
     556             :                     //  O O O │ O O O
     557             :                     //  O O O │ O O O
     558             :                     // ─X─X─X─┼─O─O─O─
     559             :                     //  X X X │ X X X
     560             :                     //  X X X │ X X X
     561             :                     //  X X X │ X X X
     562             :                     return;
     563             :                 }
     564             :             }
     565             :         }
     566             : 
     567      111347 :         auto vector = points[second] - points[first] + shift.cartesian(box);
     568             :         auto distance2 = vector.dot(vector);
     569             : 
     570      111347 :         if (distance2 < cutoff2) {
     571       15390 :             auto index = neighbors.length();
     572       15390 :             neighbors.set_pair(index, first, second);
     573             : 
     574       15390 :             if (options.return_shifts) {
     575       15390 :                 neighbors.set_shift(index, shift);
     576             :             }
     577             : 
     578       15390 :             if (options.return_distances) {
     579           0 :                 neighbors.set_distance(index, std::sqrt(distance2));
     580             :             }
     581             : 
     582       15390 :             if (options.return_vectors) {
     583       15390 :                 neighbors.set_vector(index, vector);
     584             :             }
     585             : 
     586             :             neighbors.increment_length();
     587             :         }
     588             :     });
     589             : 
     590          17 :     if (options.sorted) {
     591          17 :         neighbors.sort();
     592             :     }
     593          17 : }
     594             : 
     595             : /* ========================================================================== */
     596             : 
     597             : /// Maximal number of cells, we need to use this to prevent having too many
     598             : /// cells with a small bounding box and a large cutoff
     599             : #define MAX_NUMBER_OF_CELLS 1e5
     600             : 
     601             : /// Function to compute both quotient and remainder of the division of a by b.
     602             : /// This function follows Python convention, making sure the remainder have the
     603             : /// same sign as `b`.
     604       82668 : static std::tuple<int32_t, int32_t> divmod(int32_t a, size_t b) {
     605             :     assert(b < (std::numeric_limits<int32_t>::max()));
     606       82668 :     auto b_32 = static_cast<int32_t>(b);
     607       82668 :     auto quotient = a / b_32;
     608       82668 :     auto remainder = a % b_32;
     609       82668 :     if (remainder < 0) {
     610        4083 :         remainder += b_32;
     611        4083 :         quotient -= 1;
     612             :     }
     613       82668 :     return std::make_tuple(quotient, remainder);
     614             : }
     615             : 
     616             : /// Apply the `divmod` function to three components at the time
     617             : static std::tuple<std::array<int32_t, 3>, std::array<int32_t, 3>>
     618       25191 : divmod(std::array<int32_t, 3> a, std::array<size_t, 3> b) {
     619       25191 :     auto [qx, rx] = divmod(a[0], b[0]);
     620       25191 :     auto [qy, ry] = divmod(a[1], b[1]);
     621       25191 :     auto [qz, rz] = divmod(a[2], b[2]);
     622             :     return std::make_tuple(
     623       25191 :         std::array<int32_t, 3>{qx, qy, qz},
     624       25191 :         std::array<int32_t, 3>{rx, ry, rz}
     625       25191 :     );
     626             : }
     627             : 
     628          17 : CellList::CellList(BoundingBox box, double cutoff):
     629             :     n_search_({0, 0, 0}),
     630             :     cells_shape_({0, 0, 0}),
     631          17 :     box_(std::move(box)) {
     632             :     auto distances_between_faces = box_.distances_between_faces();
     633             : 
     634             :     auto n_cells = Vector{
     635          17 :         std::clamp(std::trunc(distances_between_faces[0] / cutoff), 1.0, HUGE_VAL),
     636          17 :         std::clamp(std::trunc(distances_between_faces[1] / cutoff), 1.0, HUGE_VAL),
     637          17 :         std::clamp(std::trunc(distances_between_faces[2] / cutoff), 1.0, HUGE_VAL),
     638             :     };
     639             : 
     640             :     assert(std::isfinite(n_cells[0]) && std::isfinite(n_cells[1]) && std::isfinite(n_cells[2]));
     641             : 
     642             :     // limit memory consumption by ensuring we have less than `MAX_N_CELLS`
     643             :     // cells to look though
     644          17 :     auto n_cells_total = n_cells[0] * n_cells[1] * n_cells[2];
     645          17 :     if (n_cells_total > MAX_NUMBER_OF_CELLS) {
     646             :         // set the total number of cells close to MAX_N_CELLS, while keeping
     647             :         // roughly the ratio of cells in each direction
     648           0 :         auto ratio_x_y = n_cells[0] / n_cells[1];
     649           0 :         auto ratio_y_z = n_cells[1] / n_cells[2];
     650             : 
     651           0 :         n_cells[2] = std::trunc(std::cbrt(MAX_NUMBER_OF_CELLS / (ratio_x_y * ratio_y_z * ratio_y_z)));
     652           0 :         n_cells[1] = std::trunc(ratio_y_z * n_cells[2]);
     653           0 :         n_cells[0] = std::trunc(ratio_x_y * n_cells[1]);
     654             :     }
     655             : 
     656             :     // number of cells to search in each direction to make sure all possible
     657             :     // pairs below the cutoff are accounted for.
     658          17 :     this->n_search_ = std::array<int32_t, 3>{
     659          17 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[0] / distances_between_faces[0])),
     660          17 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[1] / distances_between_faces[1])),
     661          17 :         static_cast<int32_t>(std::ceil(cutoff * n_cells[2] / distances_between_faces[2])),
     662             :     };
     663             : 
     664          17 :     this->cells_shape_ = std::array<size_t, 3>{
     665          17 :         static_cast<size_t>(n_cells[0]),
     666          17 :         static_cast<size_t>(n_cells[1]),
     667          17 :         static_cast<size_t>(n_cells[2]),
     668             :     };
     669             : 
     670          68 :     for (size_t spatial = 0; spatial < 3; spatial++) {
     671          51 :         if (n_search_[spatial] < 1) {
     672           0 :             n_search_[spatial] = 1;
     673             :         }
     674             :     }
     675             : 
     676          17 :     this->cells_.resize(cells_shape_[0] * cells_shape_[1] * cells_shape_[2]);
     677          17 : }
     678             : 
     679        2365 : void CellList::add_point(size_t index, Vector position) {
     680        2365 :     auto fractional = box_.cartesian_to_fractional(position);
     681             : 
     682             :     // find the cell in which this atom should go
     683             :     auto cell_index = std::array<int32_t, 3>{
     684        2365 :         static_cast<int32_t>(std::floor(fractional[0] * static_cast<double>(cells_shape_[0]))),
     685        2365 :         static_cast<int32_t>(std::floor(fractional[1] * static_cast<double>(cells_shape_[1]))),
     686        2365 :         static_cast<int32_t>(std::floor(fractional[2] * static_cast<double>(cells_shape_[2]))),
     687        2365 :     };
     688             : 
     689             :     // deal with pbc by wrapping the atom inside if it was outside of the cell
     690             :     CellShift shift;
     691        9460 :     for (size_t spatial = 0; spatial < 3; spatial++) {
     692        7095 :         auto result = divmod(cell_index[spatial], cells_shape_[spatial]);
     693        7095 :         shift[spatial] = std::get<0>(result);
     694        7095 :         cell_index[spatial] = std::get<1>(result);
     695             : 
     696             :         assert(box_.periodic(spatial) || shift[spatial] == 0);
     697             :     }
     698             : 
     699        2365 :     this->get_cell(cell_index).emplace_back(Point{index, shift});
     700        2365 : }
     701             : 
     702             : // clang-format off
     703             : template <typename Function>
     704          17 : void CellList::foreach_pair(Function callback) {
     705          56 :     for (int32_t cell_i_x=0; cell_i_x<static_cast<int32_t>(cells_shape_[0]); cell_i_x++) {
     706         188 :     for (int32_t cell_i_y=0; cell_i_y<static_cast<int32_t>(cells_shape_[1]); cell_i_y++) {
     707        1030 :     for (int32_t cell_i_z=0; cell_i_z<static_cast<int32_t>(cells_shape_[2]); cell_i_z++) {
     708         881 :         const auto& current_cell = this->get_cell({cell_i_x, cell_i_y, cell_i_z});
     709             :         // look through each neighboring cell
     710        3536 :         for (int32_t delta_x=-n_search_[0]; delta_x<=n_search_[0]; delta_x++) {
     711       10728 :         for (int32_t delta_y=-n_search_[1]; delta_y<=n_search_[1]; delta_y++) {
     712       33264 :         for (int32_t delta_z=-n_search_[2]; delta_z<=n_search_[2]; delta_z++) {
     713       25191 :             auto cell_i = std::array<int32_t, 3>{
     714       25191 :                 cell_i_x + delta_x,
     715       25191 :                 cell_i_y + delta_y,
     716       25191 :                 cell_i_z + delta_z,
     717             :             };
     718             : 
     719             :             // shift vector from one cell to the other and index of
     720             :             // the neighboring cell
     721       25191 :             auto [cell_shift, neighbor_cell_i] = divmod(cell_i, cells_shape_);
     722             : 
     723       90450 :             for (const auto& atom_i: current_cell) {
     724      280884 :                 for (const auto& atom_j: this->get_cell(neighbor_cell_i)) {
     725      215625 :                     auto shift = CellShift{cell_shift} + atom_i.shift - atom_j.shift;
     726      215625 :                     auto shift_is_zero = shift[0] == 0 && shift[1] == 0 && shift[2] == 0;
     727             : 
     728      215625 :                     if ((shift[0] != 0 && !box_.periodic(0)) ||
     729      429954 :                         (shift[1] != 0 && !box_.periodic(1)) ||
     730       35744 :                         (shift[2] != 0 && !box_.periodic(2)))
     731             :                     {
     732             :                         // do not create pairs crossing the periodic
     733             :                         // boundaries in a non-periodic box
     734        1456 :                         continue;
     735             :                     }
     736             : 
     737      214169 :                     if (atom_i.index == atom_j.index && shift_is_zero) {
     738             :                         // only create pairs with the same atom twice if the
     739             :                         // pair spans more than one bounding box
     740        2365 :                         continue;
     741             :                     }
     742             : 
     743      211804 :                     callback(atom_i.index, atom_j.index, shift);
     744             :                 }
     745             :             } // loop over atoms in current neighbor cells
     746             :         }}}
     747             :     }}} // loop over neighboring cells
     748          17 : }
     749             : 
     750       68505 : CellList::Cell& CellList::get_cell(std::array<int32_t, 3> index) {
     751       68505 :     size_t linear_index = (cells_shape_[0] * cells_shape_[1] * index[2])
     752       68505 :                         + (cells_shape_[0] * index[1])
     753       68505 :                         + index[0];
     754       68505 :     return cells_[linear_index];
     755             : }
     756             : // clang-format on
     757             : 
     758             : /* ========================================================================== */
     759             : 
     760       15390 : void GrowableNeighborList::set_pair(size_t index, size_t first, size_t second) {
     761       15390 :     if (index >= this->capacity) {
     762          83 :         this->grow();
     763             :     }
     764             : 
     765       15390 :     this->neighbors.pairs[index][0] = first;
     766       15390 :     this->neighbors.pairs[index][1] = second;
     767       15390 : }
     768             : 
     769       15390 : void GrowableNeighborList::set_shift(size_t index, PLMD::metatomic::vesin::CellShift shift) {
     770       15390 :     if (index >= this->capacity) {
     771           0 :         this->grow();
     772             :     }
     773             : 
     774       15390 :     this->neighbors.shifts[index][0] = shift[0];
     775       15390 :     this->neighbors.shifts[index][1] = shift[1];
     776       15390 :     this->neighbors.shifts[index][2] = shift[2];
     777       15390 : }
     778             : 
     779           0 : void GrowableNeighborList::set_distance(size_t index, double distance) {
     780           0 :     if (index >= this->capacity) {
     781           0 :         this->grow();
     782             :     }
     783             : 
     784           0 :     this->neighbors.distances[index] = distance;
     785           0 : }
     786             : 
     787       15390 : void GrowableNeighborList::set_vector(size_t index, PLMD::metatomic::vesin::Vector vector) {
     788       15390 :     if (index >= this->capacity) {
     789           0 :         this->grow();
     790             :     }
     791             : 
     792       15390 :     this->neighbors.vectors[index][0] = vector[0];
     793       15390 :     this->neighbors.vectors[index][1] = vector[1];
     794       15390 :     this->neighbors.vectors[index][2] = vector[2];
     795       15390 : }
     796             : 
     797             : template <typename scalar_t, size_t N>
     798             : using array_ptr = scalar_t (*)[N];
     799             : 
     800             : template <typename scalar_t, size_t N>
     801         292 : static array_ptr<scalar_t, N> alloc(array_ptr<scalar_t, N> ptr, size_t size, size_t new_size) {
     802         292 :     auto* new_ptr = reinterpret_cast<scalar_t(*)[N]>(std::realloc(ptr, new_size * sizeof(scalar_t[N])));
     803             : 
     804         292 :     if (new_ptr == nullptr) {
     805           0 :         return nullptr;
     806             :     }
     807             : 
     808             : #ifndef NDEBUG
     809             :     // initialize with a bit pattern that maps to NaN for double
     810             :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t[N]));
     811             : #endif
     812             : 
     813             :     return new_ptr;
     814             : }
     815             : 
     816             : template <typename scalar_t>
     817           0 : static scalar_t* alloc(scalar_t* ptr, size_t size, size_t new_size) {
     818           0 :     auto* new_ptr = reinterpret_cast<scalar_t*>(std::realloc(ptr, new_size * sizeof(scalar_t)));
     819             : 
     820           0 :     if (new_ptr == nullptr) {
     821           0 :         return nullptr;
     822             :     }
     823             : 
     824             : #ifndef NDEBUG
     825             :     // initialize with a bit pattern that maps to NaN for double
     826             :     std::memset(new_ptr + size, 0b11111111, (new_size - size) * sizeof(scalar_t));
     827             : #endif
     828             : 
     829             :     return new_ptr;
     830             : }
     831             : 
     832          83 : void GrowableNeighborList::grow() {
     833          83 :     auto new_size = neighbors.length * 2;
     834             :     if (new_size == 0) {
     835             :         new_size = 1;
     836             :     }
     837             : 
     838          83 :     auto* new_pairs = alloc<size_t, 2>(neighbors.pairs, neighbors.length, new_size);
     839             : 
     840             :     int32_t (*new_shifts)[3] = nullptr;
     841          83 :     if (options.return_shifts) {
     842          83 :         new_shifts = alloc<int32_t, 3>(neighbors.shifts, neighbors.length, new_size);
     843             :     }
     844             : 
     845             :     double* new_distances = nullptr;
     846          83 :     if (options.return_distances) {
     847           0 :         new_distances = alloc<double>(neighbors.distances, neighbors.length, new_size);
     848             :     }
     849             : 
     850             :     double (*new_vectors)[3] = nullptr;
     851          83 :     if (options.return_vectors) {
     852          83 :         new_vectors = alloc<double, 3>(neighbors.vectors, neighbors.length, new_size);
     853             :     }
     854             : 
     855          83 :     if (
     856          83 :         (new_pairs == nullptr) ||
     857          83 :         (options.return_shifts && new_shifts == nullptr) ||
     858          83 :         (options.return_distances && new_distances == nullptr) ||
     859          83 :         (options.return_vectors && new_vectors == nullptr)
     860             :     ) {
     861           0 :         std::free(new_pairs);
     862           0 :         std::free(new_shifts);
     863           0 :         std::free(new_distances);
     864           0 :         std::free(new_vectors);
     865           0 :         throw std::runtime_error("could not allocate memory for growing neighbor list");
     866             :     }
     867             : 
     868          83 :     this->neighbors.pairs = new_pairs;
     869          83 :     this->neighbors.shifts = new_shifts;
     870          83 :     this->neighbors.distances = new_distances;
     871          83 :     this->neighbors.vectors = new_vectors;
     872             : 
     873          83 :     this->capacity = new_size;
     874          83 : }
     875             : 
     876          17 : void GrowableNeighborList::reset() {
     877             : #ifndef NDEBUG
     878             :     auto size = this->neighbors.length;
     879             :     // set all allocated data to a bit pattern that maps to NaN for double
     880             :     std::memset(this->neighbors.pairs, 0b11111111, size * sizeof(size_t[2]));
     881             : 
     882             :     if (this->neighbors.shifts != nullptr) {
     883             :         std::memset(this->neighbors.shifts, 0b11111111, size * sizeof(int32_t[3]));
     884             :     }
     885             : 
     886             :     if (this->neighbors.distances != nullptr) {
     887             :         std::memset(this->neighbors.distances, 0b11111111, size * sizeof(double));
     888             :     }
     889             : 
     890             :     if (this->neighbors.vectors != nullptr) {
     891             :         std::memset(this->neighbors.vectors, 0b11111111, size * sizeof(double[3]));
     892             :     }
     893             : #endif
     894             : 
     895             :     // reset length (but keep the capacity where it's at)
     896          17 :     this->neighbors.length = 0;
     897             : 
     898             :     // allocate/deallocate pointers as required
     899          17 :     auto* shifts = this->neighbors.shifts;
     900          17 :     if (this->options.return_shifts && shifts == nullptr) {
     901           8 :         shifts = alloc<int32_t, 3>(shifts, 0, capacity);
     902           9 :     } else if (!this->options.return_shifts && shifts != nullptr) {
     903           0 :         std::free(shifts);
     904             :         shifts = nullptr;
     905             :     }
     906             : 
     907          17 :     auto* distances = this->neighbors.distances;
     908          17 :     if (this->options.return_distances && distances == nullptr) {
     909           0 :         distances = alloc<double>(distances, 0, capacity);
     910          17 :     } else if (!this->options.return_distances && distances != nullptr) {
     911           0 :         std::free(distances);
     912             :         distances = nullptr;
     913             :     }
     914             : 
     915          17 :     auto* vectors = this->neighbors.vectors;
     916          17 :     if (this->options.return_vectors && vectors == nullptr) {
     917           8 :         vectors = alloc<double, 3>(vectors, 0, capacity);
     918           9 :     } else if (!this->options.return_vectors && vectors != nullptr) {
     919           0 :         std::free(vectors);
     920             :         vectors = nullptr;
     921             :     }
     922             : 
     923          17 :     this->neighbors.shifts = shifts;
     924          17 :     this->neighbors.distances = distances;
     925          17 :     this->neighbors.vectors = vectors;
     926          17 : }
     927             : 
     928          17 : void GrowableNeighborList::sort() {
     929          17 :     if (this->length() == 0) {
     930           8 :         return;
     931             :     }
     932             : 
     933             :     // step 1: sort an array of indices, comparing the pairs at the indices
     934           9 :     auto indices = std::vector<int64_t>(this->length(), 0);
     935           9 :     std::iota(std::begin(indices), std::end(indices), 0);
     936             : 
     937             :     struct compare_pairs {
     938             :         compare_pairs(size_t (*pairs_)[2]):
     939             :             pairs(pairs_) {}
     940             : 
     941             :         bool operator()(int64_t a, int64_t b) const {
     942      186338 :             return pairs[a][0] < pairs[b][0];
     943             :         }
     944             : 
     945             :         size_t (*pairs)[2];
     946             :     };
     947             : 
     948          18 :     std::sort(
     949             :         std::begin(indices),
     950             :         std::end(indices),
     951           9 :         compare_pairs(this->neighbors.pairs)
     952             :     );
     953             : 
     954             :     // step 2: move all data according to the sorted indices.
     955           9 :     auto* sorted_pairs = alloc<size_t, 2>(nullptr, 0, this->capacity);
     956             : 
     957             :     int32_t (*sorted_shifts)[3] = nullptr;
     958           9 :     if (options.return_shifts) {
     959           9 :         sorted_shifts = alloc<int32_t, 3>(nullptr, 0, this->capacity);
     960             :     }
     961             : 
     962             :     double* sorted_distances = nullptr;
     963           9 :     if (options.return_distances) {
     964           0 :         sorted_distances = alloc<double>(nullptr, 0, this->capacity);
     965             :     }
     966             : 
     967             :     double (*sorted_vectors)[3] = nullptr;
     968           9 :     if (options.return_vectors) {
     969           9 :         sorted_vectors = alloc<double, 3>(nullptr, 0, this->capacity);
     970             :     }
     971             : 
     972           9 :     if (
     973           9 :         (sorted_pairs == nullptr) ||
     974           9 :         (options.return_shifts && sorted_shifts == nullptr) ||
     975           9 :         (options.return_distances && sorted_distances == nullptr) ||
     976           9 :         (options.return_vectors && sorted_vectors == nullptr)
     977             :     ) {
     978           0 :         std::free(sorted_pairs);
     979           0 :         std::free(sorted_shifts);
     980           0 :         std::free(sorted_distances);
     981           0 :         std::free(sorted_vectors);
     982           0 :         throw std::runtime_error("could not allocate memory for sorting neighbor list");
     983             :     }
     984             : 
     985       15399 :     for (size_t i = 0; i < this->neighbors.length; i++) {
     986       15390 :         auto from = static_cast<size_t>(indices[i]);
     987       15390 :         sorted_pairs[i][0] = this->neighbors.pairs[from][0];
     988       15390 :         sorted_pairs[i][1] = this->neighbors.pairs[from][1];
     989             : 
     990       15390 :         if (options.return_shifts) {
     991       15390 :             sorted_shifts[i][0] = this->neighbors.shifts[from][0];
     992       15390 :             sorted_shifts[i][1] = this->neighbors.shifts[from][1];
     993       15390 :             sorted_shifts[i][2] = this->neighbors.shifts[from][2];
     994             :         }
     995             : 
     996       15390 :         if (options.return_distances) {
     997           0 :             sorted_distances[i] = this->neighbors.distances[from];
     998             :         }
     999             : 
    1000       15390 :         if (options.return_vectors) {
    1001       15390 :             sorted_vectors[i][0] = this->neighbors.vectors[from][0];
    1002       15390 :             sorted_vectors[i][1] = this->neighbors.vectors[from][1];
    1003       15390 :             sorted_vectors[i][2] = this->neighbors.vectors[from][2];
    1004             :         }
    1005             :     }
    1006             : 
    1007           9 :     std::free(this->neighbors.pairs);
    1008           9 :     this->neighbors.pairs = sorted_pairs;
    1009             : 
    1010           9 :     if (options.return_shifts) {
    1011           9 :         std::free(this->neighbors.shifts);
    1012           9 :         this->neighbors.shifts = sorted_shifts;
    1013             :     }
    1014             : 
    1015           9 :     if (options.return_distances) {
    1016           0 :         std::free(this->neighbors.distances);
    1017           0 :         this->neighbors.distances = sorted_distances;
    1018             :     }
    1019             : 
    1020           9 :     if (options.return_vectors) {
    1021           9 :         std::free(this->neighbors.vectors);
    1022           9 :         this->neighbors.vectors = sorted_vectors;
    1023             :     }
    1024             : }
    1025             : 
    1026           8 : void PLMD::metatomic::vesin::cpu::free_neighbors(VesinNeighborList& neighbors) {
    1027             :     assert(neighbors.device.type == VesinCPU);
    1028             : 
    1029           8 :     std::free(neighbors.pairs);
    1030           8 :     std::free(neighbors.shifts);
    1031           8 :     std::free(neighbors.vectors);
    1032           8 :     std::free(neighbors.distances);
    1033           8 : }
    1034             : #include <stdexcept>
    1035             : 
    1036             : #ifndef VESIN_CUDA_HPP
    1037             : #define VESIN_CUDA_HPP
    1038             : 
    1039             : 
    1040             : namespace PLMD {
    1041             : namespace metatomic {
    1042             : namespace vesin {
    1043             : namespace cuda {
    1044             : 
    1045             : #ifndef VESIN_CUDA_AT_LEAST_PAIRS_PER_POINT
    1046             : /// Default value for the number of pairs per points in the CUDA implementation.
    1047             : /// Unless `VESIN_CUDA_MAX_PAIRS_PER_POINT` is set in the environement, the
    1048             : /// maximal number of pairs is `n_points *
    1049             : /// max(VESIN_CUDA_AT_LEAST_PAIRS_PER_POINT, cutoff^3)`. This can be overriden
    1050             : /// at compile time.
    1051             : #define VESIN_CUDA_AT_LEAST_PAIRS_PER_POINT 128
    1052             : #endif
    1053             : 
    1054             : /// @brief Buffers for cell list-based neighbor search
    1055             : struct CellListBuffers {
    1056             :     size_t max_points = 0; // Capacity for point-related arrays
    1057             :     size_t max_cells = 0;  // Capacity for cell-related arrays
    1058             : 
    1059             :     // Per-particle arrays
    1060             :     int32_t* cell_indices = nullptr;    // [max_points] linear cell index per particle
    1061             :     int32_t* particle_shifts = nullptr; // [max_points * 3] shift applied to wrap into cell
    1062             : 
    1063             :     // Per-cell arrays
    1064             :     int32_t* cell_counts = nullptr;  // [max_cells] number of particles in each cell
    1065             :     int32_t* cell_starts = nullptr;  // [max_cells] starting index in sorted arrays
    1066             :     int32_t* cell_offsets = nullptr; // [max_cells] working copy for scatter
    1067             : 
    1068             :     // Sorted particle data (for coalesced memory access)
    1069             :     double* sorted_positions = nullptr;     // [max_points * 3]
    1070             :     int32_t* sorted_indices = nullptr;      // [max_points] original particle indices
    1071             :     int32_t* sorted_shifts = nullptr;       // [max_points * 3] shifts for sorted particles
    1072             :     int32_t* sorted_cell_indices = nullptr; // [max_points] cell indices in sorted order
    1073             : 
    1074             :     // Cell grid parameters (computed on device)
    1075             :     double* inv_box = nullptr;        // [9] inverse box matrix
    1076             :     int32_t* n_cells = nullptr;       // [3] number of cells in each direction
    1077             :     int32_t* n_search = nullptr;      // [3] search range in each direction
    1078             :     int32_t* n_cells_total = nullptr; // [1] total number of cells
    1079             : 
    1080             :     double* bounding_min = nullptr; // [3] per-dimension min for non-periodic axes
    1081             :     double* bounding_max = nullptr; // [3] per-dimension max for non-periodic axes
    1082             : };
    1083             : 
    1084             : struct CudaNeighborListExtras {
    1085             :     size_t* length_ptr = nullptr;      // GPU-side counter
    1086             :     size_t capacity = 0;               // Current capacity per device
    1087             :     size_t max_pairs = 0;              // Maximum number of pairs that can be stored; depends on VESIN_CUDA_MAX_PAIRS_PER_POINT
    1088             :     int32_t* cell_check_ptr = nullptr; // GPU-side status code for checking cell
    1089             :     int32_t* overflow_flag = nullptr;  // GPU-side flag to detect overflow of pair buffers
    1090             :     int32_t allocated_device_id = -1;  // which device are we currently allocated on
    1091             : 
    1092             :     // Pinned host memory for async D2H copy (Approach 2)
    1093             :     size_t* pinned_length_ptr = nullptr;
    1094             : 
    1095             :     // Cell list buffers (allocated on demand for large systems)
    1096             :     CellListBuffers cell_list;
    1097             : 
    1098             :     // Buffers for optimized brute force kernels
    1099             :     double* box_diag = nullptr;      // [3] diagonal elements for orthogonal boxes
    1100             :     double* inv_box_brute = nullptr; // [9] inverse box matrix for general boxes
    1101             : 
    1102             :     // Temporary buffers for on-device sorting
    1103             :     size_t* sort_pairs_tmp = nullptr;     // [sort_capacity * 2]
    1104             :     int32_t* sort_shifts_tmp = nullptr;   // [sort_capacity * 3]
    1105             :     double* sort_distances_tmp = nullptr; // [sort_capacity]
    1106             :     double* sort_vectors_tmp = nullptr;   // [sort_capacity * 3]
    1107             :     size_t sort_capacity = 0;
    1108             : 
    1109             :     ~CudaNeighborListExtras();
    1110             : };
    1111             : 
    1112             : /// @brief Frees GPU memory associated with a VesinNeighborList.
    1113             : ///
    1114             : /// This function should be called to release all CUDA-allocated memory
    1115             : /// tied to the given neighbor list. It does not delete the structure itself,
    1116             : /// only the device-side memory buffers.
    1117             : ///
    1118             : /// @param neighbors Reference to the VesinNeighborList to clean up.
    1119             : void free_neighbors(VesinNeighborList& neighbors);
    1120             : 
    1121             : /// @brief Computes the neighbor list on the GPU.
    1122             : ///
    1123             : /// This function only works under Minimum Image Convention for now.
    1124             : ///
    1125             : /// This function generates a neighbor list for a set of points within a
    1126             : /// periodic simulation box using GPU acceleration. The output is stored in a
    1127             : /// `VesinNeighborList` structure, which must be initialized for GPU usage.
    1128             : ///
    1129             : /// @param points Pointer to an array of 3D points (shape: [n_points][3]).
    1130             : /// @param n_points Number of points (atoms, particles, etc.).
    1131             : /// @param box 3×3 matrix defining the bounding box of the system.
    1132             : /// @param periodic Array of three booleans indicating periodicity in each dimension.
    1133             : /// @param options Struct holding parameters such as cutoff, symmetry, etc.
    1134             : /// @param neighbors Output neighbor list (device memory will be allocated as
    1135             : /// needed).
    1136             : void neighbors(
    1137             :     const double (*points)[3],
    1138             :     size_t n_points,
    1139             :     const double box[3][3],
    1140             :     const bool periodic[3],
    1141             :     VesinOptions options,
    1142             :     VesinNeighborList& neighbors
    1143             : );
    1144             : 
    1145             : /// Get the `CudaNeighborListExtras` stored inside `VesinNeighborList`'s opaque pointer
    1146             : CudaNeighborListExtras* get_cuda_extras(VesinNeighborList* neighbors);
    1147             : 
    1148             : } // namespace cuda
    1149             : } // namespace vesin
    1150             : } // namespace metatomic
    1151             : } // namespace PLMD
    1152             : 
    1153             : #endif // VESIN_CUDA_HPP
    1154             : 
    1155           0 : void PLMD::metatomic::vesin::cuda::free_neighbors(VesinNeighborList& neighbors) {
    1156           0 :     throw std::runtime_error("CUDA neighbor list generation is not included in this build of vesin");
    1157             : }
    1158             : 
    1159           0 : void PLMD::metatomic::vesin::cuda::neighbors(
    1160             :     const double (*points)[3],
    1161             :     size_t n_points,
    1162             :     const double box[3][3],
    1163             :     const bool periodic[3],
    1164             :     VesinOptions options,
    1165             :     VesinNeighborList& neighbors
    1166             : ) {
    1167           0 :     throw std::runtime_error("CUDA neighbor list generation is not included in this build of vesin");
    1168             : }
    1169             : #include <cstring>
    1170             : #include <iostream>
    1171             : #include <string>
    1172             : 
    1173             : 
    1174             : // used to store dynamically allocated error messages before giving a pointer
    1175             : // to them back to the user
    1176             : thread_local std::string LAST_ERROR;
    1177             : 
    1178          17 : extern "C" int vesin_neighbors(
    1179             :     const double (*points)[3],
    1180             :     size_t n_points,
    1181             :     const double box[3][3],
    1182             :     const bool periodic[3],
    1183             :     VesinDevice device,
    1184             :     VesinOptions options,
    1185             :     VesinNeighborList* neighbors,
    1186             :     const char** error_message
    1187             : ) {
    1188          17 :     if (error_message == nullptr) {
    1189             :         return EXIT_FAILURE;
    1190             :     }
    1191             : 
    1192          17 :     if (points == nullptr) {
    1193           0 :         *error_message = "`points` can not be a NULL pointer";
    1194           0 :         return EXIT_FAILURE;
    1195             :     }
    1196             : 
    1197          17 :     if (box == nullptr) {
    1198           0 :         *error_message = "`cell` can not be a NULL pointer";
    1199           0 :         return EXIT_FAILURE;
    1200             :     }
    1201             : 
    1202          17 :     if (neighbors == nullptr) {
    1203           0 :         *error_message = "`neighbors` can not be a NULL pointer";
    1204           0 :         return EXIT_FAILURE;
    1205             :     }
    1206             : 
    1207          17 :     if (!std::isfinite(options.cutoff) || options.cutoff <= 0) {
    1208           0 :         *error_message = "cutoff must be a finite, positive number";
    1209           0 :         return EXIT_FAILURE;
    1210             :     }
    1211             : 
    1212          17 :     if (options.cutoff <= 1e-6) {
    1213           0 :         *error_message = "cutoff is too small";
    1214           0 :         return EXIT_FAILURE;
    1215             :     }
    1216             : 
    1217          17 :     if (neighbors->device.type != VesinUnknownDevice && neighbors->device.type != device.type) {
    1218           0 :         *error_message = "`neighbors` device and data `device` do not match, free the neighbors first";
    1219           0 :         return EXIT_FAILURE;
    1220             :     }
    1221             : 
    1222          17 :     if (device.type == VesinUnknownDevice) {
    1223           0 :         *error_message = "got an unknown device type";
    1224           0 :         return EXIT_FAILURE;
    1225             :     }
    1226             : 
    1227          17 :     if (neighbors->device.type == VesinUnknownDevice) {
    1228             :         // initialize the device
    1229           8 :         neighbors->device = device;
    1230           9 :     } else if (neighbors->device.type != device.type) {
    1231           0 :         *error_message = "`neighbors.device` and `device` do not match, free the neighbors first";
    1232           0 :         return EXIT_FAILURE;
    1233             :     }
    1234             : 
    1235             :     try {
    1236          17 :         if (device.type == VesinCPU) {
    1237             :             auto matrix = PLMD::metatomic::vesin::Matrix{{{
    1238          17 :                 {{box[0][0], box[0][1], box[0][2]}},
    1239          17 :                 {{box[1][0], box[1][1], box[1][2]}},
    1240          17 :                 {{box[2][0], box[2][1], box[2][2]}},
    1241          17 :             }}};
    1242             : 
    1243          17 :             auto box = PLMD::metatomic::vesin::BoundingBox(matrix, periodic);
    1244          17 :             box.make_bounding_for(points, n_points);
    1245             : 
    1246          17 :             PLMD::metatomic::vesin::cpu::neighbors(
    1247             :                 reinterpret_cast<const PLMD::metatomic::vesin::Vector*>(points),
    1248             :                 n_points,
    1249             :                 std::move(box),
    1250             :                 options,
    1251             :                 *neighbors
    1252             :             );
    1253           0 :         } else if (device.type == VesinCUDA) {
    1254           0 :             PLMD::metatomic::vesin::cuda::neighbors(
    1255             :                 points,
    1256             :                 n_points,
    1257             :                 box,
    1258             :                 periodic,
    1259             :                 options,
    1260             :                 *neighbors
    1261             :             );
    1262             :         } else {
    1263           0 :             throw std::runtime_error("unknown device " + std::to_string(device.type));
    1264             :         }
    1265           0 :     } catch (const std::bad_alloc&) {
    1266           0 :         LAST_ERROR = "failed to allocate memory";
    1267           0 :         *error_message = LAST_ERROR.c_str();
    1268             :         return EXIT_FAILURE;
    1269           0 :     } catch (const std::exception& e) {
    1270           0 :         LAST_ERROR = e.what();
    1271           0 :         *error_message = LAST_ERROR.c_str();
    1272             :         return EXIT_FAILURE;
    1273           0 :     } catch (...) {
    1274           0 :         *error_message = "fatal error: unknown type thrown as exception";
    1275             :         return EXIT_FAILURE;
    1276           0 :     }
    1277             : 
    1278             :     return EXIT_SUCCESS;
    1279             : }
    1280             : 
    1281           8 : extern "C" void vesin_free(VesinNeighborList* neighbors) {
    1282           8 :     if (neighbors == nullptr) {
    1283             :         return;
    1284             :     }
    1285             : 
    1286             :     try {
    1287           8 :         if (neighbors->device.type == VesinUnknownDevice) {
    1288             :             // nothing to do
    1289           8 :         } else if (neighbors->device.type == VesinCPU) {
    1290           8 :             PLMD::metatomic::vesin::cpu::free_neighbors(*neighbors);
    1291           0 :         } else if (neighbors->device.type == VesinCUDA) {
    1292           0 :             PLMD::metatomic::vesin::cuda::free_neighbors(*neighbors);
    1293             :         } else {
    1294           0 :             throw std::runtime_error("unknown device " + std::to_string(neighbors->device.type) + " when freeing memory");
    1295             :         }
    1296           0 :     } catch (const std::exception& e) {
    1297           0 :         std::cerr << "error in vesin_free: " << e.what() << std::endl;
    1298             :         return;
    1299           0 :     } catch (...) {
    1300           0 :         std::cerr << "fatal error in vesin_free, unknown type thrown as exception" << std::endl;
    1301             :         return;
    1302           0 :     }
    1303             : 
    1304             :     std::memset(neighbors, 0, sizeof(VesinNeighborList));
    1305             : }

Generated by: LCOV version 1.16