LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.9/site-packages/metatensor/include - metatensor.hpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 13 21 61.9 %
Date: 2025-11-25 13:55:50 Functions: 3 4 75.0 %

          Line data    Source code
       1             : #ifndef METATENSOR_HPP
       2             : #define METATENSOR_HPP
       3             : 
       4             : #include <array>
       5             : #include <vector>
       6             : #include <string>
       7             : #include <memory>
       8             : #include <stdexcept>
       9             : #include <exception>
      10             : #include <functional>
      11             : #include <type_traits>
      12             : #include <initializer_list>
      13             : 
      14             : #include <cassert>
      15             : #include <cstdint>
      16             : #include <cstring>
      17             : #include <cstdlib>
      18             : 
      19             : #include "metatensor.h"
      20             : 
      21             : /// This file contains the C++ API to metatensor, manually built on top of the C
      22             : /// API defined in `metatensor.h`. This API uses the standard C++ library where
      23             : /// convenient, but also allow to drop back to the C API if required, by
      24             : /// providing functions to extract the C API handles (named `as_mts_XXX`).
      25             : 
      26             : static_assert(sizeof(char) == sizeof(uint8_t), "char must be 8-bits wide");
      27             : 
      28             : namespace metatensor_torch {
      29             :     class LabelsHolder;
      30             :     class TensorBlockHolder;
      31             :     class TensorMapHolder;
      32             : }
      33             : 
      34             : namespace metatensor {
      35             : class Labels;
      36             : class TensorMap;
      37             : class TensorBlock;
      38             : 
      39             : /// Tag for the creation of Labels without uniqueness checks
      40             : struct assume_unique {};
      41             : 
      42             : /// Exception class used for all errors in metatensor
      43             : class Error: public std::runtime_error {
      44             : public:
      45             :     /// Create a new Error with the given `message`
      46           0 :     Error(const std::string& message): std::runtime_error(message) {}
      47             : };
      48             : 
      49             : namespace details {
      50             :     /// Singleton class storing the last exception throw by a C++ callback.
      51             :     ///
      52             :     /// When passing callbacks from C++ to Rust, we need to convert exceptions
      53             :     /// into status code (see the `catch` blocks in this file). This class
      54             :     /// allows to save the message associated with an exception, and rethrow an
      55             :     /// exception with the same message later (the actual exception type is lost
      56             :     /// in the process).
      57             :     class LastCxxError {
      58             :     public:
      59             :         /// Set the last error message to `message`
      60             :         static void set_message(std::string message) {
      61             :             auto& stored_message = LastCxxError::get();
      62             :             stored_message = std::move(message);
      63             :         }
      64             : 
      65             :         /// Get the last error message
      66             :         static const std::string& message() {
      67           0 :             return LastCxxError::get();
      68             :         }
      69             : 
      70             :     private:
      71           0 :         static std::string& get() {
      72             :             #pragma clang diagnostic push
      73             :             #pragma clang diagnostic ignored "-Wexit-time-destructors"
      74             :             /// we are using a per-thread static value to store the last C++
      75             :             /// exception.
      76           0 :             static thread_local std::string STORED_MESSAGE;
      77             :             #pragma clang diagnostic pop
      78             : 
      79           0 :             return STORED_MESSAGE;
      80             :         }
      81             :     };
      82             : 
      83             :     /// Check if a return status from the C API indicates an error, and if it is
      84             :     /// the case, throw an exception of type `metatensor::Error` with the last
      85             :     /// error message from the library.
      86         124 :     inline void check_status(mts_status_t status) {
      87         124 :         if (status == MTS_SUCCESS) {
      88         124 :             return;
      89           0 :         } else if (status > 0) {
      90           0 :             throw Error(mts_last_error());
      91             :         } else { // status < 0
      92           0 :             throw Error("error in C++ callback: " + LastCxxError::message());
      93             :         }
      94             :     }
      95             : 
      96             :     /// Call the given `function` with the given `args` (the function should
      97             :     /// return an `mts_status_t`), catching any C++ exception, and translating
      98             :     /// them to negative metatensor error code.
      99             :     ///
     100             :     /// This is required to prevent callbacks unwinding through the C API.
     101             :     template<typename Function, typename ...Args>
     102             :     inline mts_status_t catch_exceptions(Function function, Args ...args) {
     103             :         try {
     104             :             return function(std::move(args)...);
     105             :         } catch (const std::exception& e) {
     106             :             details::LastCxxError::set_message(e.what());
     107             :             return -1;
     108             :         } catch (...) {
     109             :             details::LastCxxError::set_message("error was not an std::exception");
     110             :             return -128;
     111             :         }
     112             :     }
     113             : 
     114             :     /// Check if a pointer allocated by the C API is null, and if it is the
     115             :     /// case, throw an exception of type `metatensor::Error` with the last error
     116             :     /// message from the library.
     117             :     inline void check_pointer(const void* pointer) {
     118             :         if (pointer == nullptr) {
     119             :             throw Error(mts_last_error());
     120             :         }
     121             :     }
     122             : 
     123             :     /// Compute the product of all values in the `shape` vector
     124             :     inline size_t product(const std::vector<size_t>& shape) {
     125             :         size_t result = 1;
     126             :         for (auto size: shape) {
     127             :             result *= size;
     128             :         }
     129             :         return result;
     130             :     }
     131             : 
     132             :     /// Get the N-dimensional index corresponding to the given linear `index`
     133             :     /// and array `shape`
     134             :     inline std::vector<size_t> cartesian_index(const std::vector<size_t>& shape, size_t index) {
     135             :         auto result = std::vector<size_t>(shape.size(), 0);
     136             :         for (size_t i=0; i<shape.size(); i++) {
     137             :             result[i] = index % shape[i];
     138             :             index = index / shape[i];
     139             :         }
     140             :         assert(index == 0);
     141             :         return result;
     142             :     }
     143             : 
     144             :     /// Get the linear index corresponding to the N-dimensional
     145             :     /// `index[index_size]`, according to the given array `shape`
     146             :     inline size_t linear_index(const std::vector<size_t>& shape, const size_t* index, size_t index_size) {
     147             :         assert(index_size != 0);
     148             :         assert(index_size == shape.size());
     149             : 
     150             :         if (index_size == 1) {
     151             :             assert(index[0] < shape[0] && "out of bounds");
     152             :             return index[0];
     153             :         } else {
     154             :             assert(index[0] < shape[0]);
     155             :             auto linear_index = index[0];
     156             :             for (size_t i=1; i<index_size; i++) {
     157             :                 assert(index[i] < shape[i] && "out of bounds");
     158             :                 linear_index *= shape[i];
     159             :                 linear_index += index[i];
     160             :             }
     161             : 
     162             :             return linear_index;
     163             :         }
     164             :     }
     165             : 
     166             :     template<size_t N>
     167             :     size_t linear_index(const std::vector<size_t>& shape, const std::array<size_t, N>& index) {
     168             :         return linear_index(shape, index.data(), index.size());
     169             :     }
     170             : 
     171             :     inline size_t linear_index(const std::vector<size_t>& shape, const std::vector<size_t>& index) {
     172             :         return linear_index(shape, index.data(), index.size());
     173             :     }
     174             : 
     175             :     Labels labels_from_cxx(const std::vector<std::string>& names, const int32_t* values, size_t count, bool assume_unique);
     176             : }
     177             : 
     178             : /******************************************************************************/
     179             : /******************************************************************************/
     180             : /******************************************************************************/
     181             : 
     182             : /******************************************************************************/
     183             : /******************************************************************************/
     184             : /*                                                                            */
     185             : /*                 N-Dimensional arrays handling                              */
     186             : /*                                                                            */
     187             : /******************************************************************************/
     188             : /******************************************************************************/
     189             : 
     190             : 
     191             : /// Simple N-dimensional array interface
     192             : ///
     193             : /// This class can either be a non-owning view inside some existing memory (for
     194             : /// example memory allocated by Rust); or own its memory (in the form of an
     195             : /// `std::vector<double>`). If the array does not own its memory, accessing it
     196             : /// is only valid for as long as the memory is kept alive.
     197             : ///
     198             : /// The API of this class is very intentionally minimal to keep metatensor as
     199             : /// simple as possible. Feel free to wrap the corresponding data inside types
     200             : /// with richer API such as Eigen, Boost, etc.
     201             : template<typename T>
     202             : class NDArray {
     203             : public:
     204             :     /// Create a new empty `NDArray`, with shape `[0, 0]`.
     205             :     NDArray(): NDArray(nullptr, {0, 0}, true) {}
     206             : 
     207             :     /// Create a new `NDArray` using a non-owning view in `const` memory with
     208             :     /// the given `shape`.
     209             :     ///
     210             :     /// `data` must point to contiguous memory containing the right number of
     211             :     /// elements as described by the `shape`, which will be interpreted as an
     212             :     /// N-dimensional array in row-major order. The resulting `NDArray` is
     213             :     /// only valid for as long as `data` is.
     214             :     NDArray(const T* data, std::vector<size_t> shape):
     215             :         NDArray(data, std::move(shape), /*is_const*/ true) {}
     216             : 
     217             :     /// Create a new `NDArray` using a non-owning view in non-`const` memory
     218             :     /// with the given `shape`.
     219             :     ///
     220             :     /// `data` must point to contiguous memory containing the right number of
     221             :     /// elements as described by the `shape`, which will be interpreted as an
     222             :     /// N-dimensional array in row-major order. The resulting `NDArray` is
     223             :     /// only valid for as long as `data` is.
     224             :     NDArray(T* data, std::vector<size_t> shape):
     225             :         NDArray(data, std::move(shape), /*is_const*/ false) {}
     226             : 
     227             :     /// Create a new `NDArray` *owning* its `data` with the given `shape`.
     228             :     NDArray(std::vector<T> data, std::vector<size_t> shape):
     229             :         NDArray(data.data(), std::move(shape), /*is_const*/ false)
     230             :     {
     231             :         using vector_t = std::vector<T>;
     232             : 
     233             :         owned_data_ = reinterpret_cast<void*>(new vector_t(std::move(data)));
     234             :         deleter_ = [](void* data){
     235             :             auto data_vector = reinterpret_cast<vector_t*>(data);
     236             :             delete data_vector;
     237             :         };
     238             :     }
     239             : 
     240             :     ~NDArray() {
     241             :         deleter_(this->owned_data_);
     242             :     }
     243             : 
     244             :     /// NDArray is not copy-constructible
     245             :     NDArray(const NDArray&) = delete;
     246             :     /// NDArray can not be copy-assigned
     247             :     NDArray& operator=(const NDArray& other) = delete;
     248             : 
     249             :     /// NDArray is move-constructible
     250             :     NDArray(NDArray&& other) noexcept: NDArray() {
     251             :         *this = std::move(other);
     252             :     }
     253             : 
     254             :     /// NDArray can be move-assigned
     255             :     NDArray& operator=(NDArray&& other) noexcept {
     256             :         this->deleter_(this->owned_data_);
     257             : 
     258             :         this->data_ = std::move(other.data_);
     259             :         this->shape_ = std::move(other.shape_);
     260             :         this->is_const_ = other.is_const_;
     261             :         this->owned_data_ = other.owned_data_;
     262             :         this->deleter_ = std::move(other.deleter_);
     263             : 
     264             :         other.data_ = nullptr;
     265             :         other.owned_data_ = nullptr;
     266             :         other.deleter_ = [](void*){};
     267             : 
     268             :         return *this;
     269             :     }
     270             : 
     271             :     /// Is this NDArray a view into external data?
     272             :     bool is_view() const {
     273             :         return owned_data_ == nullptr;
     274             :     }
     275             : 
     276             :     /// Get the value inside this `NDArray` at the given index
     277             :     ///
     278             :     /// ```
     279             :     /// auto array = NDArray(...);
     280             :     ///
     281             :     /// double value = array(2, 3, 1);
     282             :     /// ```
     283             :     template<typename ...Args>
     284             :     T operator()(Args... args) const & {
     285             :         auto index = std::array<size_t, sizeof... (Args)>{static_cast<size_t>(args)...};
     286             :         if (index.size() != shape_.size()) {
     287             :             throw Error(
     288             :                 "expected " + std::to_string(shape_.size()) +
     289             :                 " indexes in NDArray::operator(), got " + std::to_string(index.size())
     290             :             );
     291             :         }
     292             :         return data_[details::linear_index(shape_, index)];
     293             :     }
     294             : 
     295             :     /// Get a reference to the value inside this `NDArray` at the given index
     296             :     ///
     297             :     /// ```
     298             :     /// auto array = NDArray(...);
     299             :     ///
     300             :     /// array(2, 3, 1) = 5.2;
     301             :     /// ```
     302             :     template<typename ...Args>
     303             :     T& operator()(Args... args) & {
     304             :         if (is_const_) {
     305             :             throw Error("This NDArray is const, can not get non const access to it");
     306             :         }
     307             : 
     308             :         auto index = std::array<size_t, sizeof... (Args)>{static_cast<size_t>(args)...};
     309             :         if (index.size() != shape_.size()) {
     310             :             throw Error(
     311             :                 "expected " + std::to_string(shape_.size()) +
     312             :                 " indexes in Labels::operator(), got " + std::to_string(index.size())
     313             :             );
     314             :         }
     315             :         return data_[details::linear_index(shape_, index)];
     316             :     }
     317             : 
     318             :     template<typename ...Args>
     319             :     T& operator()(Args... args) && = delete;
     320             : 
     321             :     /// Get the data pointer for this array, i.e. the pointer to the first
     322             :     /// element.
     323             :     const T* data() const & {
     324             :         return data_;
     325             :     }
     326             : 
     327             :     /// Get the data pointer for this array, i.e. the pointer to the first
     328             :     /// element.
     329             :     T* data() & {
     330             :         if (is_const_) {
     331             :             throw Error("This NDArray is const, can not get non const access to it");
     332             :         }
     333             :         return data_;
     334             :     }
     335             : 
     336             :     const T* data() && = delete;
     337             : 
     338             :     /// Get the shape of this array
     339             :     const std::vector<size_t>& shape() const & {
     340             :         return shape_;
     341             :     }
     342             : 
     343             :     const std::vector<size_t>& shape() && = delete;
     344             : 
     345             :     /// Check if this array is empty, i.e. if at least one of the shape element
     346             :     /// is 0.
     347             :     bool is_empty() const {
     348             :         for (auto s: shape_) {
     349             :             if (s == 0) {
     350             :                 return true;
     351             :             }
     352             :         }
     353             :         return false;
     354             :     }
     355             : 
     356             : private:
     357             :     /// Create an `NDArray` from a pointer to the (row-major) data & shape.
     358             :     ///
     359             :     /// The `is_const` parameter controls whether this class should allow
     360             :     /// non-const access to the data.
     361             :     NDArray(const T* data, std::vector<size_t> shape, bool is_const):
     362             :         data_(const_cast<T*>(data)),
     363             :         shape_(std::move(shape)),
     364             :         is_const_(is_const),
     365             :         deleter_([](void*){})
     366             :     {
     367             :         validate();
     368             :     }
     369             : 
     370             :     /// Create a 2D NDArray from a vector of initializer lists. All the inner
     371             :     /// lists must have the same `size`.
     372             :     ///
     373             :     /// This allows creating an array (and in particular a set of Labels) with
     374             :     /// `NDArray({{1, 2}, {3, 4}, {5, 6}}, 2)`
     375             :     NDArray(const std::vector<std::initializer_list<T>>& data, size_t size):
     376             :         data_(nullptr),
     377             :         shape_({data.size(), size}),
     378             :         is_const_(false),
     379             :         deleter_([](void*){})
     380             :     {
     381             :         using vector_t = std::vector<T>;
     382             :         auto vector = std::vector<T>();
     383             :         vector.reserve(data.size() * size);
     384             :         for (auto row: std::move(data)) {
     385             :             if (row.size() != size) {
     386             :                 throw Error(
     387             :                     "invalid size for row: expected " + std::to_string(size) +
     388             :                     " got " + std::to_string(row.size())
     389             :                 );
     390             :             }
     391             : 
     392             :             for (auto entry: row) {
     393             :                 vector.emplace_back(entry);
     394             :             }
     395             :         }
     396             : 
     397             :         data_ = vector.data();
     398             :         owned_data_ = reinterpret_cast<void*>(new vector_t(std::move(vector)));
     399             :         deleter_ = [](void* data){
     400             :             auto data_vector = reinterpret_cast<vector_t*>(data);
     401             :             delete data_vector;
     402             :         };
     403             :         validate();
     404             :     }
     405             : 
     406             :     friend class Labels;
     407             : 
     408             :     void validate() const {
     409             :         static_assert(
     410             :             std::is_arithmetic<T>::value,
     411             :             "NDArray only works with integers and floating points"
     412             :         );
     413             : 
     414             :         if (shape_.empty()) {
     415             :             throw Error("invalid parameters to NDArray, shape should contain at least one element");
     416             :         }
     417             : 
     418             :         size_t size = 1;
     419             :         for (auto s: shape_) {
     420             :             size *= s;
     421             :         }
     422             : 
     423             :         if (size != 0 && data_ == nullptr) {
     424             :             throw Error("invalid parameters to NDArray, got null data pointer and non zero size");
     425             :         }
     426             :     }
     427             : 
     428             :     /// Pointer to the data used by this array
     429             :     T* data_ = nullptr;
     430             :     /// Full shape of this array
     431             :     std::vector<size_t> shape_ = {0, 0};
     432             :     /// Is this array const? This will dynamically prevent calling non-const
     433             :     /// function on it.
     434             :     bool is_const_ = true;
     435             :     /// Type-erased owned data for this array. This is a `nullptr` if this array
     436             :     /// is a view.
     437             :     void* owned_data_ = nullptr;
     438             :     /// Custom delete function for the `owned_data_`.
     439             :     std::function<void(void*)> deleter_;
     440             : };
     441             : 
     442             : 
     443             : /// Compare this `NDArray` with another `NDarray`. The array are equal if
     444             : /// and only if both the shape and data are equal.
     445             : template<typename T>
     446             : bool operator==(const NDArray<T>& lhs, const NDArray<T>& rhs) {
     447             :     if (lhs.shape() != rhs.shape()) {
     448             :         return false;
     449             :     }
     450             :     return std::memcmp(lhs.data(), rhs.data(), sizeof(T) * details::product(lhs.shape())) == 0;
     451             : }
     452             : 
     453             : /// Compare this `NDArray` with another `NDarray`. The array are equal if
     454             : /// and only if both the shape and data are equal.
     455             : template<typename T>
     456             : bool operator!=(const NDArray<T>& lhs, const NDArray<T>& rhs) {
     457             :     return !(lhs == rhs);
     458             : }
     459             : 
     460             : 
     461             : /// `DataArrayBase` manages n-dimensional arrays used as data in a block or
     462             : /// tensor map. The array itself if opaque to this library and can come from
     463             : /// multiple sources: Rust program, a C/C++ program, a Fortran program, Python
     464             : /// with numpy or torch. The data does not have to live on CPU, or even on the
     465             : /// same machine where this code is executed.
     466             : ///
     467             : /// **WARNING**: all function implementations **MUST** be thread-safe, and can
     468             : /// be called from multiple threads at the same time. The `DataArrayBase` itself
     469             : /// might be moved from one thread to another.
     470             : class DataArrayBase {
     471             : public:
     472             :     DataArrayBase() = default;
     473             :     virtual ~DataArrayBase() = default;
     474             : 
     475             :     /// DataArrayBase can be copy-constructed
     476             :     DataArrayBase(const DataArrayBase&) = default;
     477             :     /// DataArrayBase can be copy-assigned
     478             :     DataArrayBase& operator=(const DataArrayBase&) = default;
     479             :     /// DataArrayBase can be move-constructed
     480             :     DataArrayBase(DataArrayBase&&) noexcept = default;
     481             :     /// DataArrayBase can be move-assigned
     482             :     DataArrayBase& operator=(DataArrayBase&&) noexcept = default;
     483             : 
     484             :     /// Convert a concrete `DataArrayBase` to a C-compatible `mts_array_t`
     485             :     ///
     486             :     /// The `mts_array_t` takes ownership of the data, which should be released
     487             :     /// with `mts_array_t::destroy`.
     488             :     static mts_array_t to_mts_array_t(std::unique_ptr<DataArrayBase> data) {
     489             :         mts_array_t array;
     490             :         std::memset(&array, 0, sizeof(array));
     491             : 
     492             :         array.ptr = data.release();
     493             : 
     494             :         array.destroy = [](void* array) {
     495             :             auto ptr = std::unique_ptr<DataArrayBase>(static_cast<DataArrayBase*>(array));
     496             :             // let ptr go out of scope
     497             :         };
     498             : 
     499             :         array.origin = [](const void* array, mts_data_origin_t* origin) {
     500             :             return details::catch_exceptions([](const void* array, mts_data_origin_t* origin){
     501             :                 const auto* cxx_array = static_cast<const DataArrayBase*>(array);
     502             :                 *origin = cxx_array->origin();
     503             :                 return MTS_SUCCESS;
     504             :             }, array, origin);
     505             :         };
     506             : 
     507             :         array.copy = [](const void* array, mts_array_t* new_array) {
     508             :             return details::catch_exceptions([](const void* array, mts_array_t* new_array){
     509             :                 const auto* cxx_array = static_cast<const DataArrayBase*>(array);
     510             :                 auto copy = cxx_array->copy();
     511             :                 *new_array = DataArrayBase::to_mts_array_t(std::move(copy));
     512             :                 return MTS_SUCCESS;
     513             :             }, array, new_array);
     514             :         };
     515             : 
     516             :         array.create = [](const void* array, const uintptr_t* shape, uintptr_t shape_count, mts_array_t* new_array) {
     517             :             return details::catch_exceptions([](
     518             :                 const void* array,
     519             :                 const uintptr_t* shape,
     520             :                 uintptr_t shape_count,
     521             :                 mts_array_t* new_array
     522             :             ) {
     523             :                 const auto* cxx_array = static_cast<const DataArrayBase*>(array);
     524             :                 auto cxx_shape = std::vector<size_t>();
     525             :                 for (size_t i=0; i<static_cast<size_t>(shape_count); i++) {
     526             :                     cxx_shape.push_back(static_cast<size_t>(shape[i]));
     527             :                 }
     528             :                 auto copy = cxx_array->create(std::move(cxx_shape));
     529             :                 *new_array = DataArrayBase::to_mts_array_t(std::move(copy));
     530             :                 return MTS_SUCCESS;
     531             :             }, array, shape, shape_count, new_array);
     532             :         };
     533             : 
     534             :         array.data = [](void* array, double** data) {
     535             :             return details::catch_exceptions([](void* array, double** data){
     536             :                 auto* cxx_array = static_cast<DataArrayBase*>(array);
     537             :                 *data = cxx_array->data();
     538             :                 return MTS_SUCCESS;
     539             :             }, array, data);
     540             :         };
     541             : 
     542             :         array.shape = [](const void* array, const uintptr_t** shape, uintptr_t* shape_count) {
     543             :             return details::catch_exceptions([](const void* array, const uintptr_t** shape, uintptr_t* shape_count){
     544             :                 const auto* cxx_array = static_cast<const DataArrayBase*>(array);
     545             :                 const auto& cxx_shape = cxx_array->shape();
     546             :                 *shape = cxx_shape.data();
     547             :                 *shape_count = static_cast<uintptr_t>(cxx_shape.size());
     548             :                 return MTS_SUCCESS;
     549             :             }, array, shape, shape_count);
     550             :         };
     551             : 
     552             :         array.reshape = [](void* array, const uintptr_t* shape, uintptr_t shape_count) {
     553             :             return details::catch_exceptions([](void* array, const uintptr_t* shape, uintptr_t shape_count){
     554             :                 auto* cxx_array = static_cast<DataArrayBase*>(array);
     555             :                 auto cxx_shape = std::vector<uintptr_t>(shape, shape + shape_count);
     556             :                 cxx_array->reshape(std::move(cxx_shape));
     557             :                 return MTS_SUCCESS;
     558             :             }, array, shape, shape_count);
     559             :         };
     560             : 
     561             :         array.swap_axes = [](void* array, uintptr_t axis_1, uintptr_t axis_2) {
     562             :             return details::catch_exceptions([](void* array, uintptr_t axis_1, uintptr_t axis_2){
     563             :                 auto* cxx_array = static_cast<DataArrayBase*>(array);
     564             :                 cxx_array->swap_axes(axis_1, axis_2);
     565             :                 return MTS_SUCCESS;
     566             :             }, array, axis_1, axis_2);
     567             :         };
     568             : 
     569             :         array.move_samples_from = [](
     570             :             void* array,
     571             :             const void* input,
     572             :             const mts_sample_mapping_t* samples,
     573             :             uintptr_t samples_count,
     574             :             uintptr_t property_start,
     575             :             uintptr_t property_end
     576             :         ) {
     577             :             return details::catch_exceptions([](
     578             :                 void* array,
     579             :                 const void* input,
     580             :                 const mts_sample_mapping_t* samples,
     581             :                 uintptr_t samples_count,
     582             :                 uintptr_t property_start,
     583             :                 uintptr_t property_end
     584             :             ) {
     585             :                 auto* cxx_array = static_cast<DataArrayBase*>(array);
     586             :                 const auto* cxx_input = static_cast<const DataArrayBase*>(input);
     587             :                 auto cxx_samples = std::vector<mts_sample_mapping_t>(samples, samples + samples_count);
     588             : 
     589             :                 cxx_array->move_samples_from(*cxx_input, cxx_samples, property_start, property_end);
     590             :                 return MTS_SUCCESS;
     591             :             }, array, input, samples, samples_count, property_start, property_end);
     592             :         };
     593             : 
     594             :         return array;
     595             :     }
     596             : 
     597             :     /// Get "data origin" for this array in.
     598             :     ///
     599             :     /// Users of `DataArrayBase` should register a single data
     600             :     /// origin with `mts_register_data_origin`, and use it for all compatible
     601             :     /// arrays.
     602             :     virtual mts_data_origin_t origin() const = 0;
     603             : 
     604             :     /// Make a copy of this DataArrayBase and return the new array. The new
     605             :     /// array is expected to have the same data origin and parameters (data
     606             :     /// type, data location, etc.)
     607             :     virtual std::unique_ptr<DataArrayBase> copy() const = 0;
     608             : 
     609             :     /// Create a new array with the same options as the current one (data type,
     610             :     /// data location, etc.) and the requested `shape`.
     611             :     ///
     612             :     /// The new array should be filled with zeros.
     613             :     virtual std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const = 0;
     614             : 
     615             :     /// Get a pointer to the underlying data storage.
     616             :     ///
     617             :     /// This function is allowed to fail if the data is not accessible in RAM,
     618             :     /// not stored as 64-bit floating point values, or not stored as a
     619             :     /// C-contiguous array.
     620             :     virtual double* data() & = 0;
     621             : 
     622             :     double* data() && = delete;
     623             : 
     624             :     /// Get the shape of this array
     625             :     virtual const std::vector<uintptr_t>& shape() const & = 0;
     626             : 
     627             :     const std::vector<uintptr_t>& shape() && = delete;
     628             : 
     629             :     /// Set the shape of this array to the given `shape`
     630             :     virtual void reshape(std::vector<uintptr_t> shape) = 0;
     631             : 
     632             :     /// Swap the axes `axis_1` and `axis_2` in this `array`.
     633             :     virtual void swap_axes(uintptr_t axis_1, uintptr_t axis_2) = 0;
     634             : 
     635             :     /// Set entries in the current array taking data from the `input` array.
     636             :     ///
     637             :     /// This array is guaranteed to be created by calling `mts_array_t::create`
     638             :     /// with one of the arrays in the same block or tensor map as the `input`.
     639             :     ///
     640             :     /// The `samples` indicate where the data should be moved from `input` to
     641             :     /// the current DataArrayBase.
     642             :     ///
     643             :     /// This function should copy data from `input[samples[i].input, ..., :]` to
     644             :     /// `array[samples[i].output, ..., property_start:property_end]` for `i` up
     645             :     /// to `samples_count`. All indexes are 0-based.
     646             :     virtual void move_samples_from(
     647             :         const DataArrayBase& input,
     648             :         std::vector<mts_sample_mapping_t> samples,
     649             :         uintptr_t property_start,
     650             :         uintptr_t property_end
     651             :     ) = 0;
     652             : };
     653             : 
     654             : 
     655             : /// Very basic implementation of DataArrayBase in C++.
     656             : ///
     657             : /// This is included as an example implementation of DataArrayBase, and to make
     658             : /// metatensor usable without additional dependencies. For other uses cases, it
     659             : /// might be better to implement DataArrayBase on your data, using
     660             : /// functionalities from `Eigen`, `Boost.Array`, etc.
     661             : class SimpleDataArray: public metatensor::DataArrayBase {
     662             : public:
     663             :     /// Create a SimpleDataArray with the given `shape`, and all elements set to
     664             :     /// `value`
     665             :     SimpleDataArray(std::vector<uintptr_t> shape, double value = 0.0):
     666             :         shape_(std::move(shape)), data_(details::product(shape_), value) {}
     667             : 
     668             :     /// Create a SimpleDataArray with the given `shape` and `data`.
     669             :     ///
     670             :     /// The data is interpreted as a row-major n-dimensional array.
     671             :     SimpleDataArray(std::vector<uintptr_t> shape, std::vector<double> data):
     672             :         shape_(std::move(shape)),
     673             :         data_(std::move(data))
     674             :     {
     675             :         if (data_.size() != details::product(shape_)) {
     676             :             throw Error("the shape and size of the data don't match in SimpleDataArray");
     677             :         }
     678             :     }
     679             : 
     680             :     ~SimpleDataArray() override = default;
     681             : 
     682             :     /// SimpleDataArray can be copy-constructed
     683             :     SimpleDataArray(const SimpleDataArray&) = default;
     684             :     /// SimpleDataArray can be copy-assigned
     685             :     SimpleDataArray& operator=(const SimpleDataArray&) = default;
     686             :     /// SimpleDataArray can be move-constructed
     687             :     SimpleDataArray(SimpleDataArray&&) noexcept = default;
     688             :     /// SimpleDataArray can be move-assigned
     689             :     SimpleDataArray& operator=(SimpleDataArray&&) noexcept = default;
     690             : 
     691             :     mts_data_origin_t origin() const override {
     692             :         mts_data_origin_t origin = 0;
     693             :         mts_register_data_origin("metatensor::SimpleDataArray", &origin);
     694             :         return origin;
     695             :     }
     696             : 
     697             :     double* data() & override {
     698             :         return data_.data();
     699             :     }
     700             : 
     701             :     const std::vector<uintptr_t>& shape() const & override {
     702             :         return shape_;
     703             :     }
     704             : 
     705             :     void reshape(std::vector<uintptr_t> shape) override {
     706             :         if (details::product(shape_) != details::product(shape)) {
     707             :             throw metatensor::Error("invalid shape in reshape");
     708             :         }
     709             :         shape_ = std::move(shape);
     710             :     }
     711             : 
     712             :     void swap_axes(uintptr_t axis_1, uintptr_t axis_2) override {
     713             :         auto new_data = std::vector<double>(details::product(shape_), 0.0);
     714             :         auto new_shape = shape_;
     715             :         std::swap(new_shape[axis_1], new_shape[axis_2]);
     716             : 
     717             :         for (size_t i=0; i<details::product(shape_); i++) {
     718             :             auto index = details::cartesian_index(shape_, i);
     719             :             std::swap(index[axis_1], index[axis_2]);
     720             : 
     721             :             new_data[details::linear_index(new_shape, index)] = data_[i];
     722             :         }
     723             : 
     724             :         shape_ = std::move(new_shape);
     725             :         data_ = std::move(new_data);
     726             :     }
     727             : 
     728             :     std::unique_ptr<DataArrayBase> copy() const override {
     729             :         return std::unique_ptr<DataArrayBase>(new SimpleDataArray(*this));
     730             :     }
     731             : 
     732             :     std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const override {
     733             :         return std::unique_ptr<DataArrayBase>(new SimpleDataArray(std::move(shape)));
     734             :     }
     735             : 
     736             :     void move_samples_from(
     737             :         const DataArrayBase& input,
     738             :         std::vector<mts_sample_mapping_t> samples,
     739             :         uintptr_t property_start,
     740             :         uintptr_t property_end
     741             :     ) override {
     742             :         const auto& input_array = dynamic_cast<const SimpleDataArray&>(input);
     743             :         assert(input_array.shape_.size() == this->shape_.size());
     744             : 
     745             :         size_t property_count = property_end - property_start;
     746             :         size_t property_dim = shape_.size() - 1;
     747             :         assert(input_array.shape_[property_dim] == property_count);
     748             : 
     749             :         auto input_index = std::vector<size_t>(shape_.size(), 0);
     750             :         auto output_index = std::vector<size_t>(shape_.size(), 0);
     751             : 
     752             :         for (const auto& sample: samples) {
     753             :             input_index[0] = sample.input;
     754             :             output_index[0] = sample.output;
     755             : 
     756             :             if (property_dim == 1) {
     757             :                 // no components
     758             :                 for (size_t property_i=0; property_i<property_count; property_i++) {
     759             :                     input_index[property_dim] = property_i;
     760             :                     output_index[property_dim] = property_i + property_start;
     761             : 
     762             :                     auto value = input_array.data_[details::linear_index(input_array.shape_, input_index)];
     763             :                     this->data_[details::linear_index(shape_, output_index)] = value;
     764             :                 }
     765             :             } else {
     766             :                 auto last_component_dim = shape_.size() - 2;
     767             :                 for (size_t component_i=1; component_i<shape_.size() - 1; component_i++) {
     768             :                     input_index[component_i] = 0;
     769             :                 }
     770             : 
     771             :                 bool done = false;
     772             :                 while (!done) {
     773             :                     for (size_t component_i=1; component_i<shape_.size() - 1; component_i++) {
     774             :                         output_index[component_i] = input_index[component_i];
     775             :                     }
     776             : 
     777             :                     for (size_t property_i=0; property_i<property_count; property_i++) {
     778             :                         input_index[property_dim] = property_i;
     779             :                         output_index[property_dim] = property_i + property_start;
     780             : 
     781             :                         auto value = input_array.data_[details::linear_index(input_array.shape_, input_index)];
     782             :                         this->data_[details::linear_index(shape_, output_index)] = value;
     783             :                     }
     784             : 
     785             :                     input_index[last_component_dim] += 1;
     786             :                     for (size_t component_i=last_component_dim; component_i>2; component_i--) {
     787             :                         if (input_index[component_i] >= shape_[component_i]) {
     788             :                             input_index[component_i] = 0;
     789             :                             input_index[component_i - 1] += 1;
     790             :                         }
     791             :                     }
     792             : 
     793             :                     if (input_index[1] >= shape_[1]) {
     794             :                         done = true;
     795             :                     }
     796             :                 }
     797             :             }
     798             :         }
     799             :     }
     800             : 
     801             :     /// Get a const view of the data managed by this SimpleDataArray
     802             :     NDArray<double> view() const {
     803             :         return NDArray<double>(data_.data(), shape_);
     804             :     }
     805             : 
     806             :     /// Get a mutable view of the data managed by this SimpleDataArray
     807             :     NDArray<double> view() {
     808             :         return NDArray<double>(data_.data(), shape_);
     809             :     }
     810             : 
     811             :     /// Extract a reference to SimpleDataArray out of an `mts_array_t`.
     812             :     ///
     813             :     /// This function fails if the `mts_array_t` does not contain a
     814             :     /// SimpleDataArray.
     815             :     static SimpleDataArray& from_mts_array(mts_array_t& array) {
     816             :         mts_data_origin_t origin = 0;
     817             :         auto status = array.origin(array.ptr, &origin);
     818             :         if (status != MTS_SUCCESS) {
     819             :             throw Error("failed to get data origin");
     820             :         }
     821             : 
     822             :         std::array<char, 64> buffer = {0};
     823             :         status = mts_get_data_origin(origin, buffer.data(), buffer.size());
     824             :         if (status != MTS_SUCCESS || std::string(buffer.data()) != "metatensor::SimpleDataArray") {
     825             :             throw Error("this array is not a metatensor::SimpleDataArray");
     826             :         }
     827             : 
     828             :         auto* base = static_cast<DataArrayBase*>(array.ptr);
     829             :         return dynamic_cast<SimpleDataArray&>(*base);
     830             :     }
     831             : 
     832             :     /// Extract a const reference to SimpleDataArray out of an `mts_array_t`.
     833             :     ///
     834             :     /// This function fails if the `mts_array_t` does not contain a
     835             :     /// SimpleDataArray.
     836             :     static const SimpleDataArray& from_mts_array(const mts_array_t& array) {
     837             :         mts_data_origin_t origin = 0;
     838             :         auto status = array.origin(array.ptr, &origin);
     839             :         if (status != MTS_SUCCESS) {
     840             :             throw Error("failed to get data origin");
     841             :         }
     842             : 
     843             :         std::array<char, 64> buffer = {0};
     844             :         status = mts_get_data_origin(origin, buffer.data(), buffer.size());
     845             :         if (status != MTS_SUCCESS || std::string(buffer.data()) != "metatensor::SimpleDataArray") {
     846             :             throw Error("this array is not a metatensor::SimpleDataArray");
     847             :         }
     848             : 
     849             :         const auto* base = static_cast<const DataArrayBase*>(array.ptr);
     850             :         return dynamic_cast<const SimpleDataArray&>(*base);
     851             :     }
     852             : 
     853             : private:
     854             :     std::vector<uintptr_t> shape_;
     855             :     std::vector<double> data_;
     856             : 
     857             :     friend bool operator==(const SimpleDataArray& lhs, const SimpleDataArray& rhs);
     858             : };
     859             : 
     860             : /// Two SimpleDataArray compare as equal if they have the exact same shape and
     861             : /// data.
     862             : inline bool operator==(const SimpleDataArray& lhs, const SimpleDataArray& rhs) {
     863             :     return lhs.shape_ == rhs.shape_ && lhs.data_ == rhs.data_;
     864             : }
     865             : 
     866             : /// Two SimpleDataArray compare as equal if they have the exact same shape and
     867             : /// data.
     868             : inline bool operator!=(const SimpleDataArray& lhs, const SimpleDataArray& rhs) {
     869             :     return !(lhs == rhs);
     870             : }
     871             : 
     872             : 
     873             : /// An implementation of `DataArrayBase` containing no data.
     874             : ///
     875             : /// This class only tracks it's shape, and can be used when only the metadata
     876             : /// of a `TensorBlock` is important, leaving the data unspecified.
     877             : class EmptyDataArray: public metatensor::DataArrayBase {
     878             : public:
     879             :     /// Create ae `EmptyDataArray` with the given `shape`
     880             :     EmptyDataArray(std::vector<uintptr_t> shape):
     881             :         shape_(std::move(shape)) {}
     882             : 
     883             :     ~EmptyDataArray() override = default;
     884             : 
     885             :     /// EmptyDataArray can be copy-constructed
     886             :     EmptyDataArray(const EmptyDataArray&) = default;
     887             :     /// EmptyDataArray can be copy-assigned
     888             :     EmptyDataArray& operator=(const EmptyDataArray&) = default;
     889             :     /// EmptyDataArray can be move-constructed
     890             :     EmptyDataArray(EmptyDataArray&&) noexcept = default;
     891             :     /// EmptyDataArray can be move-assigned
     892             :     EmptyDataArray& operator=(EmptyDataArray&&) noexcept = default;
     893             : 
     894             :     mts_data_origin_t origin() const override {
     895             :         mts_data_origin_t origin = 0;
     896             :         mts_register_data_origin("metatensor::EmptyDataArray", &origin);
     897             :         return origin;
     898             :     }
     899             : 
     900             :     double* data() & override {
     901             :         throw metatensor::Error("can not call `data` for an EmptyDataArray");
     902             :     }
     903             : 
     904             :     const std::vector<uintptr_t>& shape() const & override {
     905             :         return shape_;
     906             :     }
     907             : 
     908             :     void reshape(std::vector<uintptr_t> shape) override {
     909             :         if (details::product(shape_) != details::product(shape)) {
     910             :             throw metatensor::Error("invalid shape in reshape");
     911             :         }
     912             :         shape_ = std::move(shape);
     913             :     }
     914             : 
     915             :     void swap_axes(uintptr_t axis_1, uintptr_t axis_2) override {
     916             :         std::swap(shape_[axis_1], shape_[axis_2]);
     917             :     }
     918             : 
     919             :     std::unique_ptr<DataArrayBase> copy() const override {
     920             :         return std::unique_ptr<DataArrayBase>(new EmptyDataArray(*this));
     921             :     }
     922             : 
     923             :     std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const override {
     924             :         return std::unique_ptr<DataArrayBase>(new EmptyDataArray(std::move(shape)));
     925             :     }
     926             : 
     927             :     void move_samples_from(const DataArrayBase&, std::vector<mts_sample_mapping_t>, uintptr_t, uintptr_t) override {
     928             :         throw metatensor::Error("can not call `move_samples_from` for an EmptyDataArray");
     929             :     }
     930             : 
     931             : private:
     932             :     std::vector<uintptr_t> shape_;
     933             : };
     934             : 
     935             : namespace details {
     936             :     /// Default callback for data array creating in `TensorMap::load`, which
     937             :     /// will create a `SimpleDataArray`.
     938             :     inline mts_status_t default_create_array(
     939             :         const uintptr_t* shape_ptr,
     940             :         uintptr_t shape_count,
     941             :         mts_array_t* array
     942             :     ) {
     943             :         return details::catch_exceptions([](const uintptr_t* shape_ptr, uintptr_t shape_count, mts_array_t* array){
     944             :             auto shape = std::vector<size_t>();
     945             :             for (size_t i=0; i<shape_count; i++) {
     946             :                 shape.push_back(static_cast<size_t>(shape_ptr[i]));
     947             :             }
     948             : 
     949             :             auto cxx_array = std::unique_ptr<DataArrayBase>(new SimpleDataArray(shape));
     950             :             *array = DataArrayBase::to_mts_array_t(std::move(cxx_array));
     951             : 
     952             :             return MTS_SUCCESS;
     953             :         }, shape_ptr, shape_count, array);
     954             :     }
     955             : }
     956             : 
     957             : /******************************************************************************/
     958             : /******************************************************************************/
     959             : /*                                                                            */
     960             : /*                           I/O functionalities                              */
     961             : /*                                                                            */
     962             : /******************************************************************************/
     963             : /******************************************************************************/
     964             : 
     965             : namespace io {
     966             :     /// Save a `TensorMap` to the file at `path`.
     967             :     ///
     968             :     /// If the file exists, it will be overwritten. The recomended file
     969             :     /// extension when saving data is `.mts`, to prevent confusion with generic
     970             :     /// `.npz` files.
     971             :     ///
     972             :     /// `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
     973             :     /// without compression (storage method is `STORED`), where each file is
     974             :     /// stored as a `.npy` array. See the C API documentation for more
     975             :     /// information on the format.
     976             :     void save(const std::string& path, const TensorMap& tensor);
     977             : 
     978             :     /// Save a `TensorMap` to an in-memory buffer.
     979             :     ///
     980             :     /// The `Buffer` template parameter can be set to any type that can be
     981             :     /// constructed from a pair of iterator over `std::vector<uint8_t>`.
     982             :     template <typename Buffer = std::vector<uint8_t>>
     983             :     Buffer save_buffer(const TensorMap& tensor);
     984             : 
     985             :     template<>
     986             :     std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorMap& tensor);
     987             : 
     988             :     /**************************************************************************/
     989             : 
     990             :     /// Save a `TensorBlock` to the file at `path`.
     991             :     ///
     992             :     /// If the file exists, it will be overwritten. The recomended file
     993             :     /// extension when saving data is `.mts`, to prevent confusion with generic
     994             :     /// `.npz` files.
     995             :     void save(const std::string& path, const TensorBlock& block);
     996             : 
     997             :     /// Save a `TensorBlock` to an in-memory buffer.
     998             :     ///
     999             :     /// The `Buffer` template parameter can be set to any type that can be
    1000             :     /// constructed from a pair of iterator over `std::vector<uint8_t>`.
    1001             :     template <typename Buffer = std::vector<uint8_t>>
    1002             :     Buffer save_buffer(const TensorBlock& block);
    1003             : 
    1004             :     template<>
    1005             :     std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorBlock& block);
    1006             : 
    1007             :     /**************************************************************************/
    1008             : 
    1009             :     /// Save `Labels` to the file at `path`.
    1010             :     ///
    1011             :     /// If the file exists, it will be overwritten. The recomended file
    1012             :     /// extension when saving data is `.mts`, to prevent confusion with generic
    1013             :     /// `.npz` files.
    1014             :     void save(const std::string& path, const Labels& labels);
    1015             : 
    1016             :     /// Save `Labels` to an in-memory buffer.
    1017             :     ///
    1018             :     /// The `Buffer` template parameter can be set to any type that can be
    1019             :     /// constructed from a pair of iterator over `std::vector<uint8_t>`.
    1020             :     template <typename Buffer = std::vector<uint8_t>>
    1021             :     Buffer save_buffer(const Labels& labels);
    1022             : 
    1023             :     template<>
    1024             :     std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const Labels& labels);
    1025             : 
    1026             :     /**************************************************************************/
    1027             :     /**************************************************************************/
    1028             : 
    1029             :     /*!
    1030             :      * Load a previously saved `TensorMap` from the given path.
    1031             :      *
    1032             :      * \verbatim embed:rst:leading-asterisk
    1033             :      *
    1034             :      * ``create_array`` will be used to create new arrays when constructing the
    1035             :      * blocks and gradients, the default version will create data using
    1036             :      * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
    1037             :      * for more information.
    1038             :      *
    1039             :      * \endverbatim
    1040             :      *
    1041             :      * `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
    1042             :      * without compression (storage method is `STORED`), where each file is
    1043             :      * stored as a `.npy` array. See the C API documentation for more
    1044             :      * information on the format.
    1045             :      */
    1046             :     TensorMap load(
    1047             :         const std::string& path,
    1048             :         mts_create_array_callback_t create_array = details::default_create_array
    1049             :     );
    1050             : 
    1051             :     /*!
    1052             :      * Load a previously saved `TensorMap` from the given `buffer`, containing
    1053             :      * `buffer_count` elements.
    1054             :      *
    1055             :      * \verbatim embed:rst:leading-asterisk
    1056             :      *
    1057             :      * ``create_array`` will be used to create new arrays when constructing the
    1058             :      * blocks and gradients, the default version will create data using
    1059             :      * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
    1060             :      * for more information.
    1061             :      *
    1062             :      * \endverbatim
    1063             :      */
    1064             :     TensorMap load_buffer(
    1065             :         const uint8_t* buffer,
    1066             :         size_t buffer_count,
    1067             :         mts_create_array_callback_t create_array = details::default_create_array
    1068             :     );
    1069             : 
    1070             : 
    1071             :     /// Load a previously saved `TensorMap` from the given `buffer`.
    1072             :     ///
    1073             :     /// The `Buffer` template parameter would typically be a
    1074             :     /// `std::vector<uint8_t>` or a `std::string`, but any container with
    1075             :     /// contiguous data and an `item_type` with the same size as a `uint8_t` can
    1076             :     /// work.
    1077             :     template <typename Buffer>
    1078             :     TensorMap load_buffer(
    1079             :         const Buffer& buffer,
    1080             :         mts_create_array_callback_t create_array = details::default_create_array
    1081             :     );
    1082             : 
    1083             :     /**************************************************************************/
    1084             : 
    1085             :     /*!
    1086             :      * Load a previously saved `TensorBlock` from the given path.
    1087             :      *
    1088             :      * \verbatim embed:rst:leading-asterisk
    1089             :      *
    1090             :      * ``create_array`` will be used to create new arrays when constructing the
    1091             :      * blocks and gradients, the default version will create data using
    1092             :      * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
    1093             :      * for more information.
    1094             :      *
    1095             :      * \endverbatim
    1096             :      *
    1097             :      */
    1098             :     TensorBlock load_block(
    1099             :         const std::string& path,
    1100             :         mts_create_array_callback_t create_array = details::default_create_array
    1101             :     );
    1102             : 
    1103             :     /*!
    1104             :      * Load a previously saved `TensorBlock` from the given `buffer`, containing
    1105             :      * `buffer_count` elements.
    1106             :      *
    1107             :      * \verbatim embed:rst:leading-asterisk
    1108             :      *
    1109             :      * ``create_array`` will be used to create new arrays when constructing the
    1110             :      * blocks and gradients, the default version will create data using
    1111             :      * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
    1112             :      * for more information.
    1113             :      *
    1114             :      * \endverbatim
    1115             :      */
    1116             :     TensorBlock load_block_buffer(
    1117             :         const uint8_t* buffer,
    1118             :         size_t buffer_count,
    1119             :         mts_create_array_callback_t create_array = details::default_create_array
    1120             :     );
    1121             : 
    1122             : 
    1123             :     /// Load a previously saved `TensorBlock` from the given `buffer`.
    1124             :     ///
    1125             :     /// The `Buffer` template parameter would typically be a
    1126             :     /// `std::vector<uint8_t>` or a `std::string`, but any container with
    1127             :     /// contiguous data and an `item_type` with the same size as a `uint8_t` can
    1128             :     /// work.
    1129             :     template <typename Buffer>
    1130             :     TensorBlock load_block_buffer(
    1131             :         const Buffer& buffer,
    1132             :         mts_create_array_callback_t create_array = details::default_create_array
    1133             :     );
    1134             : 
    1135             :     /**************************************************************************/
    1136             : 
    1137             :     /// Load previously saved `Labels` from the given path.
    1138             :     Labels load_labels(const std::string& path);
    1139             : 
    1140             :     /// Load previously saved `Labels` from the given `buffer`, containing
    1141             :     /// `buffer_count` elements.
    1142             :     Labels load_labels_buffer(const uint8_t* buffer, size_t buffer_count);
    1143             : 
    1144             :     /// Load a previously saved `Labels` from the given `buffer`.
    1145             :     ///
    1146             :     /// The `Buffer` template parameter would typically be a
    1147             :     /// `std::vector<uint8_t>` or a `std::string`, but any container with
    1148             :     /// contiguous data and an `item_type` with the same size as a `uint8_t` can
    1149             :     /// work.
    1150             :     template <typename Buffer>
    1151             :     Labels load_labels_buffer(const Buffer& buffer);
    1152             : }
    1153             : 
    1154             : 
    1155             : /******************************************************************************/
    1156             : /******************************************************************************/
    1157             : /*                                                                            */
    1158             : /*                                Labels                                      */
    1159             : /*                                                                            */
    1160             : /******************************************************************************/
    1161             : /******************************************************************************/
    1162             : 
    1163             : 
    1164             : /// It is possible to store some user-provided data inside `Labels`, and access
    1165             : /// it later. This class is used to take ownership of the data and corresponding
    1166             : /// delete function before giving the data to metatensor.
    1167             : ///
    1168             : /// User data inside `Labels` is an advanced functionality, that most users
    1169             : /// should not need to interact with.
    1170             : class LabelsUserData {
    1171             : public:
    1172             :     /// Create `LabelsUserData` containing the given `data`.
    1173             :     ///
    1174             :     /// `deleter` will be called when the data is dropped, and should
    1175             :     /// free the corresponding memory.
    1176             :     LabelsUserData(void* data, void(*deleter)(void*)): data_(data), deleter_(deleter) {}
    1177             : 
    1178             :     ~LabelsUserData() {
    1179             :         if (deleter_ !=  nullptr) {
    1180             :             deleter_(data_);
    1181             :         }
    1182             :     }
    1183             : 
    1184             :     /// LabelsUserData is not copy-constructible
    1185             :     LabelsUserData(const LabelsUserData& other) = delete;
    1186             :     /// LabelsUserData can not be copy-assigned
    1187             :     LabelsUserData& operator=(const LabelsUserData& other) = delete;
    1188             : 
    1189             :     /// LabelsUserData is move-constructible
    1190             :     LabelsUserData(LabelsUserData&& other) noexcept: LabelsUserData(nullptr, nullptr) {
    1191             :         *this = std::move(other);
    1192             :     }
    1193             : 
    1194             :     /// LabelsUserData be move-assigned
    1195             :     LabelsUserData& operator=(LabelsUserData&& other) noexcept {
    1196             :         if (deleter_ !=  nullptr) {
    1197             :             deleter_(data_);
    1198             :         }
    1199             : 
    1200             :         data_ = other.data_;
    1201             :         deleter_ = other.deleter_;
    1202             : 
    1203             :         other.data_ = nullptr;
    1204             :         other.deleter_ = nullptr;
    1205             : 
    1206             :         return *this;
    1207             :     }
    1208             : 
    1209             : private:
    1210             :     friend class Labels;
    1211             : 
    1212             :     void* data_;
    1213             :     void(*deleter_)(void*);
    1214             : };
    1215             : 
    1216             : 
    1217             : /// A set of labels used to carry metadata associated with a tensor map.
    1218             : ///
    1219             : /// This is similar to an array of named tuples, but stored as a 2D array
    1220             : /// of shape `(count, size)`, with a set of names associated with the columns of
    1221             : /// this array (often called *dimensions*). Each row/entry in this array is
    1222             : /// unique, and they are often (but not always) sorted in lexicographic order.
    1223             : class Labels final {
    1224             : public:
    1225             :     /// Create a new set of Labels from the given `names` and `values`.
    1226             :     ///
    1227             :     /// Each entry in the values must contain `names.size()` elements.
    1228             :     ///
    1229             :     /// ```
    1230             :     /// auto labels = Labels({"first", "second"}, {
    1231             :     ///    {0, 1},
    1232             :     ///    {1, 4},
    1233             :     ///    {2, 1},
    1234             :     ///    {2, 3},
    1235             :     /// });
    1236             :     /// ```
    1237             :     Labels(
    1238             :         const std::vector<std::string>& names,
    1239             :         const std::vector<std::initializer_list<int32_t>>& values
    1240             :     ): Labels(names, NDArray<int32_t>(values, names.size()), InternalConstructor{}) {}
    1241             : 
    1242             :     /// This function does not check for uniqueness of the labels entries, which
    1243             :     /// should be enforced by the caller. Calling this function with non-unique
    1244             :     /// entries is invalid and can lead to crashes or infinite loops.
    1245             :     explicit Labels(
    1246             :         const std::vector<std::string>& names,
    1247             :         const std::vector<std::initializer_list<int32_t>>& values,
    1248             :         assume_unique
    1249             :     ): Labels(names, NDArray<int32_t>(values, names.size()), assume_unique{}, InternalConstructor{}) {}
    1250             : 
    1251             :     /// Create an empty set of Labels with the given names
    1252             :     explicit Labels(const std::vector<std::string>& names):
    1253             :         Labels(names, static_cast<const int32_t*>(nullptr), 0) {}
    1254             : 
    1255             :     /// Create labels with the given `names` and `values`. `values` must be an
    1256             :     /// array with `count x names.size()` elements.
    1257             :     Labels(const std::vector<std::string>& names, const int32_t* values, size_t count):
    1258             :         Labels(details::labels_from_cxx(names, values, count, false)) {}
    1259             : 
    1260             :     /// Unchecked variant, caller promises the labels are unique. Calling with
    1261             :     /// non-unique entries is invalid and can ead to crashes or infinite loops.
    1262             :     Labels(const std::vector<std::string>& names, const int32_t* values, size_t count, assume_unique):
    1263             :         Labels(details::labels_from_cxx(names, values, count, true)) {}
    1264             : 
    1265             :     ~Labels() {
    1266             :         mts_labels_free(&labels_);
    1267             :     }
    1268             : 
    1269             :     /// Labels is copy-constructible
    1270             :     Labels(const Labels& other): Labels() {
    1271             :         *this = other;
    1272             :     }
    1273             : 
    1274             :     /// Labels can be copy-assigned
    1275             :     Labels& operator=(const Labels& other) {
    1276             :         mts_labels_free(&labels_);
    1277             :         std::memset(&labels_, 0, sizeof(labels_));
    1278             :         details::check_status(mts_labels_clone(other.labels_, &labels_));
    1279             :         assert(this->labels_.internal_ptr_ != nullptr);
    1280             : 
    1281             :         this->values_ = NDArray<int32_t>(labels_.values, {labels_.count, labels_.size});
    1282             : 
    1283             :         this->names_.clear();
    1284             :         for (size_t i=0; i<this->labels_.size; i++) {
    1285             :             this->names_.push_back(this->labels_.names[i]);
    1286             :         }
    1287             : 
    1288             :         return *this;
    1289             :     }
    1290             : 
    1291             :     /// Labels is move-constructible
    1292             :     Labels(Labels&& other) noexcept: Labels() {
    1293             :         *this = std::move(other);
    1294             :     }
    1295             : 
    1296             :     /// Labels can be move-assigned
    1297             :     Labels& operator=(Labels&& other) noexcept {
    1298             :         mts_labels_free(&labels_);
    1299             :         this->labels_ = other.labels_;
    1300             :         assert(this->labels_.internal_ptr_ != nullptr);
    1301             :         std::memset(&other.labels_, 0, sizeof(other.labels_));
    1302             : 
    1303             :         this->values_ = std::move(other.values_);
    1304             :         this->names_ = std::move(other.names_);
    1305             : 
    1306             :         return *this;
    1307             :     }
    1308             : 
    1309             :     /// Get the names of the dimensions used in these `Labels`.
    1310             :     const std::vector<const char*>& names() const {
    1311             :         return names_;
    1312             :     }
    1313             : 
    1314             :     /// Get the number of entries in this set of Labels.
    1315             :     ///
    1316             :     /// This is the same as `shape()[0]` for the corresponding values array
    1317             :     size_t count() const {
    1318             :         return labels_.count;
    1319             :     }
    1320             : 
    1321             :     /// Get the number of dimensions in this set of Labels.
    1322             :     ///
    1323             :     /// This is the same as `shape()[1]` for the corresponding values array
    1324             :     size_t size() const {
    1325             :         return labels_.size;
    1326             :     }
    1327             : 
    1328             :     /// Convert from this set of Labels to the C `mts_labels_t`
    1329             :     mts_labels_t as_mts_labels_t() const {
    1330             :         assert(labels_.internal_ptr_ != nullptr);
    1331             :         return labels_;
    1332             :     }
    1333             : 
    1334             :     /// Get the user data pointer registered with these `Labels`.
    1335             :     ///
    1336             :     /// If no user data have been registered, this function will return
    1337             :     /// `nullptr`.
    1338             :     void* user_data() & {
    1339             :         assert(labels_.internal_ptr_ != nullptr);
    1340             : 
    1341             :         void* data = nullptr;
    1342             :         details::check_status(mts_labels_user_data(labels_, &data));
    1343             :         return data;
    1344             :     }
    1345             : 
    1346             :     void* user_data() && = delete;
    1347             : 
    1348             :     /// Register some user data pointer with these `Labels`.
    1349             :     ///
    1350             :     /// Any existing user data will be released (by calling the provided
    1351             :     /// `delete` function) before overwriting with the new data.
    1352             :     void set_user_data(LabelsUserData user_data) {
    1353             :         assert(labels_.internal_ptr_ != nullptr);
    1354             : 
    1355             :         details::check_status(mts_labels_set_user_data(
    1356             :             labels_,
    1357             :             user_data.data_,
    1358             :             user_data.deleter_
    1359             :         ));
    1360             : 
    1361             :         // the user data was moved inside `labels_`
    1362             :         user_data.data_ = nullptr;
    1363             :         user_data.deleter_ = nullptr;
    1364             :     }
    1365             : 
    1366             :     /// Get the position of the `entry` in this set of Labels, or -1 if the
    1367             :     /// entry is not part of these Labels.
    1368             :     int64_t position(std::initializer_list<int32_t> entry) const {
    1369             :         return this->position(entry.begin(), entry.size());
    1370             :     }
    1371             : 
    1372             :     /// Variant of `Labels::position` taking a fixed-size array as input
    1373             :     template<size_t N>
    1374             :     int64_t position(const std::array<int32_t, N>& entry) const {
    1375             :         return this->position(entry.data(), entry.size());
    1376             :     }
    1377             : 
    1378             :     /// Variant of `Labels::position` taking a vector as input
    1379             :     int64_t position(const std::vector<int32_t>& entry) const {
    1380             :         return this->position(entry.data(), entry.size());
    1381             :     }
    1382             : 
    1383             :     /// Variant of `Labels::position` taking a pointer and length as input
    1384             :     int64_t position(const int32_t* entry, size_t length) const {
    1385             :         assert(labels_.internal_ptr_ != nullptr);
    1386             : 
    1387             :         int64_t result = 0;
    1388             :         details::check_status(mts_labels_position(labels_, entry, length, &result));
    1389             :         return result;
    1390             :     }
    1391             : 
    1392             :     /// Get the array of values for these Labels
    1393             :     const NDArray<int32_t>& values() const & {
    1394             :         return values_;
    1395             :     }
    1396             : 
    1397             :     const NDArray<int32_t>& values() && = delete;
    1398             : 
    1399             :     /// Take the union of these `Labels` with `other`.
    1400             :     ///
    1401             :     /// If requested, this function can also give the positions in the union
    1402             :     /// where each entry of the input `Labels` ended up.
    1403             :     ///
    1404             :     /// No user data pointer is registered with the output, even if the inputs
    1405             :     /// have some.
    1406             :     ///
    1407             :     /// @param other the `Labels` we want to take the union with
    1408             :     /// @param first_mapping if you want the mapping from the positions of
    1409             :     ///        entries in `this` to the positions in the union, this should be
    1410             :     ///        a pointer to an array containing `this->count()` elements, to be
    1411             :     ///        filled by this function. Otherwise it should be a `nullptr`.
    1412             :     /// @param first_mapping_count number of elements in `first_mapping`
    1413             :     /// @param second_mapping if you want the mapping from the positions of
    1414             :     ///        entries in `other` to the positions in the union, this should be
    1415             :     ///        a pointer to an array containing `other.count()` elements, to be
    1416             :     ///        filled by this function. Otherwise it should be a `nullptr`.
    1417             :     /// @param second_mapping_count number of elements in `second_mapping`
    1418             :     Labels set_union(
    1419             :         const Labels& other,
    1420             :         int64_t* first_mapping = nullptr,
    1421             :         size_t first_mapping_count = 0,
    1422             :         int64_t* second_mapping = nullptr,
    1423             :         size_t second_mapping_count = 0
    1424             :     ) const {
    1425             :         mts_labels_t result;
    1426             :         std::memset(&result, 0, sizeof(result));
    1427             : 
    1428             :         details::check_status(mts_labels_union(
    1429             :             labels_,
    1430             :             other.labels_,
    1431             :             &result,
    1432             :             first_mapping,
    1433             :             first_mapping_count,
    1434             :             second_mapping,
    1435             :             second_mapping_count
    1436             :         ));
    1437             : 
    1438             :         return Labels(result);
    1439             :     }
    1440             : 
    1441             :     /// Take the union of these `Labels` with `other`.
    1442             :     ///
    1443             :     /// If requested, this function can also give the positions in the
    1444             :     /// union where each entry of the input `Labels` ended up.
    1445             :     ///
    1446             :     /// No user data pointer is registered with the output, even if the inputs
    1447             :     /// have some.
    1448             :     ///
    1449             :     /// @param other the `Labels` we want to take the union with
    1450             :     /// @param first_mapping if you want the mapping from the positions of
    1451             :     ///        entries in `this` to the positions in the union, this should be
    1452             :     ///        a vector containing `this->count()` elements, to be filled by
    1453             :     ///        this function. Otherwise it should be an empty vector.
    1454             :     /// @param second_mapping if you want the mapping from the positions of
    1455             :     ///        entries in `other` to the positions in the union, this should be
    1456             :     ///        a vector containing `other.count()` elements, to be filled by
    1457             :     ///        this function. Otherwise it should be an empty vector.
    1458             :     Labels set_union(
    1459             :         const Labels& other,
    1460             :         std::vector<int64_t>& first_mapping,
    1461             :         std::vector<int64_t>& second_mapping
    1462             :     ) const {
    1463             :         auto* first_mapping_ptr = first_mapping.data();
    1464             :         auto first_mapping_count = first_mapping.size();
    1465             :         if (first_mapping_count == 0) {
    1466             :             first_mapping_ptr = nullptr;
    1467             :         }
    1468             : 
    1469             :         auto* second_mapping_ptr = second_mapping.data();
    1470             :         auto second_mapping_count = second_mapping.size();
    1471             :         if (second_mapping_count == 0) {
    1472             :             second_mapping_ptr = nullptr;
    1473             :         }
    1474             : 
    1475             :         return this->set_union(
    1476             :             other,
    1477             :             first_mapping_ptr,
    1478             :             first_mapping_count,
    1479             :             second_mapping_ptr,
    1480             :             second_mapping_count
    1481             :         );
    1482             :     }
    1483             : 
    1484             :     /// Take the intersection of these `Labels` with `other`.
    1485             :     ///
    1486             :     /// If requested, this function can also give the positions in the
    1487             :     /// intersection where each entry of the input `Labels` ended up.
    1488             :     ///
    1489             :     /// No user data pointer is registered with the output, even if the inputs
    1490             :     /// have some.
    1491             :     ///
    1492             :     /// @param other the `Labels` we want to take the intersection with
    1493             :     /// @param first_mapping if you want the mapping from the positions of
    1494             :     ///        entries in `this` to the positions in the intersection, this
    1495             :     ///        should be a pointer to an array containing `this->count()`
    1496             :     ///        elements, to be filled by this function. Otherwise it should be a
    1497             :     ///        `nullptr`. If an entry in `this` is not used in the intersection,
    1498             :     ///        the mapping will be set to -1.
    1499             :     /// @param first_mapping_count number of elements in `first_mapping`
    1500             :     /// @param second_mapping if you want the mapping from the positions of
    1501             :     ///        entries in `other` to the positions in the intersection, this
    1502             :     ///        should be a pointer to an array containing `other.count()`
    1503             :     ///        elements, to be filled by this function. Otherwise it should be a
    1504             :     ///        `nullptr`. If an entry in `other` is not used in the
    1505             :     ///        intersection, the mapping will be set to -1.
    1506             :     /// @param second_mapping_count number of elements in `second_mapping`
    1507             :     Labels set_intersection(
    1508             :         const Labels& other,
    1509             :         int64_t* first_mapping = nullptr,
    1510             :         size_t first_mapping_count = 0,
    1511             :         int64_t* second_mapping = nullptr,
    1512             :         size_t second_mapping_count = 0
    1513             :     ) const {
    1514             :         mts_labels_t result;
    1515             :         std::memset(&result, 0, sizeof(result));
    1516             : 
    1517             :         details::check_status(mts_labels_intersection(
    1518             :             labels_,
    1519             :             other.labels_,
    1520             :             &result,
    1521             :             first_mapping,
    1522             :             first_mapping_count,
    1523             :             second_mapping,
    1524             :             second_mapping_count
    1525             :         ));
    1526             : 
    1527             :         return Labels(result);
    1528             :     }
    1529             : 
    1530             :     /// Take the intersection of this `Labels` with `other`.
    1531             :     ///
    1532             :     /// If requested, this function can also give the positions in the
    1533             :     /// intersection where each entry of the input `Labels` ended up.
    1534             :     ///
    1535             :     /// No user data pointer is registered with the output, even if the inputs
    1536             :     /// have some.
    1537             :     ///
    1538             :     /// @param other the `Labels` we want to take the intersection with
    1539             :     /// @param first_mapping if you want the mapping from the positions of
    1540             :     ///        entries in `this` to the positions in the intersection, this
    1541             :     ///        should be a vector containing `this->count()` elements, to be
    1542             :     ///        filled by this function. Otherwise it should be an empty vector.
    1543             :     ///        If an entry in `this` is not used in the intersection, the
    1544             :     ///        mapping will be set to -1.
    1545             :     /// @param second_mapping if you want the mapping from the positions of
    1546             :     ///        entries in `other` to the positions in the intersection, this
    1547             :     ///        should be a vector containing `other.count()` elements, to be
    1548             :     ///        filled by this function. Otherwise it should be an empty vector.
    1549             :     ///        If an entry in `other` is not used in the intersection, the
    1550             :     ///        mapping will be set to -1.
    1551             :     Labels set_intersection(
    1552             :         const Labels& other,
    1553             :         std::vector<int64_t>& first_mapping,
    1554             :         std::vector<int64_t>& second_mapping
    1555             :     ) const {
    1556             :         auto* first_mapping_ptr = first_mapping.data();
    1557             :         auto first_mapping_count = first_mapping.size();
    1558             :         if (first_mapping_count == 0) {
    1559             :             first_mapping_ptr = nullptr;
    1560             :         }
    1561             : 
    1562             :         auto* second_mapping_ptr = second_mapping.data();
    1563             :         auto second_mapping_count = second_mapping.size();
    1564             :         if (second_mapping_count == 0) {
    1565             :             second_mapping_ptr = nullptr;
    1566             :         }
    1567             : 
    1568             :         return this->set_intersection(
    1569             :             other,
    1570             :             first_mapping_ptr,
    1571             :             first_mapping_count,
    1572             :             second_mapping_ptr,
    1573             :             second_mapping_count
    1574             :         );
    1575             :     }
    1576             : 
    1577             :     /// Take the difference of these `Labels` with `other`.
    1578             :     ///
    1579             :     /// If requested, this function can also give the positions in the
    1580             :     /// difference where each entry of the input `Labels` ended up.
    1581             :     ///
    1582             :     /// No user data pointer is registered with the output, even if the inputs
    1583             :     /// have some.
    1584             :     ///
    1585             :     /// @param other the `Labels` we want to take the difference with
    1586             :     /// @param first_mapping if you want the mapping from the positions of
    1587             :     ///        entries in `this` to the positions in the difference, this
    1588             :     ///        should be a pointer to an array containing `this->count()`
    1589             :     ///        elements, to be filled by this function. Otherwise it should be a
    1590             :     ///        `nullptr`. If an entry in `this` is not used in the difference,
    1591             :     ///        the mapping will be set to -1.
    1592             :     /// @param first_mapping_count number of elements in `first_mapping`
    1593             :     Labels set_difference(
    1594             :         const Labels& other,
    1595             :         int64_t* first_mapping = nullptr,
    1596             :         size_t first_mapping_count = 0
    1597             :     ) const {
    1598             :         mts_labels_t result;
    1599             :         std::memset(&result, 0, sizeof(result));
    1600             : 
    1601             :         details::check_status(mts_labels_difference(
    1602             :             labels_,
    1603             :             other.labels_,
    1604             :             &result,
    1605             :             first_mapping,
    1606             :             first_mapping_count
    1607             :         ));
    1608             : 
    1609             :         return Labels(result);
    1610             :     }
    1611             : 
    1612             :     /// Take the difference of this `Labels` with `other`.
    1613             :     ///
    1614             :     /// If requested, this function can also give the positions in the
    1615             :     /// difference where each entry of the input `Labels` ended up.
    1616             :     ///
    1617             :     /// No user data pointer is registered with the output, even if the inputs
    1618             :     /// have some.
    1619             :     ///
    1620             :     /// @param other the `Labels` we want to take the difference with
    1621             :     /// @param first_mapping if you want the mapping from the positions of
    1622             :     ///        entries in `this` to the positions in the difference, this
    1623             :     ///        should be a vector containing `this->count()` elements, to be
    1624             :     ///        filled by this function. Otherwise it should be an empty vector.
    1625             :     ///        If an entry in `this` is not used in the difference, the
    1626             :     ///        mapping will be set to -1.
    1627             :     Labels set_difference(const Labels& other, std::vector<int64_t>& first_mapping) const {
    1628             :         auto* first_mapping_ptr = first_mapping.data();
    1629             :         auto first_mapping_count = first_mapping.size();
    1630             :         if (first_mapping_count == 0) {
    1631             :             first_mapping_ptr = nullptr;
    1632             :         }
    1633             : 
    1634             :         return this->set_difference(
    1635             :             other,
    1636             :             first_mapping_ptr,
    1637             :             first_mapping_count
    1638             :         );
    1639             :     }
    1640             : 
    1641             :     /// Select entries in these `Labels` that match the `selection`.
    1642             :     ///
    1643             :     /// The selection's names must be a subset of the names of these labels.
    1644             :     ///
    1645             :     /// All entries in these `Labels` that match one of the entry in the
    1646             :     /// `selection` for all the selection's dimension will be picked. Any entry
    1647             :     /// in the `selection` but not in these `Labels` will be ignored.
    1648             :     ///
    1649             :     /// @param selection definition of the selection criteria. Multiple entries
    1650             :     ///        are interpreted as a logical `or` operation.
    1651             :     /// @param selected on input, a pointer to an array with space for
    1652             :     ///        `*selected_count` entries. On output, the first `*selected_count`
    1653             :     ///        values will contain the index in `labels` of selected entries.
    1654             :     /// @param selected_count on input, size of the `selected` array. On output,
    1655             :     ///        this will contain the number of selected entries.
    1656             :     void select(const Labels& selection, int64_t* selected, size_t *selected_count) const {
    1657             :         details::check_status(mts_labels_select(
    1658             :             labels_,
    1659             :             selection.labels_,
    1660             :             selected,
    1661             :             selected_count
    1662             :         ));
    1663             :     }
    1664             : 
    1665             :     /// Select entries in these `Labels` that match the `selection`.
    1666             :     ///
    1667             :     /// This function does the same thing as the one above, but allocates and
    1668             :     /// return the list of selected indexes in a `std::vector`
    1669             :     std::vector<int64_t> select(const Labels& selection) const {
    1670             :         auto selected_count = this->count();
    1671             :         auto selected = std::vector<int64_t>(selected_count, -1);
    1672             : 
    1673             :         this->select(selection, selected.data(), &selected_count);
    1674             : 
    1675             :         selected.resize(selected_count);
    1676             :         return selected;
    1677             :     }
    1678             : 
    1679             :     /*!
    1680             :      * \verbatim embed:rst:leading-asterisk
    1681             :      *
    1682             :      * Load previously saved ``Labels`` from the given path.
    1683             :      *
    1684             :      * This is identical to :cpp:func:`metatensor::io::load_labels`, and
    1685             :      * provided as a convenience API.
    1686             :      *
    1687             :      * \endverbatim
    1688             :      */
    1689             :     static Labels load(const std::string& path) {
    1690             :         return metatensor::io::load_labels(path);
    1691             :     }
    1692             : 
    1693             :     /*!
    1694             :      * \verbatim embed:rst:leading-asterisk
    1695             :      *
    1696             :      * Load previously saved ``Labels`` from a in-memory buffer.
    1697             :      *
    1698             :      * This is identical to :cpp:func:`metatensor::io::load_labels_buffer`, and
    1699             :      * provided as a convenience API.
    1700             :      *
    1701             :      * \endverbatim
    1702             :      */
    1703             :     static Labels load_buffer(const uint8_t* buffer, size_t buffer_count) {
    1704             :         return metatensor::io::load_labels_buffer(buffer, buffer_count);
    1705             :     }
    1706             : 
    1707             :     /*!
    1708             :      * \verbatim embed:rst:leading-asterisk
    1709             :      *
    1710             :      * Load previously saved ``Labels`` from a in-memory buffer.
    1711             :      *
    1712             :      * This is identical to :cpp:func:`metatensor::io::load_labels_buffer`, and
    1713             :      * provided as a convenience API.
    1714             :      *
    1715             :      * \endverbatim
    1716             :      */
    1717             :     template <typename Buffer>
    1718             :     static Labels load_buffer(const Buffer& buffer) {
    1719             :         return metatensor::io::load_labels_buffer<Buffer>(buffer);
    1720             :     }
    1721             : 
    1722             :     /*!
    1723             :      * \verbatim embed:rst:leading-asterisk
    1724             :      *
    1725             :      * Save ``Labels`` to the given path.
    1726             :      *
    1727             :      * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
    1728             :      * convenience API.
    1729             :      *
    1730             :      * \endverbatim
    1731             :      */
    1732             :     void save(const std::string& path) const {
    1733             :         metatensor::io::save(path, *this);
    1734             :     }
    1735             : 
    1736             :     /*!
    1737             :      * \verbatim embed:rst:leading-asterisk
    1738             :      *
    1739             :      * Save ``Labels`` to an in-memory buffer.
    1740             :      *
    1741             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    1742             :      * provided as a convenience API.
    1743             :      *
    1744             :      * \endverbatim
    1745             :      */
    1746             :     std::vector<uint8_t> save_buffer() const {
    1747             :         return metatensor::io::save_buffer(*this);
    1748             :     }
    1749             : 
    1750             :     /*!
    1751             :      * \verbatim embed:rst:leading-asterisk
    1752             :      *
    1753             :      * Save ``Labels`` to an in-memory buffer.
    1754             :      *
    1755             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    1756             :      * provided as a convenience API.
    1757             :      *
    1758             :      * \endverbatim
    1759             :      */
    1760             :     template <typename Buffer>
    1761             :     Buffer save_buffer() const {
    1762             :         return metatensor::io::save_buffer<Buffer>(*this);
    1763             :     }
    1764             : 
    1765             : private:
    1766             :     explicit Labels(): values_(static_cast<const int32_t*>(nullptr), {0, 0})
    1767             :     {
    1768             :         std::memset(&labels_, 0, sizeof(labels_));
    1769             :     }
    1770             : 
    1771             :     explicit Labels(mts_labels_t labels):
    1772             :         values_(labels.values, {labels.count, labels.size}),
    1773             :         labels_(labels)
    1774             :     {
    1775             :         assert(labels_.internal_ptr_ != nullptr);
    1776             : 
    1777             :         for (size_t i=0; i<labels_.size; i++) {
    1778             :             names_.push_back(labels_.names[i]);
    1779             :         }
    1780             :     }
    1781             : 
    1782             :     // the constructor below is ambiguous with the public constructor taking
    1783             :     // `std::initializer_list`, so we use a private dummy struct argument to
    1784             :     // remove the ambiguity.
    1785             :     struct InternalConstructor {};
    1786             :     Labels(const std::vector<std::string>& names, const NDArray<int32_t>& values, InternalConstructor):
    1787             :         Labels(names, values.data(), values.shape()[0]) {}
    1788             : 
    1789             :     Labels(const std::vector<std::string>& names, const NDArray<int32_t>& values, assume_unique, InternalConstructor):
    1790             :         Labels(names, values.data(), values.shape()[0], assume_unique{}) {}
    1791             : 
    1792             :     friend Labels details::labels_from_cxx(const std::vector<std::string>& names, const int32_t* values, size_t count, bool assume_unique);
    1793             :     friend Labels io::load_labels(const std::string &path);
    1794             :     friend Labels io::load_labels_buffer(const uint8_t* buffer, size_t buffer_count);
    1795             :     friend class TensorMap;
    1796             :     friend class TensorBlock;
    1797             : 
    1798             :     friend class metatensor_torch::LabelsHolder;
    1799             : 
    1800             :     std::vector<const char*> names_;
    1801             :     NDArray<int32_t> values_;
    1802             :     mts_labels_t labels_;
    1803             : 
    1804             :     friend bool operator==(const Labels& lhs, const Labels& rhs);
    1805             : };
    1806             : 
    1807             : namespace details {
    1808             :     inline metatensor::Labels labels_from_cxx(
    1809             :         const std::vector<std::string>& names,
    1810             :         const int32_t* values,
    1811             :         size_t count,
    1812             :         bool assume_unique = false
    1813             :     ) {
    1814             :         mts_labels_t labels;
    1815             :         std::memset(&labels, 0, sizeof(labels));
    1816             : 
    1817             :         auto c_names = std::vector<const char*>();
    1818             :         for (const auto& name: names) {
    1819             :             c_names.push_back(name.c_str());
    1820             :         }
    1821             : 
    1822             :         labels.names = c_names.data();
    1823             :         labels.size = c_names.size();
    1824             :         labels.count = count;
    1825             :         labels.values = values;
    1826             : 
    1827             :         if (assume_unique) {
    1828             :             details::check_status(mts_labels_create_assume_unique(&labels));
    1829             :         } else {
    1830             :             details::check_status(mts_labels_create(&labels));
    1831             :         }
    1832             : 
    1833             :         return metatensor::Labels(labels);
    1834             :     }
    1835             : }
    1836             : 
    1837             : 
    1838             : /// Two Labels compare equal only if they have the same names and values in the
    1839             : /// same order.
    1840             : inline bool operator==(const Labels& lhs, const Labels& rhs) {
    1841             :     if (lhs.names_.size() != rhs.names_.size()) {
    1842             :         return false;
    1843             :     }
    1844             : 
    1845             :     for (size_t i=0; i<lhs.names_.size(); i++) {
    1846             :         if (std::strcmp(lhs.names_[i], rhs.names_[i]) != 0) {
    1847             :             return false;
    1848             :         }
    1849             :     }
    1850             : 
    1851             :     return lhs.values() == rhs.values();
    1852             : }
    1853             : 
    1854             : /// Two Labels compare equal only if they have the same names and values in the
    1855             : /// same order.
    1856             : inline bool operator!=(const Labels& lhs, const Labels& rhs) {
    1857             :     return !(lhs == rhs);
    1858             : }
    1859             : 
    1860             : 
    1861             : /******************************************************************************/
    1862             : /******************************************************************************/
    1863             : /*                                                                            */
    1864             : /*                             TensorBlock                                    */
    1865             : /*                                                                            */
    1866             : /******************************************************************************/
    1867             : /******************************************************************************/
    1868             : 
    1869             : 
    1870             : /// Basic building block for a tensor map.
    1871             : ///
    1872             : /// A single block contains a n-dimensional `mts_array_t` (or `DataArrayBase`),
    1873             : /// and n sets of `Labels` (one for each dimension). The first dimension is the
    1874             : /// *samples* dimension, the last dimension is the *properties* dimension. Any
    1875             : /// intermediate dimension is called a *component* dimension.
    1876             : ///
    1877             : /// Samples should be used to describe *what* we are representing, while
    1878             : /// properties should contain information about *how* we are representing it.
    1879             : /// Finally, components should be used to describe vectorial or tensorial
    1880             : /// components of the data.
    1881             : ///
    1882             : /// A block can also contain gradients of the values with respect to a variety
    1883             : /// of parameters. In this case, each gradient has a separate set of samples,
    1884             : /// and possibly components but share the same property labels as the values.
    1885             : class TensorBlock final {
    1886             : public:
    1887             :     /// Create a new TensorBlock containing the given `values` array.
    1888             :     ///
    1889             :     /// The different dimensions of the values are described by `samples`,
    1890             :     /// `components` and `properties` `Labels`
    1891             :     TensorBlock(
    1892             :         std::unique_ptr<DataArrayBase> values,
    1893             :         const Labels& samples,
    1894             :         const std::vector<Labels>& components,
    1895             :         const Labels& properties
    1896             :     ):
    1897             :         block_(nullptr),
    1898             :         is_view_(false)
    1899             :     {
    1900             :         auto c_components = std::vector<mts_labels_t>();
    1901             :         for (const auto& component: components) {
    1902             :             c_components.push_back(component.as_mts_labels_t());
    1903             :         }
    1904             :         block_ = mts_block(
    1905             :             DataArrayBase::to_mts_array_t(std::move(values)),
    1906             :             samples.as_mts_labels_t(),
    1907             :             c_components.data(),
    1908             :             c_components.size(),
    1909             :             properties.as_mts_labels_t()
    1910             :         );
    1911             : 
    1912             :         details::check_pointer(block_);
    1913             :     }
    1914             : 
    1915             :     ~TensorBlock() {
    1916             :         if (!is_view_) {
    1917             :             mts_block_free(block_);
    1918             :         }
    1919             :     }
    1920             : 
    1921             :     /// TensorBlock can NOT be copy constructed, use TensorBlock::clone instead
    1922             :     TensorBlock(const TensorBlock&) = delete;
    1923             : 
    1924             :     /// TensorBlock can NOT be copy assigned, use TensorBlock::clone instead
    1925             :     TensorBlock& operator=(const TensorBlock& other) = delete;
    1926             : 
    1927             :     /// TensorBlock can be move constructed
    1928             :     TensorBlock(TensorBlock&& other) noexcept : TensorBlock() {
    1929             :         *this = std::move(other);
    1930             :     }
    1931             : 
    1932             :     /// TensorBlock can be moved assigned
    1933             :     TensorBlock& operator=(TensorBlock&& other) noexcept {
    1934             :         if (!is_view_) {
    1935             :             mts_block_free(block_);
    1936             :         }
    1937             : 
    1938             :         this->block_ = other.block_;
    1939             :         this->is_view_ = other.is_view_;
    1940             :         other.block_ = nullptr;
    1941             :         other.is_view_ = true;
    1942             : 
    1943             :         return *this;
    1944             :     }
    1945             : 
    1946             :     /// Make a copy of this `TensorBlock`, including all the data contained inside
    1947             :     TensorBlock clone() const {
    1948             :         auto copy = TensorBlock();
    1949             :         copy.is_view_ = false;
    1950             :         copy.block_ = mts_block_copy(this->block_);
    1951             :         details::check_pointer(copy.block_);
    1952             :         return copy;
    1953             :     }
    1954             : 
    1955             :     /// Get a copy of the metadata in this block (i.e. samples, components, and
    1956             :     /// properties), ignoring the data itself.
    1957             :     ///
    1958             :     /// The resulting block values will be an `EmptyDataArray` instance, which
    1959             :     /// does not contain any data.
    1960             :     TensorBlock clone_metadata_only() const {
    1961             :         auto block = TensorBlock(
    1962             :             std::unique_ptr<EmptyDataArray>(new EmptyDataArray(this->values_shape())),
    1963             :             this->samples(),
    1964             :             this->components(),
    1965             :             this->properties()
    1966             :         );
    1967             : 
    1968             :         for (const auto& parameter: this->gradients_list()) {
    1969             :             auto gradient = this->gradient(parameter);
    1970             :             block.add_gradient(parameter, gradient.clone_metadata_only());
    1971             :         }
    1972             : 
    1973             :         return block;
    1974             :     }
    1975             : 
    1976             :     /// Get a view in the values in this block
    1977             :     NDArray<double> values() & {
    1978             :         auto array = this->mts_array();
    1979             :         double* data = nullptr;
    1980             :         details::check_status(array.data(array.ptr, &data));
    1981             : 
    1982             :         return NDArray<double>(data, this->values_shape());
    1983             :     }
    1984             : 
    1985             :     NDArray<double> values() && = delete;
    1986             : 
    1987             :     /// Access the sample `Labels` for this block.
    1988             :     ///
    1989             :     /// The entries in these labels describe the first dimension of the
    1990             :     /// `values()` array.
    1991             :     Labels samples() const {
    1992             :         return this->labels(0);
    1993             :     }
    1994             : 
    1995             :     /// Access the component `Labels` for this block.
    1996             :     ///
    1997             :     /// The entries in these labels describe intermediate dimensions of the
    1998             :     /// `values()` array.
    1999             :     std::vector<Labels> components() const {
    2000             :         auto shape = this->values_shape();
    2001             : 
    2002             :         auto result = std::vector<Labels>();
    2003             :         for (size_t i=1; i<shape.size() - 1; i++) {
    2004             :             result.emplace_back(this->labels(i));
    2005             :         }
    2006             : 
    2007             :         return result;
    2008             :     }
    2009             : 
    2010             :     /// Access the property `Labels` for this block.
    2011             :     ///
    2012             :     /// The entries in these labels describe the last dimension of the
    2013             :     /// `values()` array. The properties are guaranteed to be the same for
    2014             :     /// a block and all of its gradients.
    2015             :     Labels properties() const {
    2016             :         auto shape = this->values_shape();
    2017             :         return this->labels(shape.size() - 1);
    2018             :     }
    2019             : 
    2020             :     /// Add a set of gradients with respect to `parameters` in this block.
    2021             :     ///
    2022             :     /// @param parameter add gradients with respect to this `parameter` (e.g.
    2023             :     ///                 `"positions"`, `"cell"`, ...)
    2024             :     /// @param gradient a `TensorBlock` whose values contain the gradients with
    2025             :     ///                 respect to the `parameter`. The labels of the gradient
    2026             :     ///                 `TensorBlock` should be organized as follows: its
    2027             :     ///                 `samples` must contain `"sample"` as the first label,
    2028             :     ///                 which establishes a correspondence with the `samples` of
    2029             :     ///                 the original `TensorBlock`; its components must contain
    2030             :     ///                 at least the same components as the original
    2031             :     ///                 `TensorBlock`, with any additional component coming
    2032             :     ///                 before those; its properties must match those of the
    2033             :     ///                 original `TensorBlock`.
    2034             :     void add_gradient(const std::string& parameter, TensorBlock gradient) {
    2035             :         if (is_view_) {
    2036             :             throw Error(
    2037             :                 "can not call TensorBlock::add_gradient on this block since "
    2038             :                 "it is a view inside a TensorMap"
    2039             :             );
    2040             :         }
    2041             : 
    2042             :         details::check_status(mts_block_add_gradient(
    2043             :             block_,
    2044             :             parameter.c_str(),
    2045             :             gradient.release()
    2046             :         ));
    2047             :     }
    2048             : 
    2049             :     /// Get a list of all gradients defined in this block.
    2050             :     std::vector<std::string> gradients_list() const {
    2051             :         const char*const * parameters = nullptr;
    2052             :         uintptr_t count = 0;
    2053             :         details::check_status(mts_block_gradients_list(
    2054             :             block_,
    2055             :             &parameters,
    2056             :             &count
    2057             :         ));
    2058             : 
    2059             :         auto result = std::vector<std::string>();
    2060             :         for (uint64_t i=0; i<count; i++) {
    2061             :             result.emplace_back(parameters[i]);
    2062             :         }
    2063             : 
    2064             :         return result;
    2065             :     }
    2066             : 
    2067             :     /// Get the gradient in this block with respect to the given `parameter`.
    2068             :     /// The gradient is returned as a TensorBlock itself.
    2069             :     ///
    2070             :     /// @param parameter check for gradients with respect to this `parameter`
    2071             :     ///                  (e.g. `"positions"`, `"cell"`, ...)
    2072             :     TensorBlock gradient(const std::string& parameter) const {
    2073             :         mts_block_t* gradient_block = nullptr;
    2074             :         details::check_status(
    2075             :             mts_block_gradient(block_, parameter.c_str(), &gradient_block)
    2076             :         );
    2077             :         details::check_pointer(gradient_block);
    2078             :         return TensorBlock::unsafe_view_from_ptr(gradient_block);
    2079             :     }
    2080             : 
    2081             :     /// Get the `mts_block_t` pointer corresponding to this block.
    2082             :     ///
    2083             :     /// The block pointer is still managed by the current `TensorBlock`
    2084             :     mts_block_t* as_mts_block_t() & {
    2085             :         if (is_view_) {
    2086             :             throw Error(
    2087             :                 "can not call non-const TensorBlock::as_mts_block_t on this "
    2088             :                 "block since it is a view inside a TensorMap"
    2089             :             );
    2090             :         }
    2091             :         return block_;
    2092             :     }
    2093             : 
    2094             :     /// const version of `as_mts_block_t`
    2095             :     const mts_block_t* as_mts_block_t() const & {
    2096             :         return block_;
    2097             :     }
    2098             : 
    2099             :     const mts_block_t* as_mts_block_t() && = delete;
    2100             : 
    2101             :     /// Create a new TensorBlock taking ownership of a raw `mts_block_t` pointer.
    2102             :     static TensorBlock unsafe_from_ptr(mts_block_t* ptr) {
    2103             :         auto block = TensorBlock();
    2104             :         block.block_ = ptr;
    2105             :         block.is_view_ = false;
    2106             :         return block;
    2107             :     }
    2108             : 
    2109             :     /// Create a new TensorBlock which is a view corresponding to a raw
    2110             :     /// `mts_block_t` pointer.
    2111             :     static TensorBlock unsafe_view_from_ptr(mts_block_t* ptr) {
    2112             :         auto block = TensorBlock();
    2113             :         block.block_ = ptr;
    2114             :         block.is_view_ = true;
    2115             :         return block;
    2116             :     }
    2117             : 
    2118             :     /// Get a raw `mts_array_t` corresponding to the values in this block.
    2119             :     mts_array_t mts_array() {
    2120             :         mts_array_t array;
    2121             :         std::memset(&array, 0, sizeof(array));
    2122             : 
    2123             :         details::check_status(
    2124             :             mts_block_data(block_, &array)
    2125             :         );
    2126             :         return array;
    2127             :     }
    2128             : 
    2129             :     /// Get the labels in this block associated with the given `axis`.
    2130             :     Labels labels(uintptr_t axis) const {
    2131             :         mts_labels_t labels;
    2132             :         std::memset(&labels, 0, sizeof(labels));
    2133             :         details::check_status(mts_block_labels(
    2134             :             block_, axis, &labels
    2135             :         ));
    2136             : 
    2137             :         return Labels(labels);
    2138             :     }
    2139             : 
    2140             :     /// Get the shape of the value array for this block
    2141          62 :     std::vector<uintptr_t> values_shape() const {
    2142          62 :         auto array = this->const_mts_array();
    2143             : 
    2144          62 :         const uintptr_t* shape = nullptr;
    2145          62 :         uintptr_t shape_count = 0;
    2146          62 :         details::check_status(array.shape(array.ptr, &shape, &shape_count));
    2147             :         assert(shape_count >= 2);
    2148             : 
    2149          62 :         return {shape, shape + shape_count};
    2150             :     }
    2151             : 
    2152             :     /*!
    2153             :      * \verbatim embed:rst:leading-asterisk
    2154             :      *
    2155             :      * Load a previously saved ``TensorBlock`` from the given path.
    2156             :      *
    2157             :      * This is identical to :cpp:func:`metatensor::io::load_block`, and provided
    2158             :      * as a convenience API.
    2159             :      *
    2160             :      * \endverbatim
    2161             :      */
    2162             :     static TensorBlock load(
    2163             :         const std::string& path,
    2164             :         mts_create_array_callback_t create_array = details::default_create_array
    2165             :     ) {
    2166             :         return metatensor::io::load_block(path, create_array);
    2167             :     }
    2168             : 
    2169             :     /*!
    2170             :      * \verbatim embed:rst:leading-asterisk
    2171             :      *
    2172             :      * Load a previously saved ``TensorBlock`` from a in-memory buffer.
    2173             :      *
    2174             :      * This is identical to :cpp:func:`metatensor::io::load_block_buffer`, and
    2175             :      * provided as a convenience API.
    2176             :      *
    2177             :      * \endverbatim
    2178             :      */
    2179             :     static TensorBlock load_buffer(
    2180             :         const uint8_t* buffer,
    2181             :         size_t buffer_count,
    2182             :         mts_create_array_callback_t create_array = details::default_create_array
    2183             :     ) {
    2184             :         return metatensor::io::load_block_buffer(buffer, buffer_count, create_array);
    2185             :     }
    2186             : 
    2187             :     /*!
    2188             :      * \verbatim embed:rst:leading-asterisk
    2189             :      *
    2190             :      * Load a previously saved ``TensorBlock`` from a in-memory buffer.
    2191             :      *
    2192             :      * This is identical to :cpp:func:`metatensor::io::load_block_buffer`, and
    2193             :      * provided as a convenience API.
    2194             :      *
    2195             :      * \endverbatim
    2196             :      */
    2197             :     template <typename Buffer>
    2198             :     static TensorBlock load_buffer(
    2199             :         const Buffer& buffer,
    2200             :         mts_create_array_callback_t create_array = details::default_create_array
    2201             :     ) {
    2202             :         return metatensor::io::load_block_buffer<Buffer>(buffer, create_array);
    2203             :     }
    2204             : 
    2205             :     /*!
    2206             :      * \verbatim embed:rst:leading-asterisk
    2207             :      *
    2208             :      * Save this ``TensorBlock`` to the given path.
    2209             :      *
    2210             :      * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
    2211             :      * convenience API.
    2212             :      *
    2213             :      * \endverbatim
    2214             :      */
    2215             :     void save(const std::string& path) const {
    2216             :         metatensor::io::save(path, *this);
    2217             :     }
    2218             : 
    2219             :     /*!
    2220             :      * \verbatim embed:rst:leading-asterisk
    2221             :      *
    2222             :      * Save this ``TensorBlock`` to an in-memory buffer.
    2223             :      *
    2224             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    2225             :      * provided as a convenience API.
    2226             :      *
    2227             :      * \endverbatim
    2228             :      */
    2229             :     std::vector<uint8_t> save_buffer() const {
    2230             :         return metatensor::io::save_buffer(*this);
    2231             :     }
    2232             : 
    2233             :     /*!
    2234             :      * \verbatim embed:rst:leading-asterisk
    2235             :      *
    2236             :      * Save this ``TensorBlock`` to an in-memory buffer.
    2237             :      *
    2238             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    2239             :      * provided as a convenience API.
    2240             :      *
    2241             :      * \endverbatim
    2242             :      */
    2243             :     template <typename Buffer>
    2244             :     Buffer save_buffer() const {
    2245             :         return metatensor::io::save_buffer<Buffer>(*this);
    2246             :     }
    2247             : 
    2248             : private:
    2249             :     /// Constructor of a TensorBlock not associated with anything
    2250             :     explicit TensorBlock(): block_(nullptr), is_view_(true) {}
    2251             : 
    2252             :     /// Create a C++ TensorBlock from a C `mts_block_t` pointer. The C++
    2253             :     /// block takes ownership of the C pointer.
    2254             :     explicit TensorBlock(mts_block_t* block): block_(block), is_view_(false) {}
    2255             : 
    2256             :     /// Get the `mts_array_t` for this block.
    2257             :     ///
    2258             :     /// The returned `mts_array_t` should only be used in a const context
    2259          62 :     mts_array_t const_mts_array() const {
    2260             :         mts_array_t array;
    2261             :         std::memset(&array, 0, sizeof(array));
    2262             : 
    2263          62 :         details::check_status(
    2264          62 :             mts_block_data(block_, &array)
    2265             :         );
    2266          62 :         return array;
    2267             :     }
    2268             : 
    2269             :     /// Release the `mts_block_t` pointer corresponding to this `TensorBlock`.
    2270             :     ///
    2271             :     /// The block pointer is **no longer** managed by the current `TensorBlock`,
    2272             :     /// and should manually be freed when no longer required.
    2273             :     mts_block_t* release() {
    2274             :          if (is_view_) {
    2275             :             throw Error(
    2276             :                 "can not call TensorBlock::release on this "
    2277             :                 "block since it is a view inside a TensorMap"
    2278             :             );
    2279             :         }
    2280             :         auto* ptr = block_;
    2281             :         block_ = nullptr;
    2282             :         is_view_ = false;
    2283             :         return ptr;
    2284             :     }
    2285             : 
    2286             :     friend class TensorMap;
    2287             :     friend class metatensor_torch::TensorBlockHolder;
    2288             :     friend TensorBlock metatensor::io::load_block(
    2289             :         const std::string& path,
    2290             :         mts_create_array_callback_t create_array
    2291             :     );
    2292             :     friend TensorBlock metatensor::io::load_block_buffer(
    2293             :         const uint8_t* buffer,
    2294             :         size_t buffer_count,
    2295             :         mts_create_array_callback_t create_array
    2296             :     );
    2297             : 
    2298             :     mts_block_t* block_;
    2299             :     bool is_view_;
    2300             : };
    2301             : 
    2302             : 
    2303             : /******************************************************************************/
    2304             : /******************************************************************************/
    2305             : /*                                                                            */
    2306             : /*                               TensorMap                                    */
    2307             : /*                                                                            */
    2308             : /******************************************************************************/
    2309             : /******************************************************************************/
    2310             : 
    2311             : /// A TensorMap is the main user-facing class of this library, and can store any
    2312             : /// kind of data used in atomistic machine learning.
    2313             : ///
    2314             : /// A tensor map contains a list of `TensorBlock`, each one associated with a
    2315             : /// key. Users can access the blocks either one by one with the `block_by_id()`
    2316             : /// function.
    2317             : ///
    2318             : /// A tensor map provides functions to move some of these keys to the samples or
    2319             : /// properties labels of the blocks, moving from a sparse representation of the
    2320             : /// data to a dense one.
    2321             : class TensorMap final {
    2322             : public:
    2323             :     /// Create a new TensorMap with the given `keys` and `blocks`
    2324             :     TensorMap(Labels keys, std::vector<TensorBlock> blocks) {
    2325             :         auto c_blocks = std::vector<mts_block_t*>();
    2326             :         for (auto& block: blocks) {
    2327             :             // We will move the data inside the new map, let's release the
    2328             :             // pointers out of the TensorBlock now
    2329             :             c_blocks.push_back(block.release());
    2330             :         }
    2331             : 
    2332             :         tensor_ = mts_tensormap(
    2333             :             keys.as_mts_labels_t(),
    2334             :             c_blocks.data(),
    2335             :             c_blocks.size()
    2336             :         );
    2337             : 
    2338             :         details::check_pointer(tensor_);
    2339             :     }
    2340             : 
    2341             :     ~TensorMap() {
    2342             :         mts_tensormap_free(tensor_);
    2343             :     }
    2344             : 
    2345             :     /// TensorMap can NOT be copy constructed, use TensorMap::clone instead
    2346             :     TensorMap(const TensorMap&) = delete;
    2347             :     /// TensorMap can not be copy assigned, use TensorMap::clone instead
    2348             :     TensorMap& operator=(const TensorMap&) = delete;
    2349             : 
    2350             :     /// TensorMap can be move constructed
    2351             :     TensorMap(TensorMap&& other) noexcept : TensorMap(nullptr) {
    2352             :         *this = std::move(other);
    2353             :     }
    2354             : 
    2355             :     /// TensorMap can be move assigned
    2356             :     TensorMap& operator=(TensorMap&& other) noexcept {
    2357             :         mts_tensormap_free(tensor_);
    2358             : 
    2359             :         this->tensor_ = other.tensor_;
    2360             :         other.tensor_ = nullptr;
    2361             : 
    2362             :         return *this;
    2363             :     }
    2364             : 
    2365             :     /// Make a copy of this `TensorMap`, including all the data contained inside
    2366             :     TensorMap clone() const {
    2367             :         auto* copy = mts_tensormap_copy(this->tensor_);
    2368             :         details::check_pointer(copy);
    2369             :         return TensorMap(copy);
    2370             :     }
    2371             : 
    2372             :     /// Get a copy of the metadata in this `TensorMap` (i.e. keys, samples,
    2373             :     /// components, and properties), ignoring the data itself.
    2374             :     ///
    2375             :     /// The resulting blocks values will be an `EmptyDataArray` instance, which
    2376             :     /// does not contain any data.
    2377             :     TensorMap clone_metadata_only() const {
    2378             :         auto n_blocks = this->keys().count();
    2379             : 
    2380             :         auto blocks = std::vector<TensorBlock>();
    2381             :         blocks.reserve(n_blocks);
    2382             :         for (uintptr_t i=0; i<n_blocks; i++) {
    2383             :             mts_block_t* block_ptr = nullptr;
    2384             :             details::check_status(mts_tensormap_block_by_id(tensor_, &block_ptr, i));
    2385             :             details::check_pointer(block_ptr);
    2386             :             auto block = TensorBlock::unsafe_view_from_ptr(block_ptr);
    2387             : 
    2388             :             blocks.push_back(block.clone_metadata_only());
    2389             :         }
    2390             : 
    2391             :         return TensorMap(this->keys(), std::move(blocks));
    2392             :     }
    2393             : 
    2394             :     /// Get the set of keys labeling the blocks in this tensor map
    2395             :     Labels keys() const {
    2396             :         mts_labels_t keys;
    2397             :         std::memset(&keys, 0, sizeof(keys));
    2398             : 
    2399             :         details::check_status(mts_tensormap_keys(tensor_, &keys));
    2400             :         return Labels(keys);
    2401             :     }
    2402             : 
    2403             :     /// Get a (possibly empty) list of block indexes matching the `selection`
    2404             :     std::vector<uintptr_t> blocks_matching(const Labels& selection) const {
    2405             :         auto matching = std::vector<uintptr_t>(this->keys().count());
    2406             :         uintptr_t count = matching.size();
    2407             : 
    2408             :         details::check_status(mts_tensormap_blocks_matching(
    2409             :             tensor_,
    2410             :             matching.data(),
    2411             :             &count,
    2412             :             selection.as_mts_labels_t()
    2413             :         ));
    2414             : 
    2415             :         assert(count <= matching.size());
    2416             :         matching.resize(count);
    2417             :         return matching;
    2418             :     }
    2419             : 
    2420             :     /// Get a block inside this TensorMap by it's index/the index of the
    2421             :     /// corresponding key.
    2422             :     ///
    2423             :     /// The returned `TensorBlock` is a view inside memory owned by this
    2424             :     /// `TensorMap`, and is only valid as long as the `TensorMap` is kept alive.
    2425             :     TensorBlock block_by_id(uintptr_t index) & {
    2426             :         mts_block_t* block = nullptr;
    2427             :         details::check_status(mts_tensormap_block_by_id(tensor_, &block, index));
    2428             :         details::check_pointer(block);
    2429             : 
    2430             :         return TensorBlock::unsafe_view_from_ptr(block);
    2431             :     }
    2432             : 
    2433             :     TensorBlock block_by_id(uintptr_t index) && = delete;
    2434             : 
    2435             :     /// Merge blocks with the same value for selected keys dimensions along the
    2436             :     /// property axis.
    2437             :     ///
    2438             :     /// The dimensions (names) of `keys_to_move` will be moved from the keys to
    2439             :     /// the property labels, and blocks with the same remaining keys dimensions
    2440             :     /// will be merged together along the property axis.
    2441             :     ///
    2442             :     /// If `keys_to_move` does not contains any entries (i.e.
    2443             :     /// `keys_to_move.count() == 0`), then the new property labels will contain
    2444             :     /// entries corresponding to the merged blocks only. For example, merging a
    2445             :     /// block with key `a=0` and properties `p=1, 2` with a block with key `a=2`
    2446             :     /// and properties `p=1, 3` will produce a block with properties
    2447             :     /// `a, p = (0, 1), (0, 2), (2, 1), (2, 3)`.
    2448             :     ///
    2449             :     /// If `keys_to_move` contains entries, then the property labels must be the
    2450             :     /// same for all the merged blocks. In that case, the merged property labels
    2451             :     /// will contains each of the entries of `keys_to_move` and then the current
    2452             :     /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
    2453             :     /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
    2454             :     /// (3, 1), (3, 2)`.
    2455             :     ///
    2456             :     /// The new sample labels will contains all of the merged blocks sample
    2457             :     /// labels. The order of the samples is controlled by `sort_samples`. If
    2458             :     /// `sort_samples` is true, samples are re-ordered to keep them
    2459             :     /// lexicographically sorted. Otherwise they are kept in the order in which
    2460             :     /// they appear in the blocks.
    2461             :     ///
    2462             :     /// @param keys_to_move description of the keys to move
    2463             :     /// @param sort_samples whether to sort the merged samples or keep them in
    2464             :     ///                     the order in which they appear in the original blocks
    2465             :     TensorMap keys_to_properties(const Labels& keys_to_move, bool sort_samples = true) const {
    2466             :         auto* ptr = mts_tensormap_keys_to_properties(
    2467             :             tensor_,
    2468             :             keys_to_move.as_mts_labels_t(),
    2469             :             sort_samples
    2470             :         );
    2471             : 
    2472             :         details::check_pointer(ptr);
    2473             :         return TensorMap(ptr);
    2474             :     }
    2475             : 
    2476             :     /// This function calls `keys_to_properties` with an empty set of `Labels`
    2477             :     /// with the dimensions defined in `keys_to_move`
    2478             :     TensorMap keys_to_properties(const std::vector<std::string>& keys_to_move, bool sort_samples = true) const {
    2479             :         return keys_to_properties(Labels(keys_to_move), sort_samples);
    2480             :     }
    2481             : 
    2482             :     /// This function calls `keys_to_properties` with an empty set of `Labels`
    2483             :     /// with a single dimension: `key_to_move`
    2484             :     TensorMap keys_to_properties(const std::string& key_to_move, bool sort_samples = true) const {
    2485             :         return keys_to_properties(std::vector<std::string>{key_to_move}, sort_samples);
    2486             :     }
    2487             : 
    2488             :     /// Merge blocks with the same value for selected keys dimensions along the
    2489             :     /// samples axis.
    2490             :     ///
    2491             :     /// The dimensions (names) of `keys_to_move` will be moved from the keys to
    2492             :     /// the sample labels, and blocks with the same remaining keys dimensions
    2493             :     /// will be merged together along the sample axis.
    2494             :     ///
    2495             :     /// If `keys_to_move` must be an empty set of `Labels`
    2496             :     /// (`keys_to_move.count() == 0`). The new sample labels will contain
    2497             :     /// entries corresponding to the merged blocks' keys.
    2498             :     ///
    2499             :     /// The order of the samples is controlled by `sort_samples`. If
    2500             :     /// `sort_samples` is true, samples are re-ordered to keep them
    2501             :     /// lexicographically sorted. Otherwise they are kept in the order in which
    2502             :     /// they appear in the blocks.
    2503             :     ///
    2504             :     /// This function is only implemented if all merged block have the same
    2505             :     /// property labels.
    2506             :     ///
    2507             :     /// @param keys_to_move description of the keys to move
    2508             :     /// @param sort_samples whether to sort the merged samples or keep them in
    2509             :     ///                     the order in which they appear in the original blocks
    2510             :     TensorMap keys_to_samples(const Labels& keys_to_move, bool sort_samples = true) const {
    2511             :         auto* ptr = mts_tensormap_keys_to_samples(
    2512             :             tensor_,
    2513             :             keys_to_move.as_mts_labels_t(),
    2514             :             sort_samples
    2515             :         );
    2516             : 
    2517             :         details::check_pointer(ptr);
    2518             :         return TensorMap(ptr);
    2519             :     }
    2520             : 
    2521             :     /// This function calls `keys_to_samples` with an empty set of `Labels`
    2522             :     /// with the dimensions defined in `keys_to_move`
    2523             :     TensorMap keys_to_samples(const std::vector<std::string>& keys_to_move, bool sort_samples = true) const {
    2524             :         return keys_to_samples(Labels(keys_to_move), sort_samples);
    2525             :     }
    2526             : 
    2527             :     /// This function calls `keys_to_samples` with an empty set of `Labels`
    2528             :     /// with a single dimension: `key_to_move`
    2529             :     TensorMap keys_to_samples(const std::string& key_to_move, bool sort_samples = true) const {
    2530             :         return keys_to_samples(std::vector<std::string>{key_to_move}, sort_samples);
    2531             :     }
    2532             : 
    2533             :     /// Move the given `dimensions` from the component labels to the property
    2534             :     /// labels for each block.
    2535             :     ///
    2536             :     /// @param dimensions name of the component dimensions to move to the
    2537             :     ///                  properties
    2538             :     TensorMap components_to_properties(const std::vector<std::string>& dimensions) const {
    2539             :         auto c_dimensions = std::vector<const char*>();
    2540             :         for (const auto& v: dimensions) {
    2541             :             c_dimensions.push_back(v.c_str());
    2542             :         }
    2543             : 
    2544             :         auto* ptr = mts_tensormap_components_to_properties(
    2545             :             tensor_,
    2546             :             c_dimensions.data(),
    2547             :             c_dimensions.size()
    2548             :         );
    2549             :         details::check_pointer(ptr);
    2550             :         return TensorMap(ptr);
    2551             :     }
    2552             : 
    2553             :     /// Call `components_to_properties` with a single dimension
    2554             :     TensorMap components_to_properties(const std::string& dimension) const {
    2555             :         const char* c_str = dimension.c_str();
    2556             :         auto* ptr = mts_tensormap_components_to_properties(
    2557             :             tensor_,
    2558             :             &c_str,
    2559             :             1
    2560             :         );
    2561             :         details::check_pointer(ptr);
    2562             :         return TensorMap(ptr);
    2563             :     }
    2564             : 
    2565             :     /*!
    2566             :      * \verbatim embed:rst:leading-asterisk
    2567             :      *
    2568             :      * Load a previously saved ``TensorMap`` from the given path.
    2569             :      *
    2570             :      * This is identical to :cpp:func:`metatensor::io::load`, and provided as a
    2571             :      * convenience API.
    2572             :      *
    2573             :      * \endverbatim
    2574             :      */
    2575             :     static TensorMap load(
    2576             :         const std::string& path,
    2577             :         mts_create_array_callback_t create_array = details::default_create_array
    2578             :     ) {
    2579             :         return metatensor::io::load(path, create_array);
    2580             :     }
    2581             : 
    2582             :     /*!
    2583             :      * \verbatim embed:rst:leading-asterisk
    2584             :      *
    2585             :      * Load a previously saved ``TensorMap`` from a in-memory buffer.
    2586             :      *
    2587             :      * This is identical to :cpp:func:`metatensor::io::load_buffer`, and
    2588             :      * provided as a convenience API.
    2589             :      *
    2590             :      * \endverbatim
    2591             :      */
    2592             :     static TensorMap load_buffer(
    2593             :         const uint8_t* buffer,
    2594             :         size_t buffer_count,
    2595             :         mts_create_array_callback_t create_array = details::default_create_array
    2596             :     ) {
    2597             :         return metatensor::io::load_buffer(buffer, buffer_count, create_array);
    2598             :     }
    2599             : 
    2600             :     /*!
    2601             :      * \verbatim embed:rst:leading-asterisk
    2602             :      *
    2603             :      * Load a previously saved ``TensorMap`` from a in-memory buffer.
    2604             :      *
    2605             :      * This is identical to :cpp:func:`metatensor::io::load_buffer`, and
    2606             :      * provided as a convenience API.
    2607             :      *
    2608             :      * \endverbatim
    2609             :      */
    2610             :     template <typename Buffer>
    2611             :     static TensorMap load_buffer(
    2612             :         const Buffer& buffer,
    2613             :         mts_create_array_callback_t create_array = details::default_create_array
    2614             :     ) {
    2615             :         return metatensor::io::load_buffer<Buffer>(buffer, create_array);
    2616             :     }
    2617             : 
    2618             :     /*!
    2619             :      * \verbatim embed:rst:leading-asterisk
    2620             :      *
    2621             :      * Save this ``TensorMap`` to the given path.
    2622             :      *
    2623             :      * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
    2624             :      * convenience API.
    2625             :      *
    2626             :      * \endverbatim
    2627             :      */
    2628             :     void save(const std::string& path) const {
    2629             :         metatensor::io::save(path, *this);
    2630             :     }
    2631             : 
    2632             :     /*!
    2633             :      * \verbatim embed:rst:leading-asterisk
    2634             :      *
    2635             :      * Save this ``TensorMap`` to an in-memory buffer.
    2636             :      *
    2637             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    2638             :      * provided as a convenience API.
    2639             :      *
    2640             :      * \endverbatim
    2641             :      */
    2642             :     std::vector<uint8_t> save_buffer() const {
    2643             :         return metatensor::io::save_buffer(*this);
    2644             :     }
    2645             : 
    2646             :     /*!
    2647             :      * \verbatim embed:rst:leading-asterisk
    2648             :      *
    2649             :      * Save this ``TensorMap`` to an in-memory buffer.
    2650             :      *
    2651             :      * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
    2652             :      * provided as a convenience API.
    2653             :      *
    2654             :      * \endverbatim
    2655             :      */
    2656             :     template <typename Buffer>
    2657             :     Buffer save_buffer() const {
    2658             :         return metatensor::io::save_buffer<Buffer>(*this);
    2659             :     }
    2660             : 
    2661             :     /// Get the `mts_tensormap_t` pointer corresponding to this `TensorMap`.
    2662             :     ///
    2663             :     /// The tensor map pointer is still managed by the current `TensorMap`
    2664             :     mts_tensormap_t* as_mts_tensormap_t() & {
    2665             :         return tensor_;
    2666             :     }
    2667             : 
    2668             :     /// Get the const `mts_tensormap_t` pointer corresponding to this `TensorMap`.
    2669             :     ///
    2670             :     /// The tensor map pointer is still managed by the current `TensorMap`
    2671             :     const mts_tensormap_t* as_mts_tensormap_t() const & {
    2672             :         return tensor_;
    2673             :     }
    2674             : 
    2675             :     mts_tensormap_t* as_mts_tensormap_t() && = delete;
    2676             : 
    2677             :     /// Create a C++ TensorMap from a C `mts_tensormap_t` pointer. The C++
    2678             :     /// tensor map takes ownership of the C pointer.
    2679             :     explicit TensorMap(mts_tensormap_t* tensor): tensor_(tensor) {}
    2680             : 
    2681             : private:
    2682             :     mts_tensormap_t* tensor_;
    2683             : };
    2684             : 
    2685             : 
    2686             : /******************************************************************************/
    2687             : /******************************************************************************/
    2688             : /*                                                                            */
    2689             : /*                   I/O functionalities implementation                       */
    2690             : /*                                                                            */
    2691             : /******************************************************************************/
    2692             : /******************************************************************************/
    2693             : 
    2694             : 
    2695             : namespace io {
    2696             :     inline void save(const std::string& path, const TensorMap& tensor) {
    2697             :         details::check_status(mts_tensormap_save(path.c_str(), tensor.as_mts_tensormap_t()));
    2698             :     }
    2699             : 
    2700             :     template <typename Buffer>
    2701             :     Buffer save_buffer(const TensorMap& tensor) {
    2702             :         auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(tensor);
    2703             :         return Buffer(buffer.begin(), buffer.end());
    2704             :     }
    2705             : 
    2706             :     template<>
    2707             :     inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorMap& tensor) {
    2708             :         std::vector<uint8_t> buffer;
    2709             : 
    2710             :         auto* ptr = buffer.data();
    2711             :         auto size = buffer.size();
    2712             : 
    2713             :         auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
    2714             :             auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
    2715             :             buffer->resize(new_size, '\0');
    2716             :             return buffer->data();
    2717             :         };
    2718             : 
    2719             :         details::check_status(mts_tensormap_save_buffer(
    2720             :             &ptr,
    2721             :             &size,
    2722             :             &buffer,
    2723             :             realloc,
    2724             :             tensor.as_mts_tensormap_t()
    2725             :         ));
    2726             : 
    2727             :         buffer.resize(size, '\0');
    2728             : 
    2729             :         return buffer;
    2730             :     }
    2731             : 
    2732             :     /**************************************************************************/
    2733             : 
    2734             :     inline void save(const std::string& path, const TensorBlock& block) {
    2735             :         details::check_status(mts_block_save(path.c_str(), block.as_mts_block_t()));
    2736             :     }
    2737             : 
    2738             :     template <typename Buffer>
    2739             :     Buffer save_buffer(const TensorBlock& block) {
    2740             :         auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(block);
    2741             :         return Buffer(buffer.begin(), buffer.end());
    2742             :     }
    2743             : 
    2744             :     template<>
    2745             :     inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorBlock& block) {
    2746             :         std::vector<uint8_t> buffer;
    2747             : 
    2748             :         auto* ptr = buffer.data();
    2749             :         auto size = buffer.size();
    2750             : 
    2751             :         auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
    2752             :             auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
    2753             :             buffer->resize(new_size, '\0');
    2754             :             return buffer->data();
    2755             :         };
    2756             : 
    2757             :         details::check_status(mts_block_save_buffer(
    2758             :             &ptr,
    2759             :             &size,
    2760             :             &buffer,
    2761             :             realloc,
    2762             :             block.as_mts_block_t()
    2763             :         ));
    2764             : 
    2765             :         buffer.resize(size, '\0');
    2766             : 
    2767             :         return buffer;
    2768             :     }
    2769             : 
    2770             :     /**************************************************************************/
    2771             : 
    2772             :     inline void save(const std::string& path, const Labels& labels) {
    2773             :         details::check_status(mts_labels_save(path.c_str(), labels.as_mts_labels_t()));
    2774             :     }
    2775             : 
    2776             :     template <typename Buffer>
    2777             :     Buffer save_buffer(const Labels& labels) {
    2778             :         auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(labels);
    2779             :         return Buffer(buffer.begin(), buffer.end());
    2780             :     }
    2781             : 
    2782             :     template<>
    2783             :     inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const Labels& labels) {
    2784             :         std::vector<uint8_t> buffer;
    2785             : 
    2786             :         auto* ptr = buffer.data();
    2787             :         auto size = buffer.size();
    2788             : 
    2789             :         auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
    2790             :             auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
    2791             :             buffer->resize(new_size, '\0');
    2792             :             return buffer->data();
    2793             :         };
    2794             : 
    2795             :         details::check_status(mts_labels_save_buffer(
    2796             :             &ptr,
    2797             :             &size,
    2798             :             &buffer,
    2799             :             realloc,
    2800             :             labels.as_mts_labels_t()
    2801             :         ));
    2802             : 
    2803             :         buffer.resize(size, '\0');
    2804             : 
    2805             :         return buffer;
    2806             :     }
    2807             : 
    2808             :     /**************************************************************************/
    2809             :     /**************************************************************************/
    2810             : 
    2811             :     inline TensorMap load(
    2812             :         const std::string& path,
    2813             :         mts_create_array_callback_t create_array
    2814             :     ) {
    2815             :         auto* ptr = mts_tensormap_load(path.c_str(), create_array);
    2816             :         details::check_pointer(ptr);
    2817             :         return TensorMap(ptr);
    2818             :     }
    2819             : 
    2820             :     inline TensorMap load_buffer(
    2821             :         const uint8_t* buffer,
    2822             :         size_t buffer_count,
    2823             :         mts_create_array_callback_t create_array
    2824             :     ) {
    2825             :         auto* ptr = mts_tensormap_load_buffer(buffer, buffer_count, create_array);
    2826             :         details::check_pointer(ptr);
    2827             :         return TensorMap(ptr);
    2828             :     }
    2829             : 
    2830             :     template <typename Buffer>
    2831             :     TensorMap load_buffer(
    2832             :         const Buffer& buffer,
    2833             :         mts_create_array_callback_t create_array
    2834             :     ) {
    2835             :         static_assert(
    2836             :             sizeof(typename Buffer::value_type) == sizeof(uint8_t),
    2837             :             "`Buffer` must be a container of uint8_t or equivalent"
    2838             :         );
    2839             : 
    2840             :         return metatensor::io::load_buffer(
    2841             :             reinterpret_cast<const uint8_t*>(buffer.data()),
    2842             :             buffer.size(),
    2843             :             create_array
    2844             :         );
    2845             :     }
    2846             : 
    2847             :     /**************************************************************************/
    2848             : 
    2849             :     inline TensorBlock load_block(
    2850             :         const std::string& path,
    2851             :         mts_create_array_callback_t create_array
    2852             :     ) {
    2853             :         auto* ptr = mts_block_load(path.c_str(), create_array);
    2854             :         details::check_pointer(ptr);
    2855             :         return TensorBlock(ptr);
    2856             :     }
    2857             : 
    2858             :     inline TensorBlock load_block_buffer(
    2859             :         const uint8_t* buffer,
    2860             :         size_t buffer_count,
    2861             :         mts_create_array_callback_t create_array
    2862             :     ) {
    2863             :         auto* ptr = mts_block_load_buffer(buffer, buffer_count, create_array);
    2864             :         details::check_pointer(ptr);
    2865             :         return TensorBlock(ptr);
    2866             :     }
    2867             : 
    2868             :     template <typename Buffer>
    2869             :     TensorBlock load_block_buffer(
    2870             :         const Buffer& buffer,
    2871             :         mts_create_array_callback_t create_array
    2872             :     ) {
    2873             :         static_assert(
    2874             :             sizeof(typename Buffer::value_type) == sizeof(uint8_t),
    2875             :             "`Buffer` must be a container of uint8_t or equivalent"
    2876             :         );
    2877             : 
    2878             :         return metatensor::io::load_block_buffer(
    2879             :             reinterpret_cast<const uint8_t*>(buffer.data()),
    2880             :             buffer.size(),
    2881             :             create_array
    2882             :         );
    2883             :     }
    2884             : 
    2885             :     /**************************************************************************/
    2886             : 
    2887             :     inline Labels load_labels(const std::string& path) {
    2888             :         mts_labels_t labels;
    2889             :         std::memset(&labels, 0, sizeof(labels));
    2890             : 
    2891             :         details::check_status(mts_labels_load(
    2892             :             path.c_str(), &labels
    2893             :         ));
    2894             : 
    2895             :         return Labels(labels);
    2896             :     }
    2897             : 
    2898             :     inline Labels load_labels_buffer(const uint8_t* buffer, size_t buffer_count) {
    2899             :         mts_labels_t labels;
    2900             :         std::memset(&labels, 0, sizeof(labels));
    2901             : 
    2902             :         details::check_status(mts_labels_load_buffer(
    2903             :             buffer, buffer_count, &labels
    2904             :         ));
    2905             : 
    2906             :         return Labels(labels);
    2907             :     }
    2908             : 
    2909             :     template <typename Buffer>
    2910             :     Labels load_labels_buffer(const Buffer& buffer) {
    2911             :         static_assert(
    2912             :             sizeof(typename Buffer::value_type) == sizeof(uint8_t),
    2913             :             "`Buffer` must be a container of uint8_t or equivalent"
    2914             :         );
    2915             : 
    2916             :         return metatensor::io::load_labels_buffer(
    2917             :             reinterpret_cast<const uint8_t*>(buffer.data()),
    2918             :             buffer.size()
    2919             :         );
    2920             :     }
    2921             : }
    2922             : 
    2923             : }
    2924             : 
    2925             : #endif /* METATENSOR_HPP */

Generated by: LCOV version 1.16