Line data Source code
1 : #ifndef METATENSOR_HPP
2 : #define METATENSOR_HPP
3 :
4 : #include <array>
5 : #include <vector>
6 : #include <string>
7 : #include <memory>
8 : #include <stdexcept>
9 : #include <exception>
10 : #include <functional>
11 : #include <type_traits>
12 : #include <initializer_list>
13 :
14 : #include <cassert>
15 : #include <cstdint>
16 : #include <cstring>
17 : #include <cstdlib>
18 :
19 : #include "metatensor.h"
20 :
21 : /// This file contains the C++ API to metatensor, manually built on top of the C
22 : /// API defined in `metatensor.h`. This API uses the standard C++ library where
23 : /// convenient, but also allow to drop back to the C API if required, by
24 : /// providing functions to extract the C API handles (named `as_mts_XXX`).
25 :
26 : static_assert(sizeof(char) == sizeof(uint8_t), "char must be 8-bits wide");
27 :
28 : namespace metatensor_torch {
29 : class LabelsHolder;
30 : class TensorBlockHolder;
31 : class TensorMapHolder;
32 : }
33 :
34 : namespace metatensor {
35 : class Labels;
36 : class TensorMap;
37 : class TensorBlock;
38 :
39 : /// Tag for the creation of Labels without uniqueness checks
40 : struct assume_unique {};
41 :
42 : /// Exception class used for all errors in metatensor
43 : class Error: public std::runtime_error {
44 : public:
45 : /// Create a new Error with the given `message`
46 0 : Error(const std::string& message): std::runtime_error(message) {}
47 : };
48 :
49 : namespace details {
50 : /// Singleton class storing the last exception throw by a C++ callback.
51 : ///
52 : /// When passing callbacks from C++ to Rust, we need to convert exceptions
53 : /// into status code (see the `catch` blocks in this file). This class
54 : /// allows to save the message associated with an exception, and rethrow an
55 : /// exception with the same message later (the actual exception type is lost
56 : /// in the process).
57 : class LastCxxError {
58 : public:
59 : /// Set the last error message to `message`
60 : static void set_message(std::string message) {
61 : auto& stored_message = LastCxxError::get();
62 : stored_message = std::move(message);
63 : }
64 :
65 : /// Get the last error message
66 : static const std::string& message() {
67 0 : return LastCxxError::get();
68 : }
69 :
70 : private:
71 0 : static std::string& get() {
72 : #pragma clang diagnostic push
73 : #pragma clang diagnostic ignored "-Wexit-time-destructors"
74 : /// we are using a per-thread static value to store the last C++
75 : /// exception.
76 0 : static thread_local std::string STORED_MESSAGE;
77 : #pragma clang diagnostic pop
78 :
79 0 : return STORED_MESSAGE;
80 : }
81 : };
82 :
83 : /// Check if a return status from the C API indicates an error, and if it is
84 : /// the case, throw an exception of type `metatensor::Error` with the last
85 : /// error message from the library.
86 124 : inline void check_status(mts_status_t status) {
87 124 : if (status == MTS_SUCCESS) {
88 124 : return;
89 0 : } else if (status > 0) {
90 0 : throw Error(mts_last_error());
91 : } else { // status < 0
92 0 : throw Error("error in C++ callback: " + LastCxxError::message());
93 : }
94 : }
95 :
96 : /// Call the given `function` with the given `args` (the function should
97 : /// return an `mts_status_t`), catching any C++ exception, and translating
98 : /// them to negative metatensor error code.
99 : ///
100 : /// This is required to prevent callbacks unwinding through the C API.
101 : template<typename Function, typename ...Args>
102 : inline mts_status_t catch_exceptions(Function function, Args ...args) {
103 : try {
104 : return function(std::move(args)...);
105 : } catch (const std::exception& e) {
106 : details::LastCxxError::set_message(e.what());
107 : return -1;
108 : } catch (...) {
109 : details::LastCxxError::set_message("error was not an std::exception");
110 : return -128;
111 : }
112 : }
113 :
114 : /// Check if a pointer allocated by the C API is null, and if it is the
115 : /// case, throw an exception of type `metatensor::Error` with the last error
116 : /// message from the library.
117 : inline void check_pointer(const void* pointer) {
118 : if (pointer == nullptr) {
119 : throw Error(mts_last_error());
120 : }
121 : }
122 :
123 : /// Compute the product of all values in the `shape` vector
124 : inline size_t product(const std::vector<size_t>& shape) {
125 : size_t result = 1;
126 : for (auto size: shape) {
127 : result *= size;
128 : }
129 : return result;
130 : }
131 :
132 : /// Get the N-dimensional index corresponding to the given linear `index`
133 : /// and array `shape`
134 : inline std::vector<size_t> cartesian_index(const std::vector<size_t>& shape, size_t index) {
135 : auto result = std::vector<size_t>(shape.size(), 0);
136 : for (size_t i=0; i<shape.size(); i++) {
137 : result[i] = index % shape[i];
138 : index = index / shape[i];
139 : }
140 : assert(index == 0);
141 : return result;
142 : }
143 :
144 : /// Get the linear index corresponding to the N-dimensional
145 : /// `index[index_size]`, according to the given array `shape`
146 : inline size_t linear_index(const std::vector<size_t>& shape, const size_t* index, size_t index_size) {
147 : assert(index_size != 0);
148 : assert(index_size == shape.size());
149 :
150 : if (index_size == 1) {
151 : assert(index[0] < shape[0] && "out of bounds");
152 : return index[0];
153 : } else {
154 : assert(index[0] < shape[0]);
155 : auto linear_index = index[0];
156 : for (size_t i=1; i<index_size; i++) {
157 : assert(index[i] < shape[i] && "out of bounds");
158 : linear_index *= shape[i];
159 : linear_index += index[i];
160 : }
161 :
162 : return linear_index;
163 : }
164 : }
165 :
166 : template<size_t N>
167 : size_t linear_index(const std::vector<size_t>& shape, const std::array<size_t, N>& index) {
168 : return linear_index(shape, index.data(), index.size());
169 : }
170 :
171 : inline size_t linear_index(const std::vector<size_t>& shape, const std::vector<size_t>& index) {
172 : return linear_index(shape, index.data(), index.size());
173 : }
174 :
175 : Labels labels_from_cxx(const std::vector<std::string>& names, const int32_t* values, size_t count, bool assume_unique);
176 : }
177 :
178 : /******************************************************************************/
179 : /******************************************************************************/
180 : /******************************************************************************/
181 :
182 : /******************************************************************************/
183 : /******************************************************************************/
184 : /* */
185 : /* N-Dimensional arrays handling */
186 : /* */
187 : /******************************************************************************/
188 : /******************************************************************************/
189 :
190 :
191 : /// Simple N-dimensional array interface
192 : ///
193 : /// This class can either be a non-owning view inside some existing memory (for
194 : /// example memory allocated by Rust); or own its memory (in the form of an
195 : /// `std::vector<double>`). If the array does not own its memory, accessing it
196 : /// is only valid for as long as the memory is kept alive.
197 : ///
198 : /// The API of this class is very intentionally minimal to keep metatensor as
199 : /// simple as possible. Feel free to wrap the corresponding data inside types
200 : /// with richer API such as Eigen, Boost, etc.
201 : template<typename T>
202 : class NDArray {
203 : public:
204 : /// Create a new empty `NDArray`, with shape `[0, 0]`.
205 : NDArray(): NDArray(nullptr, {0, 0}, true) {}
206 :
207 : /// Create a new `NDArray` using a non-owning view in `const` memory with
208 : /// the given `shape`.
209 : ///
210 : /// `data` must point to contiguous memory containing the right number of
211 : /// elements as described by the `shape`, which will be interpreted as an
212 : /// N-dimensional array in row-major order. The resulting `NDArray` is
213 : /// only valid for as long as `data` is.
214 : NDArray(const T* data, std::vector<size_t> shape):
215 : NDArray(data, std::move(shape), /*is_const*/ true) {}
216 :
217 : /// Create a new `NDArray` using a non-owning view in non-`const` memory
218 : /// with the given `shape`.
219 : ///
220 : /// `data` must point to contiguous memory containing the right number of
221 : /// elements as described by the `shape`, which will be interpreted as an
222 : /// N-dimensional array in row-major order. The resulting `NDArray` is
223 : /// only valid for as long as `data` is.
224 : NDArray(T* data, std::vector<size_t> shape):
225 : NDArray(data, std::move(shape), /*is_const*/ false) {}
226 :
227 : /// Create a new `NDArray` *owning* its `data` with the given `shape`.
228 : NDArray(std::vector<T> data, std::vector<size_t> shape):
229 : NDArray(data.data(), std::move(shape), /*is_const*/ false)
230 : {
231 : using vector_t = std::vector<T>;
232 :
233 : owned_data_ = reinterpret_cast<void*>(new vector_t(std::move(data)));
234 : deleter_ = [](void* data){
235 : auto data_vector = reinterpret_cast<vector_t*>(data);
236 : delete data_vector;
237 : };
238 : }
239 :
240 : ~NDArray() {
241 : deleter_(this->owned_data_);
242 : }
243 :
244 : /// NDArray is not copy-constructible
245 : NDArray(const NDArray&) = delete;
246 : /// NDArray can not be copy-assigned
247 : NDArray& operator=(const NDArray& other) = delete;
248 :
249 : /// NDArray is move-constructible
250 : NDArray(NDArray&& other) noexcept: NDArray() {
251 : *this = std::move(other);
252 : }
253 :
254 : /// NDArray can be move-assigned
255 : NDArray& operator=(NDArray&& other) noexcept {
256 : this->deleter_(this->owned_data_);
257 :
258 : this->data_ = std::move(other.data_);
259 : this->shape_ = std::move(other.shape_);
260 : this->is_const_ = other.is_const_;
261 : this->owned_data_ = other.owned_data_;
262 : this->deleter_ = std::move(other.deleter_);
263 :
264 : other.data_ = nullptr;
265 : other.owned_data_ = nullptr;
266 : other.deleter_ = [](void*){};
267 :
268 : return *this;
269 : }
270 :
271 : /// Is this NDArray a view into external data?
272 : bool is_view() const {
273 : return owned_data_ == nullptr;
274 : }
275 :
276 : /// Get the value inside this `NDArray` at the given index
277 : ///
278 : /// ```
279 : /// auto array = NDArray(...);
280 : ///
281 : /// double value = array(2, 3, 1);
282 : /// ```
283 : template<typename ...Args>
284 : T operator()(Args... args) const & {
285 : auto index = std::array<size_t, sizeof... (Args)>{static_cast<size_t>(args)...};
286 : if (index.size() != shape_.size()) {
287 : throw Error(
288 : "expected " + std::to_string(shape_.size()) +
289 : " indexes in NDArray::operator(), got " + std::to_string(index.size())
290 : );
291 : }
292 : return data_[details::linear_index(shape_, index)];
293 : }
294 :
295 : /// Get a reference to the value inside this `NDArray` at the given index
296 : ///
297 : /// ```
298 : /// auto array = NDArray(...);
299 : ///
300 : /// array(2, 3, 1) = 5.2;
301 : /// ```
302 : template<typename ...Args>
303 : T& operator()(Args... args) & {
304 : if (is_const_) {
305 : throw Error("This NDArray is const, can not get non const access to it");
306 : }
307 :
308 : auto index = std::array<size_t, sizeof... (Args)>{static_cast<size_t>(args)...};
309 : if (index.size() != shape_.size()) {
310 : throw Error(
311 : "expected " + std::to_string(shape_.size()) +
312 : " indexes in Labels::operator(), got " + std::to_string(index.size())
313 : );
314 : }
315 : return data_[details::linear_index(shape_, index)];
316 : }
317 :
318 : template<typename ...Args>
319 : T& operator()(Args... args) && = delete;
320 :
321 : /// Get the data pointer for this array, i.e. the pointer to the first
322 : /// element.
323 : const T* data() const & {
324 : return data_;
325 : }
326 :
327 : /// Get the data pointer for this array, i.e. the pointer to the first
328 : /// element.
329 : T* data() & {
330 : if (is_const_) {
331 : throw Error("This NDArray is const, can not get non const access to it");
332 : }
333 : return data_;
334 : }
335 :
336 : const T* data() && = delete;
337 :
338 : /// Get the shape of this array
339 : const std::vector<size_t>& shape() const & {
340 : return shape_;
341 : }
342 :
343 : const std::vector<size_t>& shape() && = delete;
344 :
345 : /// Check if this array is empty, i.e. if at least one of the shape element
346 : /// is 0.
347 : bool is_empty() const {
348 : for (auto s: shape_) {
349 : if (s == 0) {
350 : return true;
351 : }
352 : }
353 : return false;
354 : }
355 :
356 : private:
357 : /// Create an `NDArray` from a pointer to the (row-major) data & shape.
358 : ///
359 : /// The `is_const` parameter controls whether this class should allow
360 : /// non-const access to the data.
361 : NDArray(const T* data, std::vector<size_t> shape, bool is_const):
362 : data_(const_cast<T*>(data)),
363 : shape_(std::move(shape)),
364 : is_const_(is_const),
365 : deleter_([](void*){})
366 : {
367 : validate();
368 : }
369 :
370 : /// Create a 2D NDArray from a vector of initializer lists. All the inner
371 : /// lists must have the same `size`.
372 : ///
373 : /// This allows creating an array (and in particular a set of Labels) with
374 : /// `NDArray({{1, 2}, {3, 4}, {5, 6}}, 2)`
375 : NDArray(const std::vector<std::initializer_list<T>>& data, size_t size):
376 : data_(nullptr),
377 : shape_({data.size(), size}),
378 : is_const_(false),
379 : deleter_([](void*){})
380 : {
381 : using vector_t = std::vector<T>;
382 : auto vector = std::vector<T>();
383 : vector.reserve(data.size() * size);
384 : for (auto row: std::move(data)) {
385 : if (row.size() != size) {
386 : throw Error(
387 : "invalid size for row: expected " + std::to_string(size) +
388 : " got " + std::to_string(row.size())
389 : );
390 : }
391 :
392 : for (auto entry: row) {
393 : vector.emplace_back(entry);
394 : }
395 : }
396 :
397 : data_ = vector.data();
398 : owned_data_ = reinterpret_cast<void*>(new vector_t(std::move(vector)));
399 : deleter_ = [](void* data){
400 : auto data_vector = reinterpret_cast<vector_t*>(data);
401 : delete data_vector;
402 : };
403 : validate();
404 : }
405 :
406 : friend class Labels;
407 :
408 : void validate() const {
409 : static_assert(
410 : std::is_arithmetic<T>::value,
411 : "NDArray only works with integers and floating points"
412 : );
413 :
414 : if (shape_.empty()) {
415 : throw Error("invalid parameters to NDArray, shape should contain at least one element");
416 : }
417 :
418 : size_t size = 1;
419 : for (auto s: shape_) {
420 : size *= s;
421 : }
422 :
423 : if (size != 0 && data_ == nullptr) {
424 : throw Error("invalid parameters to NDArray, got null data pointer and non zero size");
425 : }
426 : }
427 :
428 : /// Pointer to the data used by this array
429 : T* data_ = nullptr;
430 : /// Full shape of this array
431 : std::vector<size_t> shape_ = {0, 0};
432 : /// Is this array const? This will dynamically prevent calling non-const
433 : /// function on it.
434 : bool is_const_ = true;
435 : /// Type-erased owned data for this array. This is a `nullptr` if this array
436 : /// is a view.
437 : void* owned_data_ = nullptr;
438 : /// Custom delete function for the `owned_data_`.
439 : std::function<void(void*)> deleter_;
440 : };
441 :
442 :
443 : /// Compare this `NDArray` with another `NDarray`. The array are equal if
444 : /// and only if both the shape and data are equal.
445 : template<typename T>
446 : bool operator==(const NDArray<T>& lhs, const NDArray<T>& rhs) {
447 : if (lhs.shape() != rhs.shape()) {
448 : return false;
449 : }
450 : return std::memcmp(lhs.data(), rhs.data(), sizeof(T) * details::product(lhs.shape())) == 0;
451 : }
452 :
453 : /// Compare this `NDArray` with another `NDarray`. The array are equal if
454 : /// and only if both the shape and data are equal.
455 : template<typename T>
456 : bool operator!=(const NDArray<T>& lhs, const NDArray<T>& rhs) {
457 : return !(lhs == rhs);
458 : }
459 :
460 :
461 : /// `DataArrayBase` manages n-dimensional arrays used as data in a block or
462 : /// tensor map. The array itself if opaque to this library and can come from
463 : /// multiple sources: Rust program, a C/C++ program, a Fortran program, Python
464 : /// with numpy or torch. The data does not have to live on CPU, or even on the
465 : /// same machine where this code is executed.
466 : ///
467 : /// **WARNING**: all function implementations **MUST** be thread-safe, and can
468 : /// be called from multiple threads at the same time. The `DataArrayBase` itself
469 : /// might be moved from one thread to another.
470 : class DataArrayBase {
471 : public:
472 : DataArrayBase() = default;
473 : virtual ~DataArrayBase() = default;
474 :
475 : /// DataArrayBase can be copy-constructed
476 : DataArrayBase(const DataArrayBase&) = default;
477 : /// DataArrayBase can be copy-assigned
478 : DataArrayBase& operator=(const DataArrayBase&) = default;
479 : /// DataArrayBase can be move-constructed
480 : DataArrayBase(DataArrayBase&&) noexcept = default;
481 : /// DataArrayBase can be move-assigned
482 : DataArrayBase& operator=(DataArrayBase&&) noexcept = default;
483 :
484 : /// Convert a concrete `DataArrayBase` to a C-compatible `mts_array_t`
485 : ///
486 : /// The `mts_array_t` takes ownership of the data, which should be released
487 : /// with `mts_array_t::destroy`.
488 : static mts_array_t to_mts_array_t(std::unique_ptr<DataArrayBase> data) {
489 : mts_array_t array;
490 : std::memset(&array, 0, sizeof(array));
491 :
492 : array.ptr = data.release();
493 :
494 : array.destroy = [](void* array) {
495 : auto ptr = std::unique_ptr<DataArrayBase>(static_cast<DataArrayBase*>(array));
496 : // let ptr go out of scope
497 : };
498 :
499 : array.origin = [](const void* array, mts_data_origin_t* origin) {
500 : return details::catch_exceptions([](const void* array, mts_data_origin_t* origin){
501 : const auto* cxx_array = static_cast<const DataArrayBase*>(array);
502 : *origin = cxx_array->origin();
503 : return MTS_SUCCESS;
504 : }, array, origin);
505 : };
506 :
507 : array.copy = [](const void* array, mts_array_t* new_array) {
508 : return details::catch_exceptions([](const void* array, mts_array_t* new_array){
509 : const auto* cxx_array = static_cast<const DataArrayBase*>(array);
510 : auto copy = cxx_array->copy();
511 : *new_array = DataArrayBase::to_mts_array_t(std::move(copy));
512 : return MTS_SUCCESS;
513 : }, array, new_array);
514 : };
515 :
516 : array.create = [](const void* array, const uintptr_t* shape, uintptr_t shape_count, mts_array_t* new_array) {
517 : return details::catch_exceptions([](
518 : const void* array,
519 : const uintptr_t* shape,
520 : uintptr_t shape_count,
521 : mts_array_t* new_array
522 : ) {
523 : const auto* cxx_array = static_cast<const DataArrayBase*>(array);
524 : auto cxx_shape = std::vector<size_t>();
525 : for (size_t i=0; i<static_cast<size_t>(shape_count); i++) {
526 : cxx_shape.push_back(static_cast<size_t>(shape[i]));
527 : }
528 : auto copy = cxx_array->create(std::move(cxx_shape));
529 : *new_array = DataArrayBase::to_mts_array_t(std::move(copy));
530 : return MTS_SUCCESS;
531 : }, array, shape, shape_count, new_array);
532 : };
533 :
534 : array.data = [](void* array, double** data) {
535 : return details::catch_exceptions([](void* array, double** data){
536 : auto* cxx_array = static_cast<DataArrayBase*>(array);
537 : *data = cxx_array->data();
538 : return MTS_SUCCESS;
539 : }, array, data);
540 : };
541 :
542 : array.shape = [](const void* array, const uintptr_t** shape, uintptr_t* shape_count) {
543 : return details::catch_exceptions([](const void* array, const uintptr_t** shape, uintptr_t* shape_count){
544 : const auto* cxx_array = static_cast<const DataArrayBase*>(array);
545 : const auto& cxx_shape = cxx_array->shape();
546 : *shape = cxx_shape.data();
547 : *shape_count = static_cast<uintptr_t>(cxx_shape.size());
548 : return MTS_SUCCESS;
549 : }, array, shape, shape_count);
550 : };
551 :
552 : array.reshape = [](void* array, const uintptr_t* shape, uintptr_t shape_count) {
553 : return details::catch_exceptions([](void* array, const uintptr_t* shape, uintptr_t shape_count){
554 : auto* cxx_array = static_cast<DataArrayBase*>(array);
555 : auto cxx_shape = std::vector<uintptr_t>(shape, shape + shape_count);
556 : cxx_array->reshape(std::move(cxx_shape));
557 : return MTS_SUCCESS;
558 : }, array, shape, shape_count);
559 : };
560 :
561 : array.swap_axes = [](void* array, uintptr_t axis_1, uintptr_t axis_2) {
562 : return details::catch_exceptions([](void* array, uintptr_t axis_1, uintptr_t axis_2){
563 : auto* cxx_array = static_cast<DataArrayBase*>(array);
564 : cxx_array->swap_axes(axis_1, axis_2);
565 : return MTS_SUCCESS;
566 : }, array, axis_1, axis_2);
567 : };
568 :
569 : array.move_samples_from = [](
570 : void* array,
571 : const void* input,
572 : const mts_sample_mapping_t* samples,
573 : uintptr_t samples_count,
574 : uintptr_t property_start,
575 : uintptr_t property_end
576 : ) {
577 : return details::catch_exceptions([](
578 : void* array,
579 : const void* input,
580 : const mts_sample_mapping_t* samples,
581 : uintptr_t samples_count,
582 : uintptr_t property_start,
583 : uintptr_t property_end
584 : ) {
585 : auto* cxx_array = static_cast<DataArrayBase*>(array);
586 : const auto* cxx_input = static_cast<const DataArrayBase*>(input);
587 : auto cxx_samples = std::vector<mts_sample_mapping_t>(samples, samples + samples_count);
588 :
589 : cxx_array->move_samples_from(*cxx_input, cxx_samples, property_start, property_end);
590 : return MTS_SUCCESS;
591 : }, array, input, samples, samples_count, property_start, property_end);
592 : };
593 :
594 : return array;
595 : }
596 :
597 : /// Get "data origin" for this array in.
598 : ///
599 : /// Users of `DataArrayBase` should register a single data
600 : /// origin with `mts_register_data_origin`, and use it for all compatible
601 : /// arrays.
602 : virtual mts_data_origin_t origin() const = 0;
603 :
604 : /// Make a copy of this DataArrayBase and return the new array. The new
605 : /// array is expected to have the same data origin and parameters (data
606 : /// type, data location, etc.)
607 : virtual std::unique_ptr<DataArrayBase> copy() const = 0;
608 :
609 : /// Create a new array with the same options as the current one (data type,
610 : /// data location, etc.) and the requested `shape`.
611 : ///
612 : /// The new array should be filled with zeros.
613 : virtual std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const = 0;
614 :
615 : /// Get a pointer to the underlying data storage.
616 : ///
617 : /// This function is allowed to fail if the data is not accessible in RAM,
618 : /// not stored as 64-bit floating point values, or not stored as a
619 : /// C-contiguous array.
620 : virtual double* data() & = 0;
621 :
622 : double* data() && = delete;
623 :
624 : /// Get the shape of this array
625 : virtual const std::vector<uintptr_t>& shape() const & = 0;
626 :
627 : const std::vector<uintptr_t>& shape() && = delete;
628 :
629 : /// Set the shape of this array to the given `shape`
630 : virtual void reshape(std::vector<uintptr_t> shape) = 0;
631 :
632 : /// Swap the axes `axis_1` and `axis_2` in this `array`.
633 : virtual void swap_axes(uintptr_t axis_1, uintptr_t axis_2) = 0;
634 :
635 : /// Set entries in the current array taking data from the `input` array.
636 : ///
637 : /// This array is guaranteed to be created by calling `mts_array_t::create`
638 : /// with one of the arrays in the same block or tensor map as the `input`.
639 : ///
640 : /// The `samples` indicate where the data should be moved from `input` to
641 : /// the current DataArrayBase.
642 : ///
643 : /// This function should copy data from `input[samples[i].input, ..., :]` to
644 : /// `array[samples[i].output, ..., property_start:property_end]` for `i` up
645 : /// to `samples_count`. All indexes are 0-based.
646 : virtual void move_samples_from(
647 : const DataArrayBase& input,
648 : std::vector<mts_sample_mapping_t> samples,
649 : uintptr_t property_start,
650 : uintptr_t property_end
651 : ) = 0;
652 : };
653 :
654 :
655 : /// Very basic implementation of DataArrayBase in C++.
656 : ///
657 : /// This is included as an example implementation of DataArrayBase, and to make
658 : /// metatensor usable without additional dependencies. For other uses cases, it
659 : /// might be better to implement DataArrayBase on your data, using
660 : /// functionalities from `Eigen`, `Boost.Array`, etc.
661 : class SimpleDataArray: public metatensor::DataArrayBase {
662 : public:
663 : /// Create a SimpleDataArray with the given `shape`, and all elements set to
664 : /// `value`
665 : SimpleDataArray(std::vector<uintptr_t> shape, double value = 0.0):
666 : shape_(std::move(shape)), data_(details::product(shape_), value) {}
667 :
668 : /// Create a SimpleDataArray with the given `shape` and `data`.
669 : ///
670 : /// The data is interpreted as a row-major n-dimensional array.
671 : SimpleDataArray(std::vector<uintptr_t> shape, std::vector<double> data):
672 : shape_(std::move(shape)),
673 : data_(std::move(data))
674 : {
675 : if (data_.size() != details::product(shape_)) {
676 : throw Error("the shape and size of the data don't match in SimpleDataArray");
677 : }
678 : }
679 :
680 : ~SimpleDataArray() override = default;
681 :
682 : /// SimpleDataArray can be copy-constructed
683 : SimpleDataArray(const SimpleDataArray&) = default;
684 : /// SimpleDataArray can be copy-assigned
685 : SimpleDataArray& operator=(const SimpleDataArray&) = default;
686 : /// SimpleDataArray can be move-constructed
687 : SimpleDataArray(SimpleDataArray&&) noexcept = default;
688 : /// SimpleDataArray can be move-assigned
689 : SimpleDataArray& operator=(SimpleDataArray&&) noexcept = default;
690 :
691 : mts_data_origin_t origin() const override {
692 : mts_data_origin_t origin = 0;
693 : mts_register_data_origin("metatensor::SimpleDataArray", &origin);
694 : return origin;
695 : }
696 :
697 : double* data() & override {
698 : return data_.data();
699 : }
700 :
701 : const std::vector<uintptr_t>& shape() const & override {
702 : return shape_;
703 : }
704 :
705 : void reshape(std::vector<uintptr_t> shape) override {
706 : if (details::product(shape_) != details::product(shape)) {
707 : throw metatensor::Error("invalid shape in reshape");
708 : }
709 : shape_ = std::move(shape);
710 : }
711 :
712 : void swap_axes(uintptr_t axis_1, uintptr_t axis_2) override {
713 : auto new_data = std::vector<double>(details::product(shape_), 0.0);
714 : auto new_shape = shape_;
715 : std::swap(new_shape[axis_1], new_shape[axis_2]);
716 :
717 : for (size_t i=0; i<details::product(shape_); i++) {
718 : auto index = details::cartesian_index(shape_, i);
719 : std::swap(index[axis_1], index[axis_2]);
720 :
721 : new_data[details::linear_index(new_shape, index)] = data_[i];
722 : }
723 :
724 : shape_ = std::move(new_shape);
725 : data_ = std::move(new_data);
726 : }
727 :
728 : std::unique_ptr<DataArrayBase> copy() const override {
729 : return std::unique_ptr<DataArrayBase>(new SimpleDataArray(*this));
730 : }
731 :
732 : std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const override {
733 : return std::unique_ptr<DataArrayBase>(new SimpleDataArray(std::move(shape)));
734 : }
735 :
736 : void move_samples_from(
737 : const DataArrayBase& input,
738 : std::vector<mts_sample_mapping_t> samples,
739 : uintptr_t property_start,
740 : uintptr_t property_end
741 : ) override {
742 : const auto& input_array = dynamic_cast<const SimpleDataArray&>(input);
743 : assert(input_array.shape_.size() == this->shape_.size());
744 :
745 : size_t property_count = property_end - property_start;
746 : size_t property_dim = shape_.size() - 1;
747 : assert(input_array.shape_[property_dim] == property_count);
748 :
749 : auto input_index = std::vector<size_t>(shape_.size(), 0);
750 : auto output_index = std::vector<size_t>(shape_.size(), 0);
751 :
752 : for (const auto& sample: samples) {
753 : input_index[0] = sample.input;
754 : output_index[0] = sample.output;
755 :
756 : if (property_dim == 1) {
757 : // no components
758 : for (size_t property_i=0; property_i<property_count; property_i++) {
759 : input_index[property_dim] = property_i;
760 : output_index[property_dim] = property_i + property_start;
761 :
762 : auto value = input_array.data_[details::linear_index(input_array.shape_, input_index)];
763 : this->data_[details::linear_index(shape_, output_index)] = value;
764 : }
765 : } else {
766 : auto last_component_dim = shape_.size() - 2;
767 : for (size_t component_i=1; component_i<shape_.size() - 1; component_i++) {
768 : input_index[component_i] = 0;
769 : }
770 :
771 : bool done = false;
772 : while (!done) {
773 : for (size_t component_i=1; component_i<shape_.size() - 1; component_i++) {
774 : output_index[component_i] = input_index[component_i];
775 : }
776 :
777 : for (size_t property_i=0; property_i<property_count; property_i++) {
778 : input_index[property_dim] = property_i;
779 : output_index[property_dim] = property_i + property_start;
780 :
781 : auto value = input_array.data_[details::linear_index(input_array.shape_, input_index)];
782 : this->data_[details::linear_index(shape_, output_index)] = value;
783 : }
784 :
785 : input_index[last_component_dim] += 1;
786 : for (size_t component_i=last_component_dim; component_i>2; component_i--) {
787 : if (input_index[component_i] >= shape_[component_i]) {
788 : input_index[component_i] = 0;
789 : input_index[component_i - 1] += 1;
790 : }
791 : }
792 :
793 : if (input_index[1] >= shape_[1]) {
794 : done = true;
795 : }
796 : }
797 : }
798 : }
799 : }
800 :
801 : /// Get a const view of the data managed by this SimpleDataArray
802 : NDArray<double> view() const {
803 : return NDArray<double>(data_.data(), shape_);
804 : }
805 :
806 : /// Get a mutable view of the data managed by this SimpleDataArray
807 : NDArray<double> view() {
808 : return NDArray<double>(data_.data(), shape_);
809 : }
810 :
811 : /// Extract a reference to SimpleDataArray out of an `mts_array_t`.
812 : ///
813 : /// This function fails if the `mts_array_t` does not contain a
814 : /// SimpleDataArray.
815 : static SimpleDataArray& from_mts_array(mts_array_t& array) {
816 : mts_data_origin_t origin = 0;
817 : auto status = array.origin(array.ptr, &origin);
818 : if (status != MTS_SUCCESS) {
819 : throw Error("failed to get data origin");
820 : }
821 :
822 : std::array<char, 64> buffer = {0};
823 : status = mts_get_data_origin(origin, buffer.data(), buffer.size());
824 : if (status != MTS_SUCCESS || std::string(buffer.data()) != "metatensor::SimpleDataArray") {
825 : throw Error("this array is not a metatensor::SimpleDataArray");
826 : }
827 :
828 : auto* base = static_cast<DataArrayBase*>(array.ptr);
829 : return dynamic_cast<SimpleDataArray&>(*base);
830 : }
831 :
832 : /// Extract a const reference to SimpleDataArray out of an `mts_array_t`.
833 : ///
834 : /// This function fails if the `mts_array_t` does not contain a
835 : /// SimpleDataArray.
836 : static const SimpleDataArray& from_mts_array(const mts_array_t& array) {
837 : mts_data_origin_t origin = 0;
838 : auto status = array.origin(array.ptr, &origin);
839 : if (status != MTS_SUCCESS) {
840 : throw Error("failed to get data origin");
841 : }
842 :
843 : std::array<char, 64> buffer = {0};
844 : status = mts_get_data_origin(origin, buffer.data(), buffer.size());
845 : if (status != MTS_SUCCESS || std::string(buffer.data()) != "metatensor::SimpleDataArray") {
846 : throw Error("this array is not a metatensor::SimpleDataArray");
847 : }
848 :
849 : const auto* base = static_cast<const DataArrayBase*>(array.ptr);
850 : return dynamic_cast<const SimpleDataArray&>(*base);
851 : }
852 :
853 : private:
854 : std::vector<uintptr_t> shape_;
855 : std::vector<double> data_;
856 :
857 : friend bool operator==(const SimpleDataArray& lhs, const SimpleDataArray& rhs);
858 : };
859 :
860 : /// Two SimpleDataArray compare as equal if they have the exact same shape and
861 : /// data.
862 : inline bool operator==(const SimpleDataArray& lhs, const SimpleDataArray& rhs) {
863 : return lhs.shape_ == rhs.shape_ && lhs.data_ == rhs.data_;
864 : }
865 :
866 : /// Two SimpleDataArray compare as equal if they have the exact same shape and
867 : /// data.
868 : inline bool operator!=(const SimpleDataArray& lhs, const SimpleDataArray& rhs) {
869 : return !(lhs == rhs);
870 : }
871 :
872 :
873 : /// An implementation of `DataArrayBase` containing no data.
874 : ///
875 : /// This class only tracks it's shape, and can be used when only the metadata
876 : /// of a `TensorBlock` is important, leaving the data unspecified.
877 : class EmptyDataArray: public metatensor::DataArrayBase {
878 : public:
879 : /// Create ae `EmptyDataArray` with the given `shape`
880 : EmptyDataArray(std::vector<uintptr_t> shape):
881 : shape_(std::move(shape)) {}
882 :
883 : ~EmptyDataArray() override = default;
884 :
885 : /// EmptyDataArray can be copy-constructed
886 : EmptyDataArray(const EmptyDataArray&) = default;
887 : /// EmptyDataArray can be copy-assigned
888 : EmptyDataArray& operator=(const EmptyDataArray&) = default;
889 : /// EmptyDataArray can be move-constructed
890 : EmptyDataArray(EmptyDataArray&&) noexcept = default;
891 : /// EmptyDataArray can be move-assigned
892 : EmptyDataArray& operator=(EmptyDataArray&&) noexcept = default;
893 :
894 : mts_data_origin_t origin() const override {
895 : mts_data_origin_t origin = 0;
896 : mts_register_data_origin("metatensor::EmptyDataArray", &origin);
897 : return origin;
898 : }
899 :
900 : double* data() & override {
901 : throw metatensor::Error("can not call `data` for an EmptyDataArray");
902 : }
903 :
904 : const std::vector<uintptr_t>& shape() const & override {
905 : return shape_;
906 : }
907 :
908 : void reshape(std::vector<uintptr_t> shape) override {
909 : if (details::product(shape_) != details::product(shape)) {
910 : throw metatensor::Error("invalid shape in reshape");
911 : }
912 : shape_ = std::move(shape);
913 : }
914 :
915 : void swap_axes(uintptr_t axis_1, uintptr_t axis_2) override {
916 : std::swap(shape_[axis_1], shape_[axis_2]);
917 : }
918 :
919 : std::unique_ptr<DataArrayBase> copy() const override {
920 : return std::unique_ptr<DataArrayBase>(new EmptyDataArray(*this));
921 : }
922 :
923 : std::unique_ptr<DataArrayBase> create(std::vector<uintptr_t> shape) const override {
924 : return std::unique_ptr<DataArrayBase>(new EmptyDataArray(std::move(shape)));
925 : }
926 :
927 : void move_samples_from(const DataArrayBase&, std::vector<mts_sample_mapping_t>, uintptr_t, uintptr_t) override {
928 : throw metatensor::Error("can not call `move_samples_from` for an EmptyDataArray");
929 : }
930 :
931 : private:
932 : std::vector<uintptr_t> shape_;
933 : };
934 :
935 : namespace details {
936 : /// Default callback for data array creating in `TensorMap::load`, which
937 : /// will create a `SimpleDataArray`.
938 : inline mts_status_t default_create_array(
939 : const uintptr_t* shape_ptr,
940 : uintptr_t shape_count,
941 : mts_array_t* array
942 : ) {
943 : return details::catch_exceptions([](const uintptr_t* shape_ptr, uintptr_t shape_count, mts_array_t* array){
944 : auto shape = std::vector<size_t>();
945 : for (size_t i=0; i<shape_count; i++) {
946 : shape.push_back(static_cast<size_t>(shape_ptr[i]));
947 : }
948 :
949 : auto cxx_array = std::unique_ptr<DataArrayBase>(new SimpleDataArray(shape));
950 : *array = DataArrayBase::to_mts_array_t(std::move(cxx_array));
951 :
952 : return MTS_SUCCESS;
953 : }, shape_ptr, shape_count, array);
954 : }
955 : }
956 :
957 : /******************************************************************************/
958 : /******************************************************************************/
959 : /* */
960 : /* I/O functionalities */
961 : /* */
962 : /******************************************************************************/
963 : /******************************************************************************/
964 :
965 : namespace io {
966 : /// Save a `TensorMap` to the file at `path`.
967 : ///
968 : /// If the file exists, it will be overwritten. The recomended file
969 : /// extension when saving data is `.mts`, to prevent confusion with generic
970 : /// `.npz` files.
971 : ///
972 : /// `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
973 : /// without compression (storage method is `STORED`), where each file is
974 : /// stored as a `.npy` array. See the C API documentation for more
975 : /// information on the format.
976 : void save(const std::string& path, const TensorMap& tensor);
977 :
978 : /// Save a `TensorMap` to an in-memory buffer.
979 : ///
980 : /// The `Buffer` template parameter can be set to any type that can be
981 : /// constructed from a pair of iterator over `std::vector<uint8_t>`.
982 : template <typename Buffer = std::vector<uint8_t>>
983 : Buffer save_buffer(const TensorMap& tensor);
984 :
985 : template<>
986 : std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorMap& tensor);
987 :
988 : /**************************************************************************/
989 :
990 : /// Save a `TensorBlock` to the file at `path`.
991 : ///
992 : /// If the file exists, it will be overwritten. The recomended file
993 : /// extension when saving data is `.mts`, to prevent confusion with generic
994 : /// `.npz` files.
995 : void save(const std::string& path, const TensorBlock& block);
996 :
997 : /// Save a `TensorBlock` to an in-memory buffer.
998 : ///
999 : /// The `Buffer` template parameter can be set to any type that can be
1000 : /// constructed from a pair of iterator over `std::vector<uint8_t>`.
1001 : template <typename Buffer = std::vector<uint8_t>>
1002 : Buffer save_buffer(const TensorBlock& block);
1003 :
1004 : template<>
1005 : std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorBlock& block);
1006 :
1007 : /**************************************************************************/
1008 :
1009 : /// Save `Labels` to the file at `path`.
1010 : ///
1011 : /// If the file exists, it will be overwritten. The recomended file
1012 : /// extension when saving data is `.mts`, to prevent confusion with generic
1013 : /// `.npz` files.
1014 : void save(const std::string& path, const Labels& labels);
1015 :
1016 : /// Save `Labels` to an in-memory buffer.
1017 : ///
1018 : /// The `Buffer` template parameter can be set to any type that can be
1019 : /// constructed from a pair of iterator over `std::vector<uint8_t>`.
1020 : template <typename Buffer = std::vector<uint8_t>>
1021 : Buffer save_buffer(const Labels& labels);
1022 :
1023 : template<>
1024 : std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const Labels& labels);
1025 :
1026 : /**************************************************************************/
1027 : /**************************************************************************/
1028 :
1029 : /*!
1030 : * Load a previously saved `TensorMap` from the given path.
1031 : *
1032 : * \verbatim embed:rst:leading-asterisk
1033 : *
1034 : * ``create_array`` will be used to create new arrays when constructing the
1035 : * blocks and gradients, the default version will create data using
1036 : * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
1037 : * for more information.
1038 : *
1039 : * \endverbatim
1040 : *
1041 : * `TensorMap` are serialized using numpy's NPZ format, i.e. a ZIP file
1042 : * without compression (storage method is `STORED`), where each file is
1043 : * stored as a `.npy` array. See the C API documentation for more
1044 : * information on the format.
1045 : */
1046 : TensorMap load(
1047 : const std::string& path,
1048 : mts_create_array_callback_t create_array = details::default_create_array
1049 : );
1050 :
1051 : /*!
1052 : * Load a previously saved `TensorMap` from the given `buffer`, containing
1053 : * `buffer_count` elements.
1054 : *
1055 : * \verbatim embed:rst:leading-asterisk
1056 : *
1057 : * ``create_array`` will be used to create new arrays when constructing the
1058 : * blocks and gradients, the default version will create data using
1059 : * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
1060 : * for more information.
1061 : *
1062 : * \endverbatim
1063 : */
1064 : TensorMap load_buffer(
1065 : const uint8_t* buffer,
1066 : size_t buffer_count,
1067 : mts_create_array_callback_t create_array = details::default_create_array
1068 : );
1069 :
1070 :
1071 : /// Load a previously saved `TensorMap` from the given `buffer`.
1072 : ///
1073 : /// The `Buffer` template parameter would typically be a
1074 : /// `std::vector<uint8_t>` or a `std::string`, but any container with
1075 : /// contiguous data and an `item_type` with the same size as a `uint8_t` can
1076 : /// work.
1077 : template <typename Buffer>
1078 : TensorMap load_buffer(
1079 : const Buffer& buffer,
1080 : mts_create_array_callback_t create_array = details::default_create_array
1081 : );
1082 :
1083 : /**************************************************************************/
1084 :
1085 : /*!
1086 : * Load a previously saved `TensorBlock` from the given path.
1087 : *
1088 : * \verbatim embed:rst:leading-asterisk
1089 : *
1090 : * ``create_array`` will be used to create new arrays when constructing the
1091 : * blocks and gradients, the default version will create data using
1092 : * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
1093 : * for more information.
1094 : *
1095 : * \endverbatim
1096 : *
1097 : */
1098 : TensorBlock load_block(
1099 : const std::string& path,
1100 : mts_create_array_callback_t create_array = details::default_create_array
1101 : );
1102 :
1103 : /*!
1104 : * Load a previously saved `TensorBlock` from the given `buffer`, containing
1105 : * `buffer_count` elements.
1106 : *
1107 : * \verbatim embed:rst:leading-asterisk
1108 : *
1109 : * ``create_array`` will be used to create new arrays when constructing the
1110 : * blocks and gradients, the default version will create data using
1111 : * :cpp:class:`SimpleDataArray`. See :c:func:`mts_create_array_callback_t`
1112 : * for more information.
1113 : *
1114 : * \endverbatim
1115 : */
1116 : TensorBlock load_block_buffer(
1117 : const uint8_t* buffer,
1118 : size_t buffer_count,
1119 : mts_create_array_callback_t create_array = details::default_create_array
1120 : );
1121 :
1122 :
1123 : /// Load a previously saved `TensorBlock` from the given `buffer`.
1124 : ///
1125 : /// The `Buffer` template parameter would typically be a
1126 : /// `std::vector<uint8_t>` or a `std::string`, but any container with
1127 : /// contiguous data and an `item_type` with the same size as a `uint8_t` can
1128 : /// work.
1129 : template <typename Buffer>
1130 : TensorBlock load_block_buffer(
1131 : const Buffer& buffer,
1132 : mts_create_array_callback_t create_array = details::default_create_array
1133 : );
1134 :
1135 : /**************************************************************************/
1136 :
1137 : /// Load previously saved `Labels` from the given path.
1138 : Labels load_labels(const std::string& path);
1139 :
1140 : /// Load previously saved `Labels` from the given `buffer`, containing
1141 : /// `buffer_count` elements.
1142 : Labels load_labels_buffer(const uint8_t* buffer, size_t buffer_count);
1143 :
1144 : /// Load a previously saved `Labels` from the given `buffer`.
1145 : ///
1146 : /// The `Buffer` template parameter would typically be a
1147 : /// `std::vector<uint8_t>` or a `std::string`, but any container with
1148 : /// contiguous data and an `item_type` with the same size as a `uint8_t` can
1149 : /// work.
1150 : template <typename Buffer>
1151 : Labels load_labels_buffer(const Buffer& buffer);
1152 : }
1153 :
1154 :
1155 : /******************************************************************************/
1156 : /******************************************************************************/
1157 : /* */
1158 : /* Labels */
1159 : /* */
1160 : /******************************************************************************/
1161 : /******************************************************************************/
1162 :
1163 :
1164 : /// It is possible to store some user-provided data inside `Labels`, and access
1165 : /// it later. This class is used to take ownership of the data and corresponding
1166 : /// delete function before giving the data to metatensor.
1167 : ///
1168 : /// User data inside `Labels` is an advanced functionality, that most users
1169 : /// should not need to interact with.
1170 : class LabelsUserData {
1171 : public:
1172 : /// Create `LabelsUserData` containing the given `data`.
1173 : ///
1174 : /// `deleter` will be called when the data is dropped, and should
1175 : /// free the corresponding memory.
1176 : LabelsUserData(void* data, void(*deleter)(void*)): data_(data), deleter_(deleter) {}
1177 :
1178 : ~LabelsUserData() {
1179 : if (deleter_ != nullptr) {
1180 : deleter_(data_);
1181 : }
1182 : }
1183 :
1184 : /// LabelsUserData is not copy-constructible
1185 : LabelsUserData(const LabelsUserData& other) = delete;
1186 : /// LabelsUserData can not be copy-assigned
1187 : LabelsUserData& operator=(const LabelsUserData& other) = delete;
1188 :
1189 : /// LabelsUserData is move-constructible
1190 : LabelsUserData(LabelsUserData&& other) noexcept: LabelsUserData(nullptr, nullptr) {
1191 : *this = std::move(other);
1192 : }
1193 :
1194 : /// LabelsUserData be move-assigned
1195 : LabelsUserData& operator=(LabelsUserData&& other) noexcept {
1196 : if (deleter_ != nullptr) {
1197 : deleter_(data_);
1198 : }
1199 :
1200 : data_ = other.data_;
1201 : deleter_ = other.deleter_;
1202 :
1203 : other.data_ = nullptr;
1204 : other.deleter_ = nullptr;
1205 :
1206 : return *this;
1207 : }
1208 :
1209 : private:
1210 : friend class Labels;
1211 :
1212 : void* data_;
1213 : void(*deleter_)(void*);
1214 : };
1215 :
1216 :
1217 : /// A set of labels used to carry metadata associated with a tensor map.
1218 : ///
1219 : /// This is similar to an array of named tuples, but stored as a 2D array
1220 : /// of shape `(count, size)`, with a set of names associated with the columns of
1221 : /// this array (often called *dimensions*). Each row/entry in this array is
1222 : /// unique, and they are often (but not always) sorted in lexicographic order.
1223 : class Labels final {
1224 : public:
1225 : /// Create a new set of Labels from the given `names` and `values`.
1226 : ///
1227 : /// Each entry in the values must contain `names.size()` elements.
1228 : ///
1229 : /// ```
1230 : /// auto labels = Labels({"first", "second"}, {
1231 : /// {0, 1},
1232 : /// {1, 4},
1233 : /// {2, 1},
1234 : /// {2, 3},
1235 : /// });
1236 : /// ```
1237 : Labels(
1238 : const std::vector<std::string>& names,
1239 : const std::vector<std::initializer_list<int32_t>>& values
1240 : ): Labels(names, NDArray<int32_t>(values, names.size()), InternalConstructor{}) {}
1241 :
1242 : /// This function does not check for uniqueness of the labels entries, which
1243 : /// should be enforced by the caller. Calling this function with non-unique
1244 : /// entries is invalid and can lead to crashes or infinite loops.
1245 : explicit Labels(
1246 : const std::vector<std::string>& names,
1247 : const std::vector<std::initializer_list<int32_t>>& values,
1248 : assume_unique
1249 : ): Labels(names, NDArray<int32_t>(values, names.size()), assume_unique{}, InternalConstructor{}) {}
1250 :
1251 : /// Create an empty set of Labels with the given names
1252 : explicit Labels(const std::vector<std::string>& names):
1253 : Labels(names, static_cast<const int32_t*>(nullptr), 0) {}
1254 :
1255 : /// Create labels with the given `names` and `values`. `values` must be an
1256 : /// array with `count x names.size()` elements.
1257 : Labels(const std::vector<std::string>& names, const int32_t* values, size_t count):
1258 : Labels(details::labels_from_cxx(names, values, count, false)) {}
1259 :
1260 : /// Unchecked variant, caller promises the labels are unique. Calling with
1261 : /// non-unique entries is invalid and can ead to crashes or infinite loops.
1262 : Labels(const std::vector<std::string>& names, const int32_t* values, size_t count, assume_unique):
1263 : Labels(details::labels_from_cxx(names, values, count, true)) {}
1264 :
1265 : ~Labels() {
1266 : mts_labels_free(&labels_);
1267 : }
1268 :
1269 : /// Labels is copy-constructible
1270 : Labels(const Labels& other): Labels() {
1271 : *this = other;
1272 : }
1273 :
1274 : /// Labels can be copy-assigned
1275 : Labels& operator=(const Labels& other) {
1276 : mts_labels_free(&labels_);
1277 : std::memset(&labels_, 0, sizeof(labels_));
1278 : details::check_status(mts_labels_clone(other.labels_, &labels_));
1279 : assert(this->labels_.internal_ptr_ != nullptr);
1280 :
1281 : this->values_ = NDArray<int32_t>(labels_.values, {labels_.count, labels_.size});
1282 :
1283 : this->names_.clear();
1284 : for (size_t i=0; i<this->labels_.size; i++) {
1285 : this->names_.push_back(this->labels_.names[i]);
1286 : }
1287 :
1288 : return *this;
1289 : }
1290 :
1291 : /// Labels is move-constructible
1292 : Labels(Labels&& other) noexcept: Labels() {
1293 : *this = std::move(other);
1294 : }
1295 :
1296 : /// Labels can be move-assigned
1297 : Labels& operator=(Labels&& other) noexcept {
1298 : mts_labels_free(&labels_);
1299 : this->labels_ = other.labels_;
1300 : assert(this->labels_.internal_ptr_ != nullptr);
1301 : std::memset(&other.labels_, 0, sizeof(other.labels_));
1302 :
1303 : this->values_ = std::move(other.values_);
1304 : this->names_ = std::move(other.names_);
1305 :
1306 : return *this;
1307 : }
1308 :
1309 : /// Get the names of the dimensions used in these `Labels`.
1310 : const std::vector<const char*>& names() const {
1311 : return names_;
1312 : }
1313 :
1314 : /// Get the number of entries in this set of Labels.
1315 : ///
1316 : /// This is the same as `shape()[0]` for the corresponding values array
1317 : size_t count() const {
1318 : return labels_.count;
1319 : }
1320 :
1321 : /// Get the number of dimensions in this set of Labels.
1322 : ///
1323 : /// This is the same as `shape()[1]` for the corresponding values array
1324 : size_t size() const {
1325 : return labels_.size;
1326 : }
1327 :
1328 : /// Convert from this set of Labels to the C `mts_labels_t`
1329 : mts_labels_t as_mts_labels_t() const {
1330 : assert(labels_.internal_ptr_ != nullptr);
1331 : return labels_;
1332 : }
1333 :
1334 : /// Get the user data pointer registered with these `Labels`.
1335 : ///
1336 : /// If no user data have been registered, this function will return
1337 : /// `nullptr`.
1338 : void* user_data() & {
1339 : assert(labels_.internal_ptr_ != nullptr);
1340 :
1341 : void* data = nullptr;
1342 : details::check_status(mts_labels_user_data(labels_, &data));
1343 : return data;
1344 : }
1345 :
1346 : void* user_data() && = delete;
1347 :
1348 : /// Register some user data pointer with these `Labels`.
1349 : ///
1350 : /// Any existing user data will be released (by calling the provided
1351 : /// `delete` function) before overwriting with the new data.
1352 : void set_user_data(LabelsUserData user_data) {
1353 : assert(labels_.internal_ptr_ != nullptr);
1354 :
1355 : details::check_status(mts_labels_set_user_data(
1356 : labels_,
1357 : user_data.data_,
1358 : user_data.deleter_
1359 : ));
1360 :
1361 : // the user data was moved inside `labels_`
1362 : user_data.data_ = nullptr;
1363 : user_data.deleter_ = nullptr;
1364 : }
1365 :
1366 : /// Get the position of the `entry` in this set of Labels, or -1 if the
1367 : /// entry is not part of these Labels.
1368 : int64_t position(std::initializer_list<int32_t> entry) const {
1369 : return this->position(entry.begin(), entry.size());
1370 : }
1371 :
1372 : /// Variant of `Labels::position` taking a fixed-size array as input
1373 : template<size_t N>
1374 : int64_t position(const std::array<int32_t, N>& entry) const {
1375 : return this->position(entry.data(), entry.size());
1376 : }
1377 :
1378 : /// Variant of `Labels::position` taking a vector as input
1379 : int64_t position(const std::vector<int32_t>& entry) const {
1380 : return this->position(entry.data(), entry.size());
1381 : }
1382 :
1383 : /// Variant of `Labels::position` taking a pointer and length as input
1384 : int64_t position(const int32_t* entry, size_t length) const {
1385 : assert(labels_.internal_ptr_ != nullptr);
1386 :
1387 : int64_t result = 0;
1388 : details::check_status(mts_labels_position(labels_, entry, length, &result));
1389 : return result;
1390 : }
1391 :
1392 : /// Get the array of values for these Labels
1393 : const NDArray<int32_t>& values() const & {
1394 : return values_;
1395 : }
1396 :
1397 : const NDArray<int32_t>& values() && = delete;
1398 :
1399 : /// Take the union of these `Labels` with `other`.
1400 : ///
1401 : /// If requested, this function can also give the positions in the union
1402 : /// where each entry of the input `Labels` ended up.
1403 : ///
1404 : /// No user data pointer is registered with the output, even if the inputs
1405 : /// have some.
1406 : ///
1407 : /// @param other the `Labels` we want to take the union with
1408 : /// @param first_mapping if you want the mapping from the positions of
1409 : /// entries in `this` to the positions in the union, this should be
1410 : /// a pointer to an array containing `this->count()` elements, to be
1411 : /// filled by this function. Otherwise it should be a `nullptr`.
1412 : /// @param first_mapping_count number of elements in `first_mapping`
1413 : /// @param second_mapping if you want the mapping from the positions of
1414 : /// entries in `other` to the positions in the union, this should be
1415 : /// a pointer to an array containing `other.count()` elements, to be
1416 : /// filled by this function. Otherwise it should be a `nullptr`.
1417 : /// @param second_mapping_count number of elements in `second_mapping`
1418 : Labels set_union(
1419 : const Labels& other,
1420 : int64_t* first_mapping = nullptr,
1421 : size_t first_mapping_count = 0,
1422 : int64_t* second_mapping = nullptr,
1423 : size_t second_mapping_count = 0
1424 : ) const {
1425 : mts_labels_t result;
1426 : std::memset(&result, 0, sizeof(result));
1427 :
1428 : details::check_status(mts_labels_union(
1429 : labels_,
1430 : other.labels_,
1431 : &result,
1432 : first_mapping,
1433 : first_mapping_count,
1434 : second_mapping,
1435 : second_mapping_count
1436 : ));
1437 :
1438 : return Labels(result);
1439 : }
1440 :
1441 : /// Take the union of these `Labels` with `other`.
1442 : ///
1443 : /// If requested, this function can also give the positions in the
1444 : /// union where each entry of the input `Labels` ended up.
1445 : ///
1446 : /// No user data pointer is registered with the output, even if the inputs
1447 : /// have some.
1448 : ///
1449 : /// @param other the `Labels` we want to take the union with
1450 : /// @param first_mapping if you want the mapping from the positions of
1451 : /// entries in `this` to the positions in the union, this should be
1452 : /// a vector containing `this->count()` elements, to be filled by
1453 : /// this function. Otherwise it should be an empty vector.
1454 : /// @param second_mapping if you want the mapping from the positions of
1455 : /// entries in `other` to the positions in the union, this should be
1456 : /// a vector containing `other.count()` elements, to be filled by
1457 : /// this function. Otherwise it should be an empty vector.
1458 : Labels set_union(
1459 : const Labels& other,
1460 : std::vector<int64_t>& first_mapping,
1461 : std::vector<int64_t>& second_mapping
1462 : ) const {
1463 : auto* first_mapping_ptr = first_mapping.data();
1464 : auto first_mapping_count = first_mapping.size();
1465 : if (first_mapping_count == 0) {
1466 : first_mapping_ptr = nullptr;
1467 : }
1468 :
1469 : auto* second_mapping_ptr = second_mapping.data();
1470 : auto second_mapping_count = second_mapping.size();
1471 : if (second_mapping_count == 0) {
1472 : second_mapping_ptr = nullptr;
1473 : }
1474 :
1475 : return this->set_union(
1476 : other,
1477 : first_mapping_ptr,
1478 : first_mapping_count,
1479 : second_mapping_ptr,
1480 : second_mapping_count
1481 : );
1482 : }
1483 :
1484 : /// Take the intersection of these `Labels` with `other`.
1485 : ///
1486 : /// If requested, this function can also give the positions in the
1487 : /// intersection where each entry of the input `Labels` ended up.
1488 : ///
1489 : /// No user data pointer is registered with the output, even if the inputs
1490 : /// have some.
1491 : ///
1492 : /// @param other the `Labels` we want to take the intersection with
1493 : /// @param first_mapping if you want the mapping from the positions of
1494 : /// entries in `this` to the positions in the intersection, this
1495 : /// should be a pointer to an array containing `this->count()`
1496 : /// elements, to be filled by this function. Otherwise it should be a
1497 : /// `nullptr`. If an entry in `this` is not used in the intersection,
1498 : /// the mapping will be set to -1.
1499 : /// @param first_mapping_count number of elements in `first_mapping`
1500 : /// @param second_mapping if you want the mapping from the positions of
1501 : /// entries in `other` to the positions in the intersection, this
1502 : /// should be a pointer to an array containing `other.count()`
1503 : /// elements, to be filled by this function. Otherwise it should be a
1504 : /// `nullptr`. If an entry in `other` is not used in the
1505 : /// intersection, the mapping will be set to -1.
1506 : /// @param second_mapping_count number of elements in `second_mapping`
1507 : Labels set_intersection(
1508 : const Labels& other,
1509 : int64_t* first_mapping = nullptr,
1510 : size_t first_mapping_count = 0,
1511 : int64_t* second_mapping = nullptr,
1512 : size_t second_mapping_count = 0
1513 : ) const {
1514 : mts_labels_t result;
1515 : std::memset(&result, 0, sizeof(result));
1516 :
1517 : details::check_status(mts_labels_intersection(
1518 : labels_,
1519 : other.labels_,
1520 : &result,
1521 : first_mapping,
1522 : first_mapping_count,
1523 : second_mapping,
1524 : second_mapping_count
1525 : ));
1526 :
1527 : return Labels(result);
1528 : }
1529 :
1530 : /// Take the intersection of this `Labels` with `other`.
1531 : ///
1532 : /// If requested, this function can also give the positions in the
1533 : /// intersection where each entry of the input `Labels` ended up.
1534 : ///
1535 : /// No user data pointer is registered with the output, even if the inputs
1536 : /// have some.
1537 : ///
1538 : /// @param other the `Labels` we want to take the intersection with
1539 : /// @param first_mapping if you want the mapping from the positions of
1540 : /// entries in `this` to the positions in the intersection, this
1541 : /// should be a vector containing `this->count()` elements, to be
1542 : /// filled by this function. Otherwise it should be an empty vector.
1543 : /// If an entry in `this` is not used in the intersection, the
1544 : /// mapping will be set to -1.
1545 : /// @param second_mapping if you want the mapping from the positions of
1546 : /// entries in `other` to the positions in the intersection, this
1547 : /// should be a vector containing `other.count()` elements, to be
1548 : /// filled by this function. Otherwise it should be an empty vector.
1549 : /// If an entry in `other` is not used in the intersection, the
1550 : /// mapping will be set to -1.
1551 : Labels set_intersection(
1552 : const Labels& other,
1553 : std::vector<int64_t>& first_mapping,
1554 : std::vector<int64_t>& second_mapping
1555 : ) const {
1556 : auto* first_mapping_ptr = first_mapping.data();
1557 : auto first_mapping_count = first_mapping.size();
1558 : if (first_mapping_count == 0) {
1559 : first_mapping_ptr = nullptr;
1560 : }
1561 :
1562 : auto* second_mapping_ptr = second_mapping.data();
1563 : auto second_mapping_count = second_mapping.size();
1564 : if (second_mapping_count == 0) {
1565 : second_mapping_ptr = nullptr;
1566 : }
1567 :
1568 : return this->set_intersection(
1569 : other,
1570 : first_mapping_ptr,
1571 : first_mapping_count,
1572 : second_mapping_ptr,
1573 : second_mapping_count
1574 : );
1575 : }
1576 :
1577 : /// Take the difference of these `Labels` with `other`.
1578 : ///
1579 : /// If requested, this function can also give the positions in the
1580 : /// difference where each entry of the input `Labels` ended up.
1581 : ///
1582 : /// No user data pointer is registered with the output, even if the inputs
1583 : /// have some.
1584 : ///
1585 : /// @param other the `Labels` we want to take the difference with
1586 : /// @param first_mapping if you want the mapping from the positions of
1587 : /// entries in `this` to the positions in the difference, this
1588 : /// should be a pointer to an array containing `this->count()`
1589 : /// elements, to be filled by this function. Otherwise it should be a
1590 : /// `nullptr`. If an entry in `this` is not used in the difference,
1591 : /// the mapping will be set to -1.
1592 : /// @param first_mapping_count number of elements in `first_mapping`
1593 : Labels set_difference(
1594 : const Labels& other,
1595 : int64_t* first_mapping = nullptr,
1596 : size_t first_mapping_count = 0
1597 : ) const {
1598 : mts_labels_t result;
1599 : std::memset(&result, 0, sizeof(result));
1600 :
1601 : details::check_status(mts_labels_difference(
1602 : labels_,
1603 : other.labels_,
1604 : &result,
1605 : first_mapping,
1606 : first_mapping_count
1607 : ));
1608 :
1609 : return Labels(result);
1610 : }
1611 :
1612 : /// Take the difference of this `Labels` with `other`.
1613 : ///
1614 : /// If requested, this function can also give the positions in the
1615 : /// difference where each entry of the input `Labels` ended up.
1616 : ///
1617 : /// No user data pointer is registered with the output, even if the inputs
1618 : /// have some.
1619 : ///
1620 : /// @param other the `Labels` we want to take the difference with
1621 : /// @param first_mapping if you want the mapping from the positions of
1622 : /// entries in `this` to the positions in the difference, this
1623 : /// should be a vector containing `this->count()` elements, to be
1624 : /// filled by this function. Otherwise it should be an empty vector.
1625 : /// If an entry in `this` is not used in the difference, the
1626 : /// mapping will be set to -1.
1627 : Labels set_difference(const Labels& other, std::vector<int64_t>& first_mapping) const {
1628 : auto* first_mapping_ptr = first_mapping.data();
1629 : auto first_mapping_count = first_mapping.size();
1630 : if (first_mapping_count == 0) {
1631 : first_mapping_ptr = nullptr;
1632 : }
1633 :
1634 : return this->set_difference(
1635 : other,
1636 : first_mapping_ptr,
1637 : first_mapping_count
1638 : );
1639 : }
1640 :
1641 : /// Select entries in these `Labels` that match the `selection`.
1642 : ///
1643 : /// The selection's names must be a subset of the names of these labels.
1644 : ///
1645 : /// All entries in these `Labels` that match one of the entry in the
1646 : /// `selection` for all the selection's dimension will be picked. Any entry
1647 : /// in the `selection` but not in these `Labels` will be ignored.
1648 : ///
1649 : /// @param selection definition of the selection criteria. Multiple entries
1650 : /// are interpreted as a logical `or` operation.
1651 : /// @param selected on input, a pointer to an array with space for
1652 : /// `*selected_count` entries. On output, the first `*selected_count`
1653 : /// values will contain the index in `labels` of selected entries.
1654 : /// @param selected_count on input, size of the `selected` array. On output,
1655 : /// this will contain the number of selected entries.
1656 : void select(const Labels& selection, int64_t* selected, size_t *selected_count) const {
1657 : details::check_status(mts_labels_select(
1658 : labels_,
1659 : selection.labels_,
1660 : selected,
1661 : selected_count
1662 : ));
1663 : }
1664 :
1665 : /// Select entries in these `Labels` that match the `selection`.
1666 : ///
1667 : /// This function does the same thing as the one above, but allocates and
1668 : /// return the list of selected indexes in a `std::vector`
1669 : std::vector<int64_t> select(const Labels& selection) const {
1670 : auto selected_count = this->count();
1671 : auto selected = std::vector<int64_t>(selected_count, -1);
1672 :
1673 : this->select(selection, selected.data(), &selected_count);
1674 :
1675 : selected.resize(selected_count);
1676 : return selected;
1677 : }
1678 :
1679 : /*!
1680 : * \verbatim embed:rst:leading-asterisk
1681 : *
1682 : * Load previously saved ``Labels`` from the given path.
1683 : *
1684 : * This is identical to :cpp:func:`metatensor::io::load_labels`, and
1685 : * provided as a convenience API.
1686 : *
1687 : * \endverbatim
1688 : */
1689 : static Labels load(const std::string& path) {
1690 : return metatensor::io::load_labels(path);
1691 : }
1692 :
1693 : /*!
1694 : * \verbatim embed:rst:leading-asterisk
1695 : *
1696 : * Load previously saved ``Labels`` from a in-memory buffer.
1697 : *
1698 : * This is identical to :cpp:func:`metatensor::io::load_labels_buffer`, and
1699 : * provided as a convenience API.
1700 : *
1701 : * \endverbatim
1702 : */
1703 : static Labels load_buffer(const uint8_t* buffer, size_t buffer_count) {
1704 : return metatensor::io::load_labels_buffer(buffer, buffer_count);
1705 : }
1706 :
1707 : /*!
1708 : * \verbatim embed:rst:leading-asterisk
1709 : *
1710 : * Load previously saved ``Labels`` from a in-memory buffer.
1711 : *
1712 : * This is identical to :cpp:func:`metatensor::io::load_labels_buffer`, and
1713 : * provided as a convenience API.
1714 : *
1715 : * \endverbatim
1716 : */
1717 : template <typename Buffer>
1718 : static Labels load_buffer(const Buffer& buffer) {
1719 : return metatensor::io::load_labels_buffer<Buffer>(buffer);
1720 : }
1721 :
1722 : /*!
1723 : * \verbatim embed:rst:leading-asterisk
1724 : *
1725 : * Save ``Labels`` to the given path.
1726 : *
1727 : * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
1728 : * convenience API.
1729 : *
1730 : * \endverbatim
1731 : */
1732 : void save(const std::string& path) const {
1733 : metatensor::io::save(path, *this);
1734 : }
1735 :
1736 : /*!
1737 : * \verbatim embed:rst:leading-asterisk
1738 : *
1739 : * Save ``Labels`` to an in-memory buffer.
1740 : *
1741 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
1742 : * provided as a convenience API.
1743 : *
1744 : * \endverbatim
1745 : */
1746 : std::vector<uint8_t> save_buffer() const {
1747 : return metatensor::io::save_buffer(*this);
1748 : }
1749 :
1750 : /*!
1751 : * \verbatim embed:rst:leading-asterisk
1752 : *
1753 : * Save ``Labels`` to an in-memory buffer.
1754 : *
1755 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
1756 : * provided as a convenience API.
1757 : *
1758 : * \endverbatim
1759 : */
1760 : template <typename Buffer>
1761 : Buffer save_buffer() const {
1762 : return metatensor::io::save_buffer<Buffer>(*this);
1763 : }
1764 :
1765 : private:
1766 : explicit Labels(): values_(static_cast<const int32_t*>(nullptr), {0, 0})
1767 : {
1768 : std::memset(&labels_, 0, sizeof(labels_));
1769 : }
1770 :
1771 : explicit Labels(mts_labels_t labels):
1772 : values_(labels.values, {labels.count, labels.size}),
1773 : labels_(labels)
1774 : {
1775 : assert(labels_.internal_ptr_ != nullptr);
1776 :
1777 : for (size_t i=0; i<labels_.size; i++) {
1778 : names_.push_back(labels_.names[i]);
1779 : }
1780 : }
1781 :
1782 : // the constructor below is ambiguous with the public constructor taking
1783 : // `std::initializer_list`, so we use a private dummy struct argument to
1784 : // remove the ambiguity.
1785 : struct InternalConstructor {};
1786 : Labels(const std::vector<std::string>& names, const NDArray<int32_t>& values, InternalConstructor):
1787 : Labels(names, values.data(), values.shape()[0]) {}
1788 :
1789 : Labels(const std::vector<std::string>& names, const NDArray<int32_t>& values, assume_unique, InternalConstructor):
1790 : Labels(names, values.data(), values.shape()[0], assume_unique{}) {}
1791 :
1792 : friend Labels details::labels_from_cxx(const std::vector<std::string>& names, const int32_t* values, size_t count, bool assume_unique);
1793 : friend Labels io::load_labels(const std::string &path);
1794 : friend Labels io::load_labels_buffer(const uint8_t* buffer, size_t buffer_count);
1795 : friend class TensorMap;
1796 : friend class TensorBlock;
1797 :
1798 : friend class metatensor_torch::LabelsHolder;
1799 :
1800 : std::vector<const char*> names_;
1801 : NDArray<int32_t> values_;
1802 : mts_labels_t labels_;
1803 :
1804 : friend bool operator==(const Labels& lhs, const Labels& rhs);
1805 : };
1806 :
1807 : namespace details {
1808 : inline metatensor::Labels labels_from_cxx(
1809 : const std::vector<std::string>& names,
1810 : const int32_t* values,
1811 : size_t count,
1812 : bool assume_unique = false
1813 : ) {
1814 : mts_labels_t labels;
1815 : std::memset(&labels, 0, sizeof(labels));
1816 :
1817 : auto c_names = std::vector<const char*>();
1818 : for (const auto& name: names) {
1819 : c_names.push_back(name.c_str());
1820 : }
1821 :
1822 : labels.names = c_names.data();
1823 : labels.size = c_names.size();
1824 : labels.count = count;
1825 : labels.values = values;
1826 :
1827 : if (assume_unique) {
1828 : details::check_status(mts_labels_create_assume_unique(&labels));
1829 : } else {
1830 : details::check_status(mts_labels_create(&labels));
1831 : }
1832 :
1833 : return metatensor::Labels(labels);
1834 : }
1835 : }
1836 :
1837 :
1838 : /// Two Labels compare equal only if they have the same names and values in the
1839 : /// same order.
1840 : inline bool operator==(const Labels& lhs, const Labels& rhs) {
1841 : if (lhs.names_.size() != rhs.names_.size()) {
1842 : return false;
1843 : }
1844 :
1845 : for (size_t i=0; i<lhs.names_.size(); i++) {
1846 : if (std::strcmp(lhs.names_[i], rhs.names_[i]) != 0) {
1847 : return false;
1848 : }
1849 : }
1850 :
1851 : return lhs.values() == rhs.values();
1852 : }
1853 :
1854 : /// Two Labels compare equal only if they have the same names and values in the
1855 : /// same order.
1856 : inline bool operator!=(const Labels& lhs, const Labels& rhs) {
1857 : return !(lhs == rhs);
1858 : }
1859 :
1860 :
1861 : /******************************************************************************/
1862 : /******************************************************************************/
1863 : /* */
1864 : /* TensorBlock */
1865 : /* */
1866 : /******************************************************************************/
1867 : /******************************************************************************/
1868 :
1869 :
1870 : /// Basic building block for a tensor map.
1871 : ///
1872 : /// A single block contains a n-dimensional `mts_array_t` (or `DataArrayBase`),
1873 : /// and n sets of `Labels` (one for each dimension). The first dimension is the
1874 : /// *samples* dimension, the last dimension is the *properties* dimension. Any
1875 : /// intermediate dimension is called a *component* dimension.
1876 : ///
1877 : /// Samples should be used to describe *what* we are representing, while
1878 : /// properties should contain information about *how* we are representing it.
1879 : /// Finally, components should be used to describe vectorial or tensorial
1880 : /// components of the data.
1881 : ///
1882 : /// A block can also contain gradients of the values with respect to a variety
1883 : /// of parameters. In this case, each gradient has a separate set of samples,
1884 : /// and possibly components but share the same property labels as the values.
1885 : class TensorBlock final {
1886 : public:
1887 : /// Create a new TensorBlock containing the given `values` array.
1888 : ///
1889 : /// The different dimensions of the values are described by `samples`,
1890 : /// `components` and `properties` `Labels`
1891 : TensorBlock(
1892 : std::unique_ptr<DataArrayBase> values,
1893 : const Labels& samples,
1894 : const std::vector<Labels>& components,
1895 : const Labels& properties
1896 : ):
1897 : block_(nullptr),
1898 : is_view_(false)
1899 : {
1900 : auto c_components = std::vector<mts_labels_t>();
1901 : for (const auto& component: components) {
1902 : c_components.push_back(component.as_mts_labels_t());
1903 : }
1904 : block_ = mts_block(
1905 : DataArrayBase::to_mts_array_t(std::move(values)),
1906 : samples.as_mts_labels_t(),
1907 : c_components.data(),
1908 : c_components.size(),
1909 : properties.as_mts_labels_t()
1910 : );
1911 :
1912 : details::check_pointer(block_);
1913 : }
1914 :
1915 : ~TensorBlock() {
1916 : if (!is_view_) {
1917 : mts_block_free(block_);
1918 : }
1919 : }
1920 :
1921 : /// TensorBlock can NOT be copy constructed, use TensorBlock::clone instead
1922 : TensorBlock(const TensorBlock&) = delete;
1923 :
1924 : /// TensorBlock can NOT be copy assigned, use TensorBlock::clone instead
1925 : TensorBlock& operator=(const TensorBlock& other) = delete;
1926 :
1927 : /// TensorBlock can be move constructed
1928 : TensorBlock(TensorBlock&& other) noexcept : TensorBlock() {
1929 : *this = std::move(other);
1930 : }
1931 :
1932 : /// TensorBlock can be moved assigned
1933 : TensorBlock& operator=(TensorBlock&& other) noexcept {
1934 : if (!is_view_) {
1935 : mts_block_free(block_);
1936 : }
1937 :
1938 : this->block_ = other.block_;
1939 : this->is_view_ = other.is_view_;
1940 : other.block_ = nullptr;
1941 : other.is_view_ = true;
1942 :
1943 : return *this;
1944 : }
1945 :
1946 : /// Make a copy of this `TensorBlock`, including all the data contained inside
1947 : TensorBlock clone() const {
1948 : auto copy = TensorBlock();
1949 : copy.is_view_ = false;
1950 : copy.block_ = mts_block_copy(this->block_);
1951 : details::check_pointer(copy.block_);
1952 : return copy;
1953 : }
1954 :
1955 : /// Get a copy of the metadata in this block (i.e. samples, components, and
1956 : /// properties), ignoring the data itself.
1957 : ///
1958 : /// The resulting block values will be an `EmptyDataArray` instance, which
1959 : /// does not contain any data.
1960 : TensorBlock clone_metadata_only() const {
1961 : auto block = TensorBlock(
1962 : std::unique_ptr<EmptyDataArray>(new EmptyDataArray(this->values_shape())),
1963 : this->samples(),
1964 : this->components(),
1965 : this->properties()
1966 : );
1967 :
1968 : for (const auto& parameter: this->gradients_list()) {
1969 : auto gradient = this->gradient(parameter);
1970 : block.add_gradient(parameter, gradient.clone_metadata_only());
1971 : }
1972 :
1973 : return block;
1974 : }
1975 :
1976 : /// Get a view in the values in this block
1977 : NDArray<double> values() & {
1978 : auto array = this->mts_array();
1979 : double* data = nullptr;
1980 : details::check_status(array.data(array.ptr, &data));
1981 :
1982 : return NDArray<double>(data, this->values_shape());
1983 : }
1984 :
1985 : NDArray<double> values() && = delete;
1986 :
1987 : /// Access the sample `Labels` for this block.
1988 : ///
1989 : /// The entries in these labels describe the first dimension of the
1990 : /// `values()` array.
1991 : Labels samples() const {
1992 : return this->labels(0);
1993 : }
1994 :
1995 : /// Access the component `Labels` for this block.
1996 : ///
1997 : /// The entries in these labels describe intermediate dimensions of the
1998 : /// `values()` array.
1999 : std::vector<Labels> components() const {
2000 : auto shape = this->values_shape();
2001 :
2002 : auto result = std::vector<Labels>();
2003 : for (size_t i=1; i<shape.size() - 1; i++) {
2004 : result.emplace_back(this->labels(i));
2005 : }
2006 :
2007 : return result;
2008 : }
2009 :
2010 : /// Access the property `Labels` for this block.
2011 : ///
2012 : /// The entries in these labels describe the last dimension of the
2013 : /// `values()` array. The properties are guaranteed to be the same for
2014 : /// a block and all of its gradients.
2015 : Labels properties() const {
2016 : auto shape = this->values_shape();
2017 : return this->labels(shape.size() - 1);
2018 : }
2019 :
2020 : /// Add a set of gradients with respect to `parameters` in this block.
2021 : ///
2022 : /// @param parameter add gradients with respect to this `parameter` (e.g.
2023 : /// `"positions"`, `"cell"`, ...)
2024 : /// @param gradient a `TensorBlock` whose values contain the gradients with
2025 : /// respect to the `parameter`. The labels of the gradient
2026 : /// `TensorBlock` should be organized as follows: its
2027 : /// `samples` must contain `"sample"` as the first label,
2028 : /// which establishes a correspondence with the `samples` of
2029 : /// the original `TensorBlock`; its components must contain
2030 : /// at least the same components as the original
2031 : /// `TensorBlock`, with any additional component coming
2032 : /// before those; its properties must match those of the
2033 : /// original `TensorBlock`.
2034 : void add_gradient(const std::string& parameter, TensorBlock gradient) {
2035 : if (is_view_) {
2036 : throw Error(
2037 : "can not call TensorBlock::add_gradient on this block since "
2038 : "it is a view inside a TensorMap"
2039 : );
2040 : }
2041 :
2042 : details::check_status(mts_block_add_gradient(
2043 : block_,
2044 : parameter.c_str(),
2045 : gradient.release()
2046 : ));
2047 : }
2048 :
2049 : /// Get a list of all gradients defined in this block.
2050 : std::vector<std::string> gradients_list() const {
2051 : const char*const * parameters = nullptr;
2052 : uintptr_t count = 0;
2053 : details::check_status(mts_block_gradients_list(
2054 : block_,
2055 : ¶meters,
2056 : &count
2057 : ));
2058 :
2059 : auto result = std::vector<std::string>();
2060 : for (uint64_t i=0; i<count; i++) {
2061 : result.emplace_back(parameters[i]);
2062 : }
2063 :
2064 : return result;
2065 : }
2066 :
2067 : /// Get the gradient in this block with respect to the given `parameter`.
2068 : /// The gradient is returned as a TensorBlock itself.
2069 : ///
2070 : /// @param parameter check for gradients with respect to this `parameter`
2071 : /// (e.g. `"positions"`, `"cell"`, ...)
2072 : TensorBlock gradient(const std::string& parameter) const {
2073 : mts_block_t* gradient_block = nullptr;
2074 : details::check_status(
2075 : mts_block_gradient(block_, parameter.c_str(), &gradient_block)
2076 : );
2077 : details::check_pointer(gradient_block);
2078 : return TensorBlock::unsafe_view_from_ptr(gradient_block);
2079 : }
2080 :
2081 : /// Get the `mts_block_t` pointer corresponding to this block.
2082 : ///
2083 : /// The block pointer is still managed by the current `TensorBlock`
2084 : mts_block_t* as_mts_block_t() & {
2085 : if (is_view_) {
2086 : throw Error(
2087 : "can not call non-const TensorBlock::as_mts_block_t on this "
2088 : "block since it is a view inside a TensorMap"
2089 : );
2090 : }
2091 : return block_;
2092 : }
2093 :
2094 : /// const version of `as_mts_block_t`
2095 : const mts_block_t* as_mts_block_t() const & {
2096 : return block_;
2097 : }
2098 :
2099 : const mts_block_t* as_mts_block_t() && = delete;
2100 :
2101 : /// Create a new TensorBlock taking ownership of a raw `mts_block_t` pointer.
2102 : static TensorBlock unsafe_from_ptr(mts_block_t* ptr) {
2103 : auto block = TensorBlock();
2104 : block.block_ = ptr;
2105 : block.is_view_ = false;
2106 : return block;
2107 : }
2108 :
2109 : /// Create a new TensorBlock which is a view corresponding to a raw
2110 : /// `mts_block_t` pointer.
2111 : static TensorBlock unsafe_view_from_ptr(mts_block_t* ptr) {
2112 : auto block = TensorBlock();
2113 : block.block_ = ptr;
2114 : block.is_view_ = true;
2115 : return block;
2116 : }
2117 :
2118 : /// Get a raw `mts_array_t` corresponding to the values in this block.
2119 : mts_array_t mts_array() {
2120 : mts_array_t array;
2121 : std::memset(&array, 0, sizeof(array));
2122 :
2123 : details::check_status(
2124 : mts_block_data(block_, &array)
2125 : );
2126 : return array;
2127 : }
2128 :
2129 : /// Get the labels in this block associated with the given `axis`.
2130 : Labels labels(uintptr_t axis) const {
2131 : mts_labels_t labels;
2132 : std::memset(&labels, 0, sizeof(labels));
2133 : details::check_status(mts_block_labels(
2134 : block_, axis, &labels
2135 : ));
2136 :
2137 : return Labels(labels);
2138 : }
2139 :
2140 : /// Get the shape of the value array for this block
2141 62 : std::vector<uintptr_t> values_shape() const {
2142 62 : auto array = this->const_mts_array();
2143 :
2144 62 : const uintptr_t* shape = nullptr;
2145 62 : uintptr_t shape_count = 0;
2146 62 : details::check_status(array.shape(array.ptr, &shape, &shape_count));
2147 : assert(shape_count >= 2);
2148 :
2149 62 : return {shape, shape + shape_count};
2150 : }
2151 :
2152 : /*!
2153 : * \verbatim embed:rst:leading-asterisk
2154 : *
2155 : * Load a previously saved ``TensorBlock`` from the given path.
2156 : *
2157 : * This is identical to :cpp:func:`metatensor::io::load_block`, and provided
2158 : * as a convenience API.
2159 : *
2160 : * \endverbatim
2161 : */
2162 : static TensorBlock load(
2163 : const std::string& path,
2164 : mts_create_array_callback_t create_array = details::default_create_array
2165 : ) {
2166 : return metatensor::io::load_block(path, create_array);
2167 : }
2168 :
2169 : /*!
2170 : * \verbatim embed:rst:leading-asterisk
2171 : *
2172 : * Load a previously saved ``TensorBlock`` from a in-memory buffer.
2173 : *
2174 : * This is identical to :cpp:func:`metatensor::io::load_block_buffer`, and
2175 : * provided as a convenience API.
2176 : *
2177 : * \endverbatim
2178 : */
2179 : static TensorBlock load_buffer(
2180 : const uint8_t* buffer,
2181 : size_t buffer_count,
2182 : mts_create_array_callback_t create_array = details::default_create_array
2183 : ) {
2184 : return metatensor::io::load_block_buffer(buffer, buffer_count, create_array);
2185 : }
2186 :
2187 : /*!
2188 : * \verbatim embed:rst:leading-asterisk
2189 : *
2190 : * Load a previously saved ``TensorBlock`` from a in-memory buffer.
2191 : *
2192 : * This is identical to :cpp:func:`metatensor::io::load_block_buffer`, and
2193 : * provided as a convenience API.
2194 : *
2195 : * \endverbatim
2196 : */
2197 : template <typename Buffer>
2198 : static TensorBlock load_buffer(
2199 : const Buffer& buffer,
2200 : mts_create_array_callback_t create_array = details::default_create_array
2201 : ) {
2202 : return metatensor::io::load_block_buffer<Buffer>(buffer, create_array);
2203 : }
2204 :
2205 : /*!
2206 : * \verbatim embed:rst:leading-asterisk
2207 : *
2208 : * Save this ``TensorBlock`` to the given path.
2209 : *
2210 : * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
2211 : * convenience API.
2212 : *
2213 : * \endverbatim
2214 : */
2215 : void save(const std::string& path) const {
2216 : metatensor::io::save(path, *this);
2217 : }
2218 :
2219 : /*!
2220 : * \verbatim embed:rst:leading-asterisk
2221 : *
2222 : * Save this ``TensorBlock`` to an in-memory buffer.
2223 : *
2224 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
2225 : * provided as a convenience API.
2226 : *
2227 : * \endverbatim
2228 : */
2229 : std::vector<uint8_t> save_buffer() const {
2230 : return metatensor::io::save_buffer(*this);
2231 : }
2232 :
2233 : /*!
2234 : * \verbatim embed:rst:leading-asterisk
2235 : *
2236 : * Save this ``TensorBlock`` to an in-memory buffer.
2237 : *
2238 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
2239 : * provided as a convenience API.
2240 : *
2241 : * \endverbatim
2242 : */
2243 : template <typename Buffer>
2244 : Buffer save_buffer() const {
2245 : return metatensor::io::save_buffer<Buffer>(*this);
2246 : }
2247 :
2248 : private:
2249 : /// Constructor of a TensorBlock not associated with anything
2250 : explicit TensorBlock(): block_(nullptr), is_view_(true) {}
2251 :
2252 : /// Create a C++ TensorBlock from a C `mts_block_t` pointer. The C++
2253 : /// block takes ownership of the C pointer.
2254 : explicit TensorBlock(mts_block_t* block): block_(block), is_view_(false) {}
2255 :
2256 : /// Get the `mts_array_t` for this block.
2257 : ///
2258 : /// The returned `mts_array_t` should only be used in a const context
2259 62 : mts_array_t const_mts_array() const {
2260 : mts_array_t array;
2261 : std::memset(&array, 0, sizeof(array));
2262 :
2263 62 : details::check_status(
2264 62 : mts_block_data(block_, &array)
2265 : );
2266 62 : return array;
2267 : }
2268 :
2269 : /// Release the `mts_block_t` pointer corresponding to this `TensorBlock`.
2270 : ///
2271 : /// The block pointer is **no longer** managed by the current `TensorBlock`,
2272 : /// and should manually be freed when no longer required.
2273 : mts_block_t* release() {
2274 : if (is_view_) {
2275 : throw Error(
2276 : "can not call TensorBlock::release on this "
2277 : "block since it is a view inside a TensorMap"
2278 : );
2279 : }
2280 : auto* ptr = block_;
2281 : block_ = nullptr;
2282 : is_view_ = false;
2283 : return ptr;
2284 : }
2285 :
2286 : friend class TensorMap;
2287 : friend class metatensor_torch::TensorBlockHolder;
2288 : friend TensorBlock metatensor::io::load_block(
2289 : const std::string& path,
2290 : mts_create_array_callback_t create_array
2291 : );
2292 : friend TensorBlock metatensor::io::load_block_buffer(
2293 : const uint8_t* buffer,
2294 : size_t buffer_count,
2295 : mts_create_array_callback_t create_array
2296 : );
2297 :
2298 : mts_block_t* block_;
2299 : bool is_view_;
2300 : };
2301 :
2302 :
2303 : /******************************************************************************/
2304 : /******************************************************************************/
2305 : /* */
2306 : /* TensorMap */
2307 : /* */
2308 : /******************************************************************************/
2309 : /******************************************************************************/
2310 :
2311 : /// A TensorMap is the main user-facing class of this library, and can store any
2312 : /// kind of data used in atomistic machine learning.
2313 : ///
2314 : /// A tensor map contains a list of `TensorBlock`, each one associated with a
2315 : /// key. Users can access the blocks either one by one with the `block_by_id()`
2316 : /// function.
2317 : ///
2318 : /// A tensor map provides functions to move some of these keys to the samples or
2319 : /// properties labels of the blocks, moving from a sparse representation of the
2320 : /// data to a dense one.
2321 : class TensorMap final {
2322 : public:
2323 : /// Create a new TensorMap with the given `keys` and `blocks`
2324 : TensorMap(Labels keys, std::vector<TensorBlock> blocks) {
2325 : auto c_blocks = std::vector<mts_block_t*>();
2326 : for (auto& block: blocks) {
2327 : // We will move the data inside the new map, let's release the
2328 : // pointers out of the TensorBlock now
2329 : c_blocks.push_back(block.release());
2330 : }
2331 :
2332 : tensor_ = mts_tensormap(
2333 : keys.as_mts_labels_t(),
2334 : c_blocks.data(),
2335 : c_blocks.size()
2336 : );
2337 :
2338 : details::check_pointer(tensor_);
2339 : }
2340 :
2341 : ~TensorMap() {
2342 : mts_tensormap_free(tensor_);
2343 : }
2344 :
2345 : /// TensorMap can NOT be copy constructed, use TensorMap::clone instead
2346 : TensorMap(const TensorMap&) = delete;
2347 : /// TensorMap can not be copy assigned, use TensorMap::clone instead
2348 : TensorMap& operator=(const TensorMap&) = delete;
2349 :
2350 : /// TensorMap can be move constructed
2351 : TensorMap(TensorMap&& other) noexcept : TensorMap(nullptr) {
2352 : *this = std::move(other);
2353 : }
2354 :
2355 : /// TensorMap can be move assigned
2356 : TensorMap& operator=(TensorMap&& other) noexcept {
2357 : mts_tensormap_free(tensor_);
2358 :
2359 : this->tensor_ = other.tensor_;
2360 : other.tensor_ = nullptr;
2361 :
2362 : return *this;
2363 : }
2364 :
2365 : /// Make a copy of this `TensorMap`, including all the data contained inside
2366 : TensorMap clone() const {
2367 : auto* copy = mts_tensormap_copy(this->tensor_);
2368 : details::check_pointer(copy);
2369 : return TensorMap(copy);
2370 : }
2371 :
2372 : /// Get a copy of the metadata in this `TensorMap` (i.e. keys, samples,
2373 : /// components, and properties), ignoring the data itself.
2374 : ///
2375 : /// The resulting blocks values will be an `EmptyDataArray` instance, which
2376 : /// does not contain any data.
2377 : TensorMap clone_metadata_only() const {
2378 : auto n_blocks = this->keys().count();
2379 :
2380 : auto blocks = std::vector<TensorBlock>();
2381 : blocks.reserve(n_blocks);
2382 : for (uintptr_t i=0; i<n_blocks; i++) {
2383 : mts_block_t* block_ptr = nullptr;
2384 : details::check_status(mts_tensormap_block_by_id(tensor_, &block_ptr, i));
2385 : details::check_pointer(block_ptr);
2386 : auto block = TensorBlock::unsafe_view_from_ptr(block_ptr);
2387 :
2388 : blocks.push_back(block.clone_metadata_only());
2389 : }
2390 :
2391 : return TensorMap(this->keys(), std::move(blocks));
2392 : }
2393 :
2394 : /// Get the set of keys labeling the blocks in this tensor map
2395 : Labels keys() const {
2396 : mts_labels_t keys;
2397 : std::memset(&keys, 0, sizeof(keys));
2398 :
2399 : details::check_status(mts_tensormap_keys(tensor_, &keys));
2400 : return Labels(keys);
2401 : }
2402 :
2403 : /// Get a (possibly empty) list of block indexes matching the `selection`
2404 : std::vector<uintptr_t> blocks_matching(const Labels& selection) const {
2405 : auto matching = std::vector<uintptr_t>(this->keys().count());
2406 : uintptr_t count = matching.size();
2407 :
2408 : details::check_status(mts_tensormap_blocks_matching(
2409 : tensor_,
2410 : matching.data(),
2411 : &count,
2412 : selection.as_mts_labels_t()
2413 : ));
2414 :
2415 : assert(count <= matching.size());
2416 : matching.resize(count);
2417 : return matching;
2418 : }
2419 :
2420 : /// Get a block inside this TensorMap by it's index/the index of the
2421 : /// corresponding key.
2422 : ///
2423 : /// The returned `TensorBlock` is a view inside memory owned by this
2424 : /// `TensorMap`, and is only valid as long as the `TensorMap` is kept alive.
2425 : TensorBlock block_by_id(uintptr_t index) & {
2426 : mts_block_t* block = nullptr;
2427 : details::check_status(mts_tensormap_block_by_id(tensor_, &block, index));
2428 : details::check_pointer(block);
2429 :
2430 : return TensorBlock::unsafe_view_from_ptr(block);
2431 : }
2432 :
2433 : TensorBlock block_by_id(uintptr_t index) && = delete;
2434 :
2435 : /// Merge blocks with the same value for selected keys dimensions along the
2436 : /// property axis.
2437 : ///
2438 : /// The dimensions (names) of `keys_to_move` will be moved from the keys to
2439 : /// the property labels, and blocks with the same remaining keys dimensions
2440 : /// will be merged together along the property axis.
2441 : ///
2442 : /// If `keys_to_move` does not contains any entries (i.e.
2443 : /// `keys_to_move.count() == 0`), then the new property labels will contain
2444 : /// entries corresponding to the merged blocks only. For example, merging a
2445 : /// block with key `a=0` and properties `p=1, 2` with a block with key `a=2`
2446 : /// and properties `p=1, 3` will produce a block with properties
2447 : /// `a, p = (0, 1), (0, 2), (2, 1), (2, 3)`.
2448 : ///
2449 : /// If `keys_to_move` contains entries, then the property labels must be the
2450 : /// same for all the merged blocks. In that case, the merged property labels
2451 : /// will contains each of the entries of `keys_to_move` and then the current
2452 : /// property labels. For example, using `a=2, 3` in `keys_to_move`, and
2453 : /// blocks with properties `p=1, 2` will result in `a, p = (2, 1), (2, 2),
2454 : /// (3, 1), (3, 2)`.
2455 : ///
2456 : /// The new sample labels will contains all of the merged blocks sample
2457 : /// labels. The order of the samples is controlled by `sort_samples`. If
2458 : /// `sort_samples` is true, samples are re-ordered to keep them
2459 : /// lexicographically sorted. Otherwise they are kept in the order in which
2460 : /// they appear in the blocks.
2461 : ///
2462 : /// @param keys_to_move description of the keys to move
2463 : /// @param sort_samples whether to sort the merged samples or keep them in
2464 : /// the order in which they appear in the original blocks
2465 : TensorMap keys_to_properties(const Labels& keys_to_move, bool sort_samples = true) const {
2466 : auto* ptr = mts_tensormap_keys_to_properties(
2467 : tensor_,
2468 : keys_to_move.as_mts_labels_t(),
2469 : sort_samples
2470 : );
2471 :
2472 : details::check_pointer(ptr);
2473 : return TensorMap(ptr);
2474 : }
2475 :
2476 : /// This function calls `keys_to_properties` with an empty set of `Labels`
2477 : /// with the dimensions defined in `keys_to_move`
2478 : TensorMap keys_to_properties(const std::vector<std::string>& keys_to_move, bool sort_samples = true) const {
2479 : return keys_to_properties(Labels(keys_to_move), sort_samples);
2480 : }
2481 :
2482 : /// This function calls `keys_to_properties` with an empty set of `Labels`
2483 : /// with a single dimension: `key_to_move`
2484 : TensorMap keys_to_properties(const std::string& key_to_move, bool sort_samples = true) const {
2485 : return keys_to_properties(std::vector<std::string>{key_to_move}, sort_samples);
2486 : }
2487 :
2488 : /// Merge blocks with the same value for selected keys dimensions along the
2489 : /// samples axis.
2490 : ///
2491 : /// The dimensions (names) of `keys_to_move` will be moved from the keys to
2492 : /// the sample labels, and blocks with the same remaining keys dimensions
2493 : /// will be merged together along the sample axis.
2494 : ///
2495 : /// If `keys_to_move` must be an empty set of `Labels`
2496 : /// (`keys_to_move.count() == 0`). The new sample labels will contain
2497 : /// entries corresponding to the merged blocks' keys.
2498 : ///
2499 : /// The order of the samples is controlled by `sort_samples`. If
2500 : /// `sort_samples` is true, samples are re-ordered to keep them
2501 : /// lexicographically sorted. Otherwise they are kept in the order in which
2502 : /// they appear in the blocks.
2503 : ///
2504 : /// This function is only implemented if all merged block have the same
2505 : /// property labels.
2506 : ///
2507 : /// @param keys_to_move description of the keys to move
2508 : /// @param sort_samples whether to sort the merged samples or keep them in
2509 : /// the order in which they appear in the original blocks
2510 : TensorMap keys_to_samples(const Labels& keys_to_move, bool sort_samples = true) const {
2511 : auto* ptr = mts_tensormap_keys_to_samples(
2512 : tensor_,
2513 : keys_to_move.as_mts_labels_t(),
2514 : sort_samples
2515 : );
2516 :
2517 : details::check_pointer(ptr);
2518 : return TensorMap(ptr);
2519 : }
2520 :
2521 : /// This function calls `keys_to_samples` with an empty set of `Labels`
2522 : /// with the dimensions defined in `keys_to_move`
2523 : TensorMap keys_to_samples(const std::vector<std::string>& keys_to_move, bool sort_samples = true) const {
2524 : return keys_to_samples(Labels(keys_to_move), sort_samples);
2525 : }
2526 :
2527 : /// This function calls `keys_to_samples` with an empty set of `Labels`
2528 : /// with a single dimension: `key_to_move`
2529 : TensorMap keys_to_samples(const std::string& key_to_move, bool sort_samples = true) const {
2530 : return keys_to_samples(std::vector<std::string>{key_to_move}, sort_samples);
2531 : }
2532 :
2533 : /// Move the given `dimensions` from the component labels to the property
2534 : /// labels for each block.
2535 : ///
2536 : /// @param dimensions name of the component dimensions to move to the
2537 : /// properties
2538 : TensorMap components_to_properties(const std::vector<std::string>& dimensions) const {
2539 : auto c_dimensions = std::vector<const char*>();
2540 : for (const auto& v: dimensions) {
2541 : c_dimensions.push_back(v.c_str());
2542 : }
2543 :
2544 : auto* ptr = mts_tensormap_components_to_properties(
2545 : tensor_,
2546 : c_dimensions.data(),
2547 : c_dimensions.size()
2548 : );
2549 : details::check_pointer(ptr);
2550 : return TensorMap(ptr);
2551 : }
2552 :
2553 : /// Call `components_to_properties` with a single dimension
2554 : TensorMap components_to_properties(const std::string& dimension) const {
2555 : const char* c_str = dimension.c_str();
2556 : auto* ptr = mts_tensormap_components_to_properties(
2557 : tensor_,
2558 : &c_str,
2559 : 1
2560 : );
2561 : details::check_pointer(ptr);
2562 : return TensorMap(ptr);
2563 : }
2564 :
2565 : /*!
2566 : * \verbatim embed:rst:leading-asterisk
2567 : *
2568 : * Load a previously saved ``TensorMap`` from the given path.
2569 : *
2570 : * This is identical to :cpp:func:`metatensor::io::load`, and provided as a
2571 : * convenience API.
2572 : *
2573 : * \endverbatim
2574 : */
2575 : static TensorMap load(
2576 : const std::string& path,
2577 : mts_create_array_callback_t create_array = details::default_create_array
2578 : ) {
2579 : return metatensor::io::load(path, create_array);
2580 : }
2581 :
2582 : /*!
2583 : * \verbatim embed:rst:leading-asterisk
2584 : *
2585 : * Load a previously saved ``TensorMap`` from a in-memory buffer.
2586 : *
2587 : * This is identical to :cpp:func:`metatensor::io::load_buffer`, and
2588 : * provided as a convenience API.
2589 : *
2590 : * \endverbatim
2591 : */
2592 : static TensorMap load_buffer(
2593 : const uint8_t* buffer,
2594 : size_t buffer_count,
2595 : mts_create_array_callback_t create_array = details::default_create_array
2596 : ) {
2597 : return metatensor::io::load_buffer(buffer, buffer_count, create_array);
2598 : }
2599 :
2600 : /*!
2601 : * \verbatim embed:rst:leading-asterisk
2602 : *
2603 : * Load a previously saved ``TensorMap`` from a in-memory buffer.
2604 : *
2605 : * This is identical to :cpp:func:`metatensor::io::load_buffer`, and
2606 : * provided as a convenience API.
2607 : *
2608 : * \endverbatim
2609 : */
2610 : template <typename Buffer>
2611 : static TensorMap load_buffer(
2612 : const Buffer& buffer,
2613 : mts_create_array_callback_t create_array = details::default_create_array
2614 : ) {
2615 : return metatensor::io::load_buffer<Buffer>(buffer, create_array);
2616 : }
2617 :
2618 : /*!
2619 : * \verbatim embed:rst:leading-asterisk
2620 : *
2621 : * Save this ``TensorMap`` to the given path.
2622 : *
2623 : * This is identical to :cpp:func:`metatensor::io::save`, and provided as a
2624 : * convenience API.
2625 : *
2626 : * \endverbatim
2627 : */
2628 : void save(const std::string& path) const {
2629 : metatensor::io::save(path, *this);
2630 : }
2631 :
2632 : /*!
2633 : * \verbatim embed:rst:leading-asterisk
2634 : *
2635 : * Save this ``TensorMap`` to an in-memory buffer.
2636 : *
2637 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
2638 : * provided as a convenience API.
2639 : *
2640 : * \endverbatim
2641 : */
2642 : std::vector<uint8_t> save_buffer() const {
2643 : return metatensor::io::save_buffer(*this);
2644 : }
2645 :
2646 : /*!
2647 : * \verbatim embed:rst:leading-asterisk
2648 : *
2649 : * Save this ``TensorMap`` to an in-memory buffer.
2650 : *
2651 : * This is identical to :cpp:func:`metatensor::io::save_buffer`, and
2652 : * provided as a convenience API.
2653 : *
2654 : * \endverbatim
2655 : */
2656 : template <typename Buffer>
2657 : Buffer save_buffer() const {
2658 : return metatensor::io::save_buffer<Buffer>(*this);
2659 : }
2660 :
2661 : /// Get the `mts_tensormap_t` pointer corresponding to this `TensorMap`.
2662 : ///
2663 : /// The tensor map pointer is still managed by the current `TensorMap`
2664 : mts_tensormap_t* as_mts_tensormap_t() & {
2665 : return tensor_;
2666 : }
2667 :
2668 : /// Get the const `mts_tensormap_t` pointer corresponding to this `TensorMap`.
2669 : ///
2670 : /// The tensor map pointer is still managed by the current `TensorMap`
2671 : const mts_tensormap_t* as_mts_tensormap_t() const & {
2672 : return tensor_;
2673 : }
2674 :
2675 : mts_tensormap_t* as_mts_tensormap_t() && = delete;
2676 :
2677 : /// Create a C++ TensorMap from a C `mts_tensormap_t` pointer. The C++
2678 : /// tensor map takes ownership of the C pointer.
2679 : explicit TensorMap(mts_tensormap_t* tensor): tensor_(tensor) {}
2680 :
2681 : private:
2682 : mts_tensormap_t* tensor_;
2683 : };
2684 :
2685 :
2686 : /******************************************************************************/
2687 : /******************************************************************************/
2688 : /* */
2689 : /* I/O functionalities implementation */
2690 : /* */
2691 : /******************************************************************************/
2692 : /******************************************************************************/
2693 :
2694 :
2695 : namespace io {
2696 : inline void save(const std::string& path, const TensorMap& tensor) {
2697 : details::check_status(mts_tensormap_save(path.c_str(), tensor.as_mts_tensormap_t()));
2698 : }
2699 :
2700 : template <typename Buffer>
2701 : Buffer save_buffer(const TensorMap& tensor) {
2702 : auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(tensor);
2703 : return Buffer(buffer.begin(), buffer.end());
2704 : }
2705 :
2706 : template<>
2707 : inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorMap& tensor) {
2708 : std::vector<uint8_t> buffer;
2709 :
2710 : auto* ptr = buffer.data();
2711 : auto size = buffer.size();
2712 :
2713 : auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
2714 : auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
2715 : buffer->resize(new_size, '\0');
2716 : return buffer->data();
2717 : };
2718 :
2719 : details::check_status(mts_tensormap_save_buffer(
2720 : &ptr,
2721 : &size,
2722 : &buffer,
2723 : realloc,
2724 : tensor.as_mts_tensormap_t()
2725 : ));
2726 :
2727 : buffer.resize(size, '\0');
2728 :
2729 : return buffer;
2730 : }
2731 :
2732 : /**************************************************************************/
2733 :
2734 : inline void save(const std::string& path, const TensorBlock& block) {
2735 : details::check_status(mts_block_save(path.c_str(), block.as_mts_block_t()));
2736 : }
2737 :
2738 : template <typename Buffer>
2739 : Buffer save_buffer(const TensorBlock& block) {
2740 : auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(block);
2741 : return Buffer(buffer.begin(), buffer.end());
2742 : }
2743 :
2744 : template<>
2745 : inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const TensorBlock& block) {
2746 : std::vector<uint8_t> buffer;
2747 :
2748 : auto* ptr = buffer.data();
2749 : auto size = buffer.size();
2750 :
2751 : auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
2752 : auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
2753 : buffer->resize(new_size, '\0');
2754 : return buffer->data();
2755 : };
2756 :
2757 : details::check_status(mts_block_save_buffer(
2758 : &ptr,
2759 : &size,
2760 : &buffer,
2761 : realloc,
2762 : block.as_mts_block_t()
2763 : ));
2764 :
2765 : buffer.resize(size, '\0');
2766 :
2767 : return buffer;
2768 : }
2769 :
2770 : /**************************************************************************/
2771 :
2772 : inline void save(const std::string& path, const Labels& labels) {
2773 : details::check_status(mts_labels_save(path.c_str(), labels.as_mts_labels_t()));
2774 : }
2775 :
2776 : template <typename Buffer>
2777 : Buffer save_buffer(const Labels& labels) {
2778 : auto buffer = metatensor::io::save_buffer<std::vector<uint8_t>>(labels);
2779 : return Buffer(buffer.begin(), buffer.end());
2780 : }
2781 :
2782 : template<>
2783 : inline std::vector<uint8_t> save_buffer<std::vector<uint8_t>>(const Labels& labels) {
2784 : std::vector<uint8_t> buffer;
2785 :
2786 : auto* ptr = buffer.data();
2787 : auto size = buffer.size();
2788 :
2789 : auto realloc = [](void* user_data, uint8_t*, uintptr_t new_size) {
2790 : auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(user_data);
2791 : buffer->resize(new_size, '\0');
2792 : return buffer->data();
2793 : };
2794 :
2795 : details::check_status(mts_labels_save_buffer(
2796 : &ptr,
2797 : &size,
2798 : &buffer,
2799 : realloc,
2800 : labels.as_mts_labels_t()
2801 : ));
2802 :
2803 : buffer.resize(size, '\0');
2804 :
2805 : return buffer;
2806 : }
2807 :
2808 : /**************************************************************************/
2809 : /**************************************************************************/
2810 :
2811 : inline TensorMap load(
2812 : const std::string& path,
2813 : mts_create_array_callback_t create_array
2814 : ) {
2815 : auto* ptr = mts_tensormap_load(path.c_str(), create_array);
2816 : details::check_pointer(ptr);
2817 : return TensorMap(ptr);
2818 : }
2819 :
2820 : inline TensorMap load_buffer(
2821 : const uint8_t* buffer,
2822 : size_t buffer_count,
2823 : mts_create_array_callback_t create_array
2824 : ) {
2825 : auto* ptr = mts_tensormap_load_buffer(buffer, buffer_count, create_array);
2826 : details::check_pointer(ptr);
2827 : return TensorMap(ptr);
2828 : }
2829 :
2830 : template <typename Buffer>
2831 : TensorMap load_buffer(
2832 : const Buffer& buffer,
2833 : mts_create_array_callback_t create_array
2834 : ) {
2835 : static_assert(
2836 : sizeof(typename Buffer::value_type) == sizeof(uint8_t),
2837 : "`Buffer` must be a container of uint8_t or equivalent"
2838 : );
2839 :
2840 : return metatensor::io::load_buffer(
2841 : reinterpret_cast<const uint8_t*>(buffer.data()),
2842 : buffer.size(),
2843 : create_array
2844 : );
2845 : }
2846 :
2847 : /**************************************************************************/
2848 :
2849 : inline TensorBlock load_block(
2850 : const std::string& path,
2851 : mts_create_array_callback_t create_array
2852 : ) {
2853 : auto* ptr = mts_block_load(path.c_str(), create_array);
2854 : details::check_pointer(ptr);
2855 : return TensorBlock(ptr);
2856 : }
2857 :
2858 : inline TensorBlock load_block_buffer(
2859 : const uint8_t* buffer,
2860 : size_t buffer_count,
2861 : mts_create_array_callback_t create_array
2862 : ) {
2863 : auto* ptr = mts_block_load_buffer(buffer, buffer_count, create_array);
2864 : details::check_pointer(ptr);
2865 : return TensorBlock(ptr);
2866 : }
2867 :
2868 : template <typename Buffer>
2869 : TensorBlock load_block_buffer(
2870 : const Buffer& buffer,
2871 : mts_create_array_callback_t create_array
2872 : ) {
2873 : static_assert(
2874 : sizeof(typename Buffer::value_type) == sizeof(uint8_t),
2875 : "`Buffer` must be a container of uint8_t or equivalent"
2876 : );
2877 :
2878 : return metatensor::io::load_block_buffer(
2879 : reinterpret_cast<const uint8_t*>(buffer.data()),
2880 : buffer.size(),
2881 : create_array
2882 : );
2883 : }
2884 :
2885 : /**************************************************************************/
2886 :
2887 : inline Labels load_labels(const std::string& path) {
2888 : mts_labels_t labels;
2889 : std::memset(&labels, 0, sizeof(labels));
2890 :
2891 : details::check_status(mts_labels_load(
2892 : path.c_str(), &labels
2893 : ));
2894 :
2895 : return Labels(labels);
2896 : }
2897 :
2898 : inline Labels load_labels_buffer(const uint8_t* buffer, size_t buffer_count) {
2899 : mts_labels_t labels;
2900 : std::memset(&labels, 0, sizeof(labels));
2901 :
2902 : details::check_status(mts_labels_load_buffer(
2903 : buffer, buffer_count, &labels
2904 : ));
2905 :
2906 : return Labels(labels);
2907 : }
2908 :
2909 : template <typename Buffer>
2910 : Labels load_labels_buffer(const Buffer& buffer) {
2911 : static_assert(
2912 : sizeof(typename Buffer::value_type) == sizeof(uint8_t),
2913 : "`Buffer` must be a container of uint8_t or equivalent"
2914 : );
2915 :
2916 : return metatensor::io::load_labels_buffer(
2917 : reinterpret_cast<const uint8_t*>(buffer.data()),
2918 : buffer.size()
2919 : );
2920 : }
2921 : }
2922 :
2923 : }
2924 :
2925 : #endif /* METATENSOR_HPP */
|