LCOV - code coverage report
Current view: top level - home/runner/.local/lib/python3.10/site-packages/torch/include/c10/util - hash.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 0 14 0.0 %
Date: 2026-06-05 17:04:24 Functions: 0 4 0.0 %

          Line data    Source code
       1             : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
       2             : #pragma once
       3             : 
       4             : #include <c10/util/Exception.h>
       5             : #include <cstddef>
       6             : #include <functional>
       7             : #include <iomanip>
       8             : #include <ios>
       9             : #include <sstream>
      10             : #include <string>
      11             : #include <tuple>
      12             : #include <type_traits>
      13             : #include <utility>
      14             : #include <vector>
      15             : 
      16             : #include <c10/util/ArrayRef.h>
      17             : #include <c10/util/complex.h>
      18             : 
      19             : namespace c10 {
      20             : 
      21             : // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
      22             : //
      23             : // Boost Software License - Version 1.0 - August 17th, 2003
      24             : //
      25             : // Permission is hereby granted, free of charge, to any person or organization
      26             : // obtaining a copy of the software and accompanying documentation covered by
      27             : // this license (the "Software") to use, reproduce, display, distribute,
      28             : // execute, and transmit the Software, and to prepare derivative works of the
      29             : // Software, and to permit third-parties to whom the Software is furnished to
      30             : // do so, all subject to the following:
      31             : //
      32             : // The copyright notices in the Software and this entire statement, including
      33             : // the above license grant, this restriction and the following disclaimer,
      34             : // must be included in all copies of the Software, in whole or in part, and
      35             : // all derivative works of the Software, unless such copies or derivative
      36             : // works are solely in the form of machine-executable object code generated by
      37             : // a source language processor.
      38             : //
      39             : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
      40             : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
      41             : // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
      42             : // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
      43             : // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
      44             : // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
      45             : // DEALINGS IN THE SOFTWARE.
      46             : 
      47             : inline size_t hash_combine(size_t seed, size_t value) {
      48           0 :   return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
      49             : }
      50             : 
      51             : // Creates the SHA1 hash of a string. A 160-bit hash.
      52             : // Based on the implementation in Boost (see notice above).
      53             : // Note that SHA1 hashes are no longer considered cryptographically
      54             : //   secure, but are the standard hash for generating unique ids.
      55             : // Usage:
      56             : //   // Let 'code' be a std::string
      57             : //   c10::sha1 sha1_hash{code};
      58             : //   const auto hash_code = sha1_hash.str();
      59             : // TODO: Compare vs OpenSSL and/or CryptoPP implementations
      60             : struct sha1 {
      61             :   typedef unsigned int(digest_type)[5];
      62             : 
      63             :   sha1(const std::string& s = "") {
      64             :     if (!s.empty()) {
      65             :       reset();
      66             :       process_bytes(s.c_str(), s.size());
      67             :     }
      68             :   }
      69             : 
      70             :   void reset() {
      71             :     h_[0] = 0x67452301;
      72             :     h_[1] = 0xEFCDAB89;
      73             :     h_[2] = 0x98BADCFE;
      74             :     h_[3] = 0x10325476;
      75             :     h_[4] = 0xC3D2E1F0;
      76             : 
      77             :     block_byte_index_ = 0;
      78             :     bit_count_low = 0;
      79             :     bit_count_high = 0;
      80             :   }
      81             : 
      82             :   std::string str() {
      83             :     unsigned int digest[5];
      84             :     get_digest(digest);
      85             : 
      86             :     std::ostringstream buf;
      87             :     for (unsigned int i : digest) {
      88             :       buf << std::hex << std::setfill('0') << std::setw(8) << i;
      89             :     }
      90             : 
      91             :     return buf.str();
      92             :   }
      93             : 
      94             :  private:
      95             :   unsigned int left_rotate(unsigned int x, std::size_t n) {
      96             :     return (x << n) ^ (x >> (32 - n));
      97             :   }
      98             : 
      99             :   void process_block_impl() {
     100             :     unsigned int w[80];
     101             : 
     102             :     for (std::size_t i = 0; i < 16; ++i) {
     103             :       w[i] = (block_[i * 4 + 0] << 24);
     104             :       w[i] |= (block_[i * 4 + 1] << 16);
     105             :       w[i] |= (block_[i * 4 + 2] << 8);
     106             :       w[i] |= (block_[i * 4 + 3]);
     107             :     }
     108             : 
     109             :     for (std::size_t i = 16; i < 80; ++i) {
     110             :       w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
     111             :     }
     112             : 
     113             :     unsigned int a = h_[0];
     114             :     unsigned int b = h_[1];
     115             :     unsigned int c = h_[2];
     116             :     unsigned int d = h_[3];
     117             :     unsigned int e = h_[4];
     118             : 
     119             :     for (std::size_t i = 0; i < 80; ++i) {
     120             :       unsigned int f = 0;
     121             :       unsigned int k = 0;
     122             : 
     123             :       if (i < 20) {
     124             :         f = (b & c) | (~b & d);
     125             :         k = 0x5A827999;
     126             :       } else if (i < 40) {
     127             :         f = b ^ c ^ d;
     128             :         k = 0x6ED9EBA1;
     129             :       } else if (i < 60) {
     130             :         f = (b & c) | (b & d) | (c & d);
     131             :         k = 0x8F1BBCDC;
     132             :       } else {
     133             :         f = b ^ c ^ d;
     134             :         k = 0xCA62C1D6;
     135             :       }
     136             : 
     137             :       unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
     138             :       e = d;
     139             :       d = c;
     140             :       c = left_rotate(b, 30);
     141             :       b = a;
     142             :       a = temp;
     143             :     }
     144             : 
     145             :     h_[0] += a;
     146             :     h_[1] += b;
     147             :     h_[2] += c;
     148             :     h_[3] += d;
     149             :     h_[4] += e;
     150             :   }
     151             : 
     152             :   void process_byte_impl(unsigned char byte) {
     153             :     block_[block_byte_index_++] = byte;
     154             : 
     155             :     if (block_byte_index_ == 64) {
     156             :       block_byte_index_ = 0;
     157             :       process_block_impl();
     158             :     }
     159             :   }
     160             : 
     161             :   void process_byte(unsigned char byte) {
     162             :     process_byte_impl(byte);
     163             : 
     164             :     // size_t max value = 0xFFFFFFFF
     165             :     // if (bit_count_low + 8 >= 0x100000000) { // would overflow
     166             :     // if (bit_count_low >= 0x100000000-8) {
     167             :     if (bit_count_low < 0xFFFFFFF8) {
     168             :       bit_count_low += 8;
     169             :     } else {
     170             :       bit_count_low = 0;
     171             : 
     172             :       if (bit_count_high <= 0xFFFFFFFE) {
     173             :         ++bit_count_high;
     174             :       } else {
     175             :         TORCH_CHECK(false, "sha1 too many bytes");
     176             :       }
     177             :     }
     178             :   }
     179             : 
     180             :   void process_block(void const* bytes_begin, void const* bytes_end) {
     181             :     unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
     182             :     unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
     183             :     for (; begin != end; ++begin) {
     184             :       process_byte(*begin);
     185             :     }
     186             :   }
     187             : 
     188             :   void process_bytes(void const* buffer, std::size_t byte_count) {
     189             :     unsigned char const* b = static_cast<unsigned char const*>(buffer);
     190             :     process_block(b, b + byte_count);
     191             :   }
     192             : 
     193             :   void get_digest(digest_type& digest) {
     194             :     // append the bit '1' to the message
     195             :     process_byte_impl(0x80);
     196             : 
     197             :     // append k bits '0', where k is the minimum number >= 0
     198             :     // such that the resulting message length is congruent to 56 (mod 64)
     199             :     // check if there is enough space for padding and bit_count
     200             :     if (block_byte_index_ > 56) {
     201             :       // finish this block
     202             :       while (block_byte_index_ != 0) {
     203             :         process_byte_impl(0);
     204             :       }
     205             : 
     206             :       // one more block
     207             :       while (block_byte_index_ < 56) {
     208             :         process_byte_impl(0);
     209             :       }
     210             :     } else {
     211             :       while (block_byte_index_ < 56) {
     212             :         process_byte_impl(0);
     213             :       }
     214             :     }
     215             : 
     216             :     // append length of message (before pre-processing)
     217             :     // as a 64-bit big-endian integer
     218             :     process_byte_impl(
     219             :         static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
     220             :     process_byte_impl(
     221             :         static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
     222             :     process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
     223             :     process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
     224             :     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
     225             :     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
     226             :     process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
     227             :     process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
     228             : 
     229             :     // get final digest
     230             :     digest[0] = h_[0];
     231             :     digest[1] = h_[1];
     232             :     digest[2] = h_[2];
     233             :     digest[3] = h_[3];
     234             :     digest[4] = h_[4];
     235             :   }
     236             : 
     237             :   unsigned int h_[5]{};
     238             :   unsigned char block_[64]{};
     239             :   std::size_t block_byte_index_{};
     240             :   std::size_t bit_count_low{};
     241             :   std::size_t bit_count_high{};
     242             : };
     243             : 
     244             : constexpr uint64_t twang_mix64(uint64_t key) noexcept {
     245             :   key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
     246             :   key = key ^ (key >> 24);
     247             :   key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
     248             :   key = key ^ (key >> 14);
     249             :   key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
     250             :   key = key ^ (key >> 28);
     251             :   key = key + (key << 31); // key *= 1 + (1 << 31)
     252             :   return key;
     253             : }
     254             : 
     255             : ////////////////////////////////////////////////////////////////////////////////
     256             : // c10::hash implementation
     257             : ////////////////////////////////////////////////////////////////////////////////
     258             : 
     259             : namespace _hash_detail {
     260             : 
     261             : // Use template argument deduction to shorten calls to c10::hash
     262             : template <typename T>
     263             : size_t simple_get_hash(const T& o);
     264             : 
     265             : template <typename T, typename V>
     266             : using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
     267             : 
     268             : // Use SFINAE to dispatch to std::hash if possible, cast enum types to int
     269             : // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
     270             : // for hashing enum types to the standard, and some compilers implement it even
     271             : // when C++14 flags aren't specified. This is why we have to disable this
     272             : // overload if T is an enum type (and use the one below in this case).
     273             : template <typename T>
     274           0 : auto dispatch_hash(const T& o)
     275             :     -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
     276           0 :   return std::hash<T>()(o);
     277             : }
     278             : 
     279             : template <typename T>
     280             : std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
     281             :   using R = std::underlying_type_t<T>;
     282             :   return std::hash<R>()(static_cast<R>(o));
     283             : }
     284             : 
     285             : template <typename T>
     286             : auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
     287             :   return T::hash(o);
     288             : }
     289             : 
     290             : } // namespace _hash_detail
     291             : 
     292             : // Hasher struct
     293             : template <typename T>
     294             : struct hash {
     295             :   size_t operator()(const T& o) const {
     296           0 :     return _hash_detail::dispatch_hash(o);
     297             :   }
     298             : };
     299             : 
     300             : // Specialization for std::tuple
     301             : template <typename... Types>
     302             : struct hash<std::tuple<Types...>> {
     303             :   template <size_t idx, typename... Ts>
     304             :   struct tuple_hash {
     305           0 :     size_t operator()(const std::tuple<Ts...>& t) const {
     306           0 :       return hash_combine(
     307             :           _hash_detail::simple_get_hash(std::get<idx>(t)),
     308           0 :           tuple_hash<idx - 1, Ts...>()(t));
     309             :     }
     310             :   };
     311             : 
     312             :   template <typename... Ts>
     313             :   struct tuple_hash<0, Ts...> {
     314             :     size_t operator()(const std::tuple<Ts...>& t) const {
     315           0 :       return _hash_detail::simple_get_hash(std::get<0>(t));
     316             :     }
     317             :   };
     318             : 
     319             :   size_t operator()(const std::tuple<Types...>& t) const {
     320           0 :     return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
     321             :   }
     322             : };
     323             : 
     324             : template <typename T1, typename T2>
     325             : struct hash<std::pair<T1, T2>> {
     326             :   size_t operator()(const std::pair<T1, T2>& pair) const {
     327             :     std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
     328             :     return _hash_detail::simple_get_hash(tuple);
     329             :   }
     330             : };
     331             : 
     332             : template <typename T>
     333             : struct hash<c10::ArrayRef<T>> {
     334             :   size_t operator()(c10::ArrayRef<T> v) const {
     335             :     size_t seed = 0;
     336             :     for (const auto& elem : v) {
     337             :       seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
     338             :     }
     339             :     return seed;
     340             :   }
     341             : };
     342             : 
     343             : // Specialization for std::vector
     344             : template <typename T>
     345             : struct hash<std::vector<T>> {
     346             :   size_t operator()(const std::vector<T>& v) const {
     347             :     return hash<c10::ArrayRef<T>>()(v);
     348             :   }
     349             : };
     350             : 
     351             : namespace _hash_detail {
     352             : 
     353             : template <typename T>
     354           0 : size_t simple_get_hash(const T& o) {
     355           0 :   return c10::hash<T>()(o);
     356             : }
     357             : 
     358             : } // namespace _hash_detail
     359             : 
     360             : // Use this function to actually hash multiple things in one line.
     361             : // Dispatches to c10::hash, so it can hash containers.
     362             : // Example:
     363             : //
     364             : // static size_t hash(const MyStruct& s) {
     365             : //   return get_hash(s.member1, s.member2, s.member3);
     366             : // }
     367             : template <typename... Types>
     368           0 : size_t get_hash(const Types&... args) {
     369           0 :   return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
     370             : }
     371             : 
     372             : // Specialization for c10::complex
     373             : template <typename T>
     374             : struct hash<c10::complex<T>> {
     375             :   size_t operator()(const c10::complex<T>& c) const {
     376           0 :     return get_hash(c.real(), c.imag());
     377             :   }
     378             : };
     379             : 
     380             : } // namespace c10
     381             : 
     382             : #else
     383             : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
     384             : #endif  // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)

Generated by: LCOV version 1.16