Line data Source code
1 : #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2 : #pragma once
3 :
4 : #include <c10/util/Exception.h>
5 : #include <cstddef>
6 : #include <functional>
7 : #include <iomanip>
8 : #include <ios>
9 : #include <sstream>
10 : #include <string>
11 : #include <tuple>
12 : #include <type_traits>
13 : #include <utility>
14 : #include <vector>
15 :
16 : #include <c10/util/ArrayRef.h>
17 : #include <c10/util/complex.h>
18 :
19 : namespace c10 {
20 :
21 : // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
22 : //
23 : // Boost Software License - Version 1.0 - August 17th, 2003
24 : //
25 : // Permission is hereby granted, free of charge, to any person or organization
26 : // obtaining a copy of the software and accompanying documentation covered by
27 : // this license (the "Software") to use, reproduce, display, distribute,
28 : // execute, and transmit the Software, and to prepare derivative works of the
29 : // Software, and to permit third-parties to whom the Software is furnished to
30 : // do so, all subject to the following:
31 : //
32 : // The copyright notices in the Software and this entire statement, including
33 : // the above license grant, this restriction and the following disclaimer,
34 : // must be included in all copies of the Software, in whole or in part, and
35 : // all derivative works of the Software, unless such copies or derivative
36 : // works are solely in the form of machine-executable object code generated by
37 : // a source language processor.
38 : //
39 : // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
40 : // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
41 : // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
42 : // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
43 : // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
44 : // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
45 : // DEALINGS IN THE SOFTWARE.
46 :
47 : inline size_t hash_combine(size_t seed, size_t value) {
48 0 : return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
49 : }
50 :
51 : // Creates the SHA1 hash of a string. A 160-bit hash.
52 : // Based on the implementation in Boost (see notice above).
53 : // Note that SHA1 hashes are no longer considered cryptographically
54 : // secure, but are the standard hash for generating unique ids.
55 : // Usage:
56 : // // Let 'code' be a std::string
57 : // c10::sha1 sha1_hash{code};
58 : // const auto hash_code = sha1_hash.str();
59 : // TODO: Compare vs OpenSSL and/or CryptoPP implementations
60 : struct sha1 {
61 : typedef unsigned int(digest_type)[5];
62 :
63 : sha1(const std::string& s = "") {
64 : if (!s.empty()) {
65 : reset();
66 : process_bytes(s.c_str(), s.size());
67 : }
68 : }
69 :
70 : void reset() {
71 : h_[0] = 0x67452301;
72 : h_[1] = 0xEFCDAB89;
73 : h_[2] = 0x98BADCFE;
74 : h_[3] = 0x10325476;
75 : h_[4] = 0xC3D2E1F0;
76 :
77 : block_byte_index_ = 0;
78 : bit_count_low = 0;
79 : bit_count_high = 0;
80 : }
81 :
82 : std::string str() {
83 : unsigned int digest[5];
84 : get_digest(digest);
85 :
86 : std::ostringstream buf;
87 : for (unsigned int i : digest) {
88 : buf << std::hex << std::setfill('0') << std::setw(8) << i;
89 : }
90 :
91 : return buf.str();
92 : }
93 :
94 : private:
95 : unsigned int left_rotate(unsigned int x, std::size_t n) {
96 : return (x << n) ^ (x >> (32 - n));
97 : }
98 :
99 : void process_block_impl() {
100 : unsigned int w[80];
101 :
102 : for (std::size_t i = 0; i < 16; ++i) {
103 : w[i] = (block_[i * 4 + 0] << 24);
104 : w[i] |= (block_[i * 4 + 1] << 16);
105 : w[i] |= (block_[i * 4 + 2] << 8);
106 : w[i] |= (block_[i * 4 + 3]);
107 : }
108 :
109 : for (std::size_t i = 16; i < 80; ++i) {
110 : w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
111 : }
112 :
113 : unsigned int a = h_[0];
114 : unsigned int b = h_[1];
115 : unsigned int c = h_[2];
116 : unsigned int d = h_[3];
117 : unsigned int e = h_[4];
118 :
119 : for (std::size_t i = 0; i < 80; ++i) {
120 : unsigned int f = 0;
121 : unsigned int k = 0;
122 :
123 : if (i < 20) {
124 : f = (b & c) | (~b & d);
125 : k = 0x5A827999;
126 : } else if (i < 40) {
127 : f = b ^ c ^ d;
128 : k = 0x6ED9EBA1;
129 : } else if (i < 60) {
130 : f = (b & c) | (b & d) | (c & d);
131 : k = 0x8F1BBCDC;
132 : } else {
133 : f = b ^ c ^ d;
134 : k = 0xCA62C1D6;
135 : }
136 :
137 : unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
138 : e = d;
139 : d = c;
140 : c = left_rotate(b, 30);
141 : b = a;
142 : a = temp;
143 : }
144 :
145 : h_[0] += a;
146 : h_[1] += b;
147 : h_[2] += c;
148 : h_[3] += d;
149 : h_[4] += e;
150 : }
151 :
152 : void process_byte_impl(unsigned char byte) {
153 : block_[block_byte_index_++] = byte;
154 :
155 : if (block_byte_index_ == 64) {
156 : block_byte_index_ = 0;
157 : process_block_impl();
158 : }
159 : }
160 :
161 : void process_byte(unsigned char byte) {
162 : process_byte_impl(byte);
163 :
164 : // size_t max value = 0xFFFFFFFF
165 : // if (bit_count_low + 8 >= 0x100000000) { // would overflow
166 : // if (bit_count_low >= 0x100000000-8) {
167 : if (bit_count_low < 0xFFFFFFF8) {
168 : bit_count_low += 8;
169 : } else {
170 : bit_count_low = 0;
171 :
172 : if (bit_count_high <= 0xFFFFFFFE) {
173 : ++bit_count_high;
174 : } else {
175 : TORCH_CHECK(false, "sha1 too many bytes");
176 : }
177 : }
178 : }
179 :
180 : void process_block(void const* bytes_begin, void const* bytes_end) {
181 : unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
182 : unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
183 : for (; begin != end; ++begin) {
184 : process_byte(*begin);
185 : }
186 : }
187 :
188 : void process_bytes(void const* buffer, std::size_t byte_count) {
189 : unsigned char const* b = static_cast<unsigned char const*>(buffer);
190 : process_block(b, b + byte_count);
191 : }
192 :
193 : void get_digest(digest_type& digest) {
194 : // append the bit '1' to the message
195 : process_byte_impl(0x80);
196 :
197 : // append k bits '0', where k is the minimum number >= 0
198 : // such that the resulting message length is congruent to 56 (mod 64)
199 : // check if there is enough space for padding and bit_count
200 : if (block_byte_index_ > 56) {
201 : // finish this block
202 : while (block_byte_index_ != 0) {
203 : process_byte_impl(0);
204 : }
205 :
206 : // one more block
207 : while (block_byte_index_ < 56) {
208 : process_byte_impl(0);
209 : }
210 : } else {
211 : while (block_byte_index_ < 56) {
212 : process_byte_impl(0);
213 : }
214 : }
215 :
216 : // append length of message (before pre-processing)
217 : // as a 64-bit big-endian integer
218 : process_byte_impl(
219 : static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
220 : process_byte_impl(
221 : static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
222 : process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
223 : process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
224 : process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
225 : process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
226 : process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
227 : process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
228 :
229 : // get final digest
230 : digest[0] = h_[0];
231 : digest[1] = h_[1];
232 : digest[2] = h_[2];
233 : digest[3] = h_[3];
234 : digest[4] = h_[4];
235 : }
236 :
237 : unsigned int h_[5]{};
238 : unsigned char block_[64]{};
239 : std::size_t block_byte_index_{};
240 : std::size_t bit_count_low{};
241 : std::size_t bit_count_high{};
242 : };
243 :
244 : constexpr uint64_t twang_mix64(uint64_t key) noexcept {
245 : key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
246 : key = key ^ (key >> 24);
247 : key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
248 : key = key ^ (key >> 14);
249 : key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
250 : key = key ^ (key >> 28);
251 : key = key + (key << 31); // key *= 1 + (1 << 31)
252 : return key;
253 : }
254 :
255 : ////////////////////////////////////////////////////////////////////////////////
256 : // c10::hash implementation
257 : ////////////////////////////////////////////////////////////////////////////////
258 :
259 : namespace _hash_detail {
260 :
261 : // Use template argument deduction to shorten calls to c10::hash
262 : template <typename T>
263 : size_t simple_get_hash(const T& o);
264 :
265 : template <typename T, typename V>
266 : using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
267 :
268 : // Use SFINAE to dispatch to std::hash if possible, cast enum types to int
269 : // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
270 : // for hashing enum types to the standard, and some compilers implement it even
271 : // when C++14 flags aren't specified. This is why we have to disable this
272 : // overload if T is an enum type (and use the one below in this case).
273 : template <typename T>
274 0 : auto dispatch_hash(const T& o)
275 : -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
276 0 : return std::hash<T>()(o);
277 : }
278 :
279 : template <typename T>
280 : std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
281 : using R = std::underlying_type_t<T>;
282 : return std::hash<R>()(static_cast<R>(o));
283 : }
284 :
285 : template <typename T>
286 : auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
287 : return T::hash(o);
288 : }
289 :
290 : } // namespace _hash_detail
291 :
292 : // Hasher struct
293 : template <typename T>
294 : struct hash {
295 : size_t operator()(const T& o) const {
296 0 : return _hash_detail::dispatch_hash(o);
297 : }
298 : };
299 :
300 : // Specialization for std::tuple
301 : template <typename... Types>
302 : struct hash<std::tuple<Types...>> {
303 : template <size_t idx, typename... Ts>
304 : struct tuple_hash {
305 0 : size_t operator()(const std::tuple<Ts...>& t) const {
306 0 : return hash_combine(
307 : _hash_detail::simple_get_hash(std::get<idx>(t)),
308 0 : tuple_hash<idx - 1, Ts...>()(t));
309 : }
310 : };
311 :
312 : template <typename... Ts>
313 : struct tuple_hash<0, Ts...> {
314 : size_t operator()(const std::tuple<Ts...>& t) const {
315 0 : return _hash_detail::simple_get_hash(std::get<0>(t));
316 : }
317 : };
318 :
319 : size_t operator()(const std::tuple<Types...>& t) const {
320 0 : return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
321 : }
322 : };
323 :
324 : template <typename T1, typename T2>
325 : struct hash<std::pair<T1, T2>> {
326 : size_t operator()(const std::pair<T1, T2>& pair) const {
327 : std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
328 : return _hash_detail::simple_get_hash(tuple);
329 : }
330 : };
331 :
332 : template <typename T>
333 : struct hash<c10::ArrayRef<T>> {
334 : size_t operator()(c10::ArrayRef<T> v) const {
335 : size_t seed = 0;
336 : for (const auto& elem : v) {
337 : seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
338 : }
339 : return seed;
340 : }
341 : };
342 :
343 : // Specialization for std::vector
344 : template <typename T>
345 : struct hash<std::vector<T>> {
346 : size_t operator()(const std::vector<T>& v) const {
347 : return hash<c10::ArrayRef<T>>()(v);
348 : }
349 : };
350 :
351 : namespace _hash_detail {
352 :
353 : template <typename T>
354 0 : size_t simple_get_hash(const T& o) {
355 0 : return c10::hash<T>()(o);
356 : }
357 :
358 : } // namespace _hash_detail
359 :
360 : // Use this function to actually hash multiple things in one line.
361 : // Dispatches to c10::hash, so it can hash containers.
362 : // Example:
363 : //
364 : // static size_t hash(const MyStruct& s) {
365 : // return get_hash(s.member1, s.member2, s.member3);
366 : // }
367 : template <typename... Types>
368 0 : size_t get_hash(const Types&... args) {
369 0 : return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
370 : }
371 :
372 : // Specialization for c10::complex
373 : template <typename T>
374 : struct hash<c10::complex<T>> {
375 : size_t operator()(const c10::complex<T>& c) const {
376 0 : return get_hash(c.real(), c.imag());
377 : }
378 : };
379 :
380 : } // namespace c10
381 :
382 : #else
383 : #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
384 : #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|