util.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <algorithm>
8 #include <chrono>
9 #include <cmath>
10 #include <cstdlib> // llvm 8.1 gets confused about `malloc` otherwise
11 #include <functional>
12 #include <iostream>
13 #include <map>
14 #include <memory>
15 #include <sstream>
16 #include <string>
17 #include <typeindex>
18 #include <typeinfo>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "ngraph/axis_vector.hpp"
23 #include "ngraph/graph_util.hpp"
24 #include "ngraph/node.hpp"
25 #include "ngraph/runtime/host_tensor.hpp"
26 #include "ngraph/runtime/tensor.hpp"
27 #include "ngraph/shape.hpp"
28 
29 namespace ngraph
30 {
31  class Node;
32  class Function;
33  class stopwatch;
34 
35  namespace runtime
36  {
37  class Backend;
38  class Value;
39  class Tensor;
40  } // namespace runtime
41 
42  template <typename T>
43  std::string join(const T& v, const std::string& sep = ", ")
44  {
45  std::ostringstream ss;
46  size_t count = 0;
47  for (const auto& x : v)
48  {
49  if (count++ > 0)
50  {
51  ss << sep;
52  }
53  ss << x;
54  }
55  return ss.str();
56  }
57 
58  template <typename T>
59  std::string vector_to_string(const T& v)
60  {
61  std::ostringstream os;
62  os << "[ " << ngraph::join(v) << " ]";
63  return os.str();
64  }
65 
66  NGRAPH_API
67  size_t hash_combine(const std::vector<size_t>& list);
68  NGRAPH_API
69  void dump(std::ostream& out, const void*, size_t);
70  NGRAPH_API
71  std::string to_lower(const std::string& s);
72  NGRAPH_API
73  std::string to_upper(const std::string& s);
74  NGRAPH_API
75  std::string trim(const std::string& s);
76  NGRAPH_API
77  std::vector<std::string> split(const std::string& s, char delimiter, bool trim = false);
78 
79  template <typename T>
80  std::string locale_string(T x)
81  {
82  std::stringstream ss;
83  ss.imbue(std::locale(""));
84  ss << x;
85  return ss.str();
86  }
87 
88  class NGRAPH_API stopwatch
89  {
90  public:
91  void start()
92  {
93  if (m_active == false)
94  {
95  m_total_count++;
96  m_active = true;
97  m_start_time = m_clock.now();
98  }
99  }
100 
101  void stop()
102  {
103  if (m_active == true)
104  {
105  auto end_time = m_clock.now();
106  m_last_time = end_time - m_start_time;
107  m_total_time += m_last_time;
108  m_active = false;
109  }
110  }
111 
112  size_t get_call_count() const;
113  size_t get_seconds() const;
114  size_t get_milliseconds() const;
115  size_t get_microseconds() const;
116  std::chrono::nanoseconds get_timer_value() const;
117  size_t get_nanoseconds() const;
118 
119  size_t get_total_seconds() const;
120  size_t get_total_milliseconds() const;
121  size_t get_total_microseconds() const;
122  size_t get_total_nanoseconds() const;
123 
124  private:
125  std::chrono::high_resolution_clock m_clock;
126  std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
127  bool m_active = false;
128  std::chrono::nanoseconds m_total_time =
129  std::chrono::high_resolution_clock::duration::zero();
130  std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero();
131  size_t m_total_count = 0;
132  };
133 
134  /// Parses a string containing a literal of the underlying type.
135  template <typename T>
136  T parse_string(const std::string& s)
137  {
138  T result;
139  std::stringstream ss;
140 
141  ss << s;
142  ss >> result;
143 
144  // Check that (1) parsing succeeded and (2) the entire string was used.
145  if (ss.fail() || ss.rdbuf()->in_avail() != 0)
146  {
147  throw std::runtime_error("Could not parse literal '" + s + "'");
148  }
149 
150  return result;
151  }
152 
153  /// template specializations for float and double to handle INFINITY, -INFINITY
154  /// and NaN values.
155  template <>
156  NGRAPH_API float parse_string<float>(const std::string& s);
157  template <>
158  NGRAPH_API double parse_string<double>(const std::string& s);
159 
160  /// template specializations for int8_t and uint8_t to handle the fact that default
161  /// implementation ends up treating values as characters so that the number "0" turns into
162  /// the parsed value 48, which is it's ASCII value
163  template <>
164  NGRAPH_API int8_t parse_string<int8_t>(const std::string& s);
165  template <>
166  NGRAPH_API uint8_t parse_string<uint8_t>(const std::string& s);
167 
168  /// Parses a list of strings containing literals of the underlying type.
169  template <typename T>
170  std::vector<T> parse_string(const std::vector<std::string>& ss)
171  {
172  std::vector<T> result(ss.size());
173  std::transform(ss.begin(), ss.end(), result.begin(), [](const std::string& s) {
174  return parse_string<T>(s);
175  });
176  return result;
177  }
178 
179  template <typename T>
180  T ceil_div(const T& x, const T& y)
181  {
182  return (x == 0 ? 0 : (1 + (x - 1) / y));
183  }
184 
185  template <typename T>
186  T subtract_or_zero(T x, T y)
187  {
188  return y > x ? 0 : x - y;
189  }
190 
191  NGRAPH_API
192  void* ngraph_malloc(size_t size);
193  NGRAPH_API
194  void ngraph_free(void*);
195 
196  NGRAPH_API
197  size_t round_up(size_t size, size_t alignment);
198  bool is_valid_permutation(ngraph::AxisVector permutation, ngraph::Rank rank = Rank::dynamic());
199  template <typename T>
200  T apply_permutation(T input, ngraph::AxisVector order);
201 
202  extern template NGRAPH_API AxisVector apply_permutation<AxisVector>(AxisVector input,
203  AxisVector order);
204 
205  extern template NGRAPH_API Coordinate apply_permutation<Coordinate>(Coordinate input,
206  AxisVector order);
207 
208  extern template NGRAPH_API Strides apply_permutation<Strides>(Strides input, AxisVector order);
209 
210  extern template NGRAPH_API Shape apply_permutation<Shape>(Shape input, AxisVector order);
211 
212  template <>
213  NGRAPH_API PartialShape apply_permutation(PartialShape input, AxisVector order);
214 
215  NGRAPH_API
216  AxisVector get_default_order(size_t rank);
217 
218  NGRAPH_API
219  AxisVector get_default_order(const Rank& rank);
220 
221  NGRAPH_API
222  AxisVector get_default_order(const Shape& shape);
223 
224  NGRAPH_API
225  AxisVector get_default_order(const PartialShape& shape);
226 
227  //
228  // EnumMask is intended to work with a scoped enum type. It's used to store
229  // a combination of enum values and provides easy access and manipulation
230  // of these enum values as a mask.
231  //
232  // EnumMask does not provide a set_all() or invert() operator because they
233  // could do things unexpected by the user, i.e. for enum with 4 bit values,
234  // invert(001000...) != 110100..., due to the extra bits.
235  //
236  template <typename T>
237  class EnumMask
238  {
239  public:
240  /// Make sure the template type is an enum.
241  static_assert(std::is_enum<T>::value, "EnumMask template type must be an enum");
242  /// Extract the underlying type of the enum.
243  typedef typename std::underlying_type<T>::type value_type;
244  /// Some bit operations are not safe for signed values, we require enum
245  /// type to use unsigned underlying type.
246  static_assert(std::is_unsigned<value_type>::value, "EnumMask enum must use unsigned type.");
247 
248  constexpr EnumMask()
249  : m_value{0}
250  {
251  }
252  constexpr EnumMask(const T& enum_value)
253  : m_value{static_cast<value_type>(enum_value)}
254  {
255  }
256  EnumMask(const EnumMask& other)
257  : m_value{other.m_value}
258  {
259  }
260  EnumMask(std::initializer_list<T> enum_values)
261  : m_value{0}
262  {
263  for (auto& v : enum_values)
264  {
265  m_value |= static_cast<value_type>(v);
266  }
267  }
268  value_type value() const { return m_value; }
269  /// Check if any of the input parameter enum bit mask match
270  bool is_any_set(const EnumMask& p) const { return m_value & p.m_value; }
271  /// Check if all of the input parameter enum bit mask match
272  bool is_set(const EnumMask& p) const { return (m_value & p.m_value) == p.m_value; }
273  /// Check if any of the input parameter enum bit mask does not match
274  bool is_any_clear(const EnumMask& p) const { return !is_set(p); }
275  /// Check if all of the input parameter enum bit mask do not match
276  bool is_clear(const EnumMask& p) const { return !is_any_set(p); }
277  void set(const EnumMask& p) { m_value |= p.m_value; }
278  void clear(const EnumMask& p) { m_value &= ~p.m_value; }
279  void clear_all() { m_value = 0; }
280  bool operator[](const EnumMask& p) const { return is_set(p); }
281  bool operator==(const EnumMask& other) const { return m_value == other.m_value; }
282  bool operator!=(const EnumMask& other) const { return m_value != other.m_value; }
283  EnumMask& operator=(const EnumMask& other)
284  {
285  m_value = other.m_value;
286  return *this;
287  }
288  EnumMask& operator&=(const EnumMask& other)
289  {
290  m_value &= other.m_value;
291  return *this;
292  }
293 
294  EnumMask& operator|=(const EnumMask& other)
295  {
296  m_value |= other.m_value;
297  return *this;
298  }
299 
300  EnumMask operator&(const EnumMask& other) const
301  {
302  return EnumMask(m_value & other.m_value);
303  }
304 
305  EnumMask operator|(const EnumMask& other) const
306  {
307  return EnumMask(m_value | other.m_value);
308  }
309 
310  friend std::ostream& operator<<(std::ostream& os, const EnumMask& m)
311  {
312  os << m.m_value;
313  return os;
314  }
315 
316  private:
317  /// Only used internally
318  explicit EnumMask(const value_type& value)
319  : m_value{value}
320  {
321  }
322 
323  value_type m_value;
324  };
325 
326  /// \brief Function to query parsed version information of the version of ngraph which
327  /// contains this function. Version information strictly follows Semantic Versioning
328  /// http://semver.org
329  /// \param version The major part of the version
330  /// \param major Returns the major part of the version
331  /// \param minor Returns the minor part of the version
332  /// \param patch Returns the patch part of the version
333  /// \param extra Returns the extra part of the version. This includes everything following
334  /// the patch version number.
335  ///
336  /// \note Throws a runtime_error if there is an error during parsing
337  NGRAPH_API
339  std::string version, size_t& major, size_t& minor, size_t& patch, std::string& extra);
340 
341  template <typename T>
342  T double_to_int(double x, double float_to_int_converter(double))
343  {
344  if (!std::is_integral<T>())
345  {
346  throw std::runtime_error(
347  "Function double_to_int template parameter must be an integral type.");
348  }
349 
350  x = float_to_int_converter(x);
351 
352  double min_t = static_cast<double>(std::numeric_limits<T>::min());
353  if (x < min_t)
354  {
355  return std::numeric_limits<T>::min();
356  }
357 
358  double max_t = static_cast<double>(std::numeric_limits<T>::max());
359  if (x > max_t)
360  {
361  return std::numeric_limits<T>::max();
362  }
363 
364  return static_cast<T>(x);
365  }
366 } // end namespace ngraph
367 
368 template <typename T>
369 std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::Tensor> tv)
370 {
371  if (ngraph::element::from<T>() != tv->get_element_type())
372  {
373  throw std::invalid_argument("read_vector type must match Tensor type");
374  }
375  size_t element_count = ngraph::shape_size(tv->get_shape());
376  size_t size = element_count * sizeof(T);
377  std::vector<T> rc(element_count);
378  tv->read(rc.data(), size);
379  return rc;
380 }
381 
382 template <typename T>
383 std::vector<T> host_tensor_2_vector(ngraph::HostTensorPtr tensor)
384 {
385  NGRAPH_CHECK(tensor != nullptr,
386  "Invalid Tensor received, can't read the data from a null pointer.");
387 
388  switch (tensor->get_element_type())
389  {
390  case ngraph::element::Type_t::boolean:
391  {
392  auto p = tensor->get_data_ptr<ngraph::element::Type_t::boolean>();
393  return std::vector<T>(p, p + tensor->get_element_count());
394  }
395  case ngraph::element::Type_t::bf16:
396  {
397  auto p = tensor->get_data_ptr<ngraph::element::Type_t::bf16>();
398  return std::vector<T>(p, p + tensor->get_element_count());
399  }
400  case ngraph::element::Type_t::f16:
401  {
402  auto p = tensor->get_data_ptr<ngraph::element::Type_t::f16>();
403  return std::vector<T>(p, p + tensor->get_element_count());
404  }
405  case ngraph::element::Type_t::f32:
406  {
407  auto p = tensor->get_data_ptr<ngraph::element::Type_t::f32>();
408  return std::vector<T>(p, p + tensor->get_element_count());
409  }
410  case ngraph::element::Type_t::f64:
411  {
412  auto p = tensor->get_data_ptr<ngraph::element::Type_t::f64>();
413  return std::vector<T>(p, p + tensor->get_element_count());
414  }
415  case ngraph::element::Type_t::i8:
416  {
417  auto p = tensor->get_data_ptr<ngraph::element::Type_t::i8>();
418  return std::vector<T>(p, p + tensor->get_element_count());
419  }
420  case ngraph::element::Type_t::i16:
421  {
422  auto p = tensor->get_data_ptr<ngraph::element::Type_t::i16>();
423  return std::vector<T>(p, p + tensor->get_element_count());
424  }
425  case ngraph::element::Type_t::i32:
426  {
427  auto p = tensor->get_data_ptr<ngraph::element::Type_t::i32>();
428  return std::vector<T>(p, p + tensor->get_element_count());
429  }
430  case ngraph::element::Type_t::i64:
431  {
432  auto p = tensor->get_data_ptr<ngraph::element::Type_t::i64>();
433  return std::vector<T>(p, p + tensor->get_element_count());
434  }
435  case ngraph::element::Type_t::u1: NGRAPH_CHECK(false, "u1 element type is unsupported"); break;
436  case ngraph::element::Type_t::u8:
437  {
438  auto p = tensor->get_data_ptr<ngraph::element::Type_t::u8>();
439  return std::vector<T>(p, p + tensor->get_element_count());
440  }
441  case ngraph::element::Type_t::u16:
442  {
443  auto p = tensor->get_data_ptr<ngraph::element::Type_t::u16>();
444  return std::vector<T>(p, p + tensor->get_element_count());
445  }
446  case ngraph::element::Type_t::u32:
447  {
448  auto p = tensor->get_data_ptr<ngraph::element::Type_t::u32>();
449  return std::vector<T>(p, p + tensor->get_element_count());
450  }
451  case ngraph::element::Type_t::u64:
452  {
453  auto p = tensor->get_data_ptr<ngraph::element::Type_t::u64>();
454  return std::vector<T>(p, p + tensor->get_element_count());
455  }
456  default: NGRAPH_UNREACHABLE("unsupported element type");
457  }
458 }
459 
460 std::vector<float> NGRAPH_API read_float_vector(std::shared_ptr<ngraph::runtime::Tensor> tv);
461 
462 std::vector<int64_t> NGRAPH_API read_index_vector(std::shared_ptr<ngraph::runtime::Tensor> tv);
463 
464 NGRAPH_API
465 std::ostream& operator<<(std::ostream& os, const ngraph::NodeVector& nv);
A vector of axes.
Definition: axis_vector.hpp:18
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:23
static Dimension dynamic()
Create a dynamic dimension.
Definition: dimension.hpp:118
Definition: util.hpp:238
bool is_clear(const EnumMask &p) const
Check if all of the input parameter enum bit mask do not match.
Definition: util.hpp:276
bool is_set(const EnumMask &p) const
Check if all of the input parameter enum bit mask match.
Definition: util.hpp:272
constexpr EnumMask()
Definition: util.hpp:248
bool is_any_clear(const EnumMask &p) const
Check if any of the input parameter enum bit mask does not match.
Definition: util.hpp:274
std::underlying_type< T >::type value_type
Make sure the template type is an enum.
Definition: util.hpp:241
bool is_any_set(const EnumMask &p) const
Check if any of the input parameter enum bit mask match.
Definition: util.hpp:270
Definition: util.hpp:89
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Dimension Rank
Alias for Dimension, used when the value represents the number of axes in a shape,...
Definition: rank.hpp:15
NGRAPH_API void parse_version_string(std::string version, size_t &major, size_t &minor, size_t &patch, std::string &extra)
Function to query parsed version information of the version of ngraph which contains this function....
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:59
T parse_string(const std::string &s)
Parses a string containing a literal of the underlying type.
Definition: util.hpp:136
NGRAPH_API float parse_string< float >(const std::string &s)
NGRAPH_API int8_t parse_string< int8_t >(const std::string &s)