31 #include <unordered_map>
34 #include "ngraph/axis_vector.hpp"
35 #include "ngraph/graph_util.hpp"
36 #include "ngraph/node.hpp"
37 #include "ngraph/runtime/host_tensor.hpp"
38 #include "ngraph/runtime/tensor.hpp"
39 #include "ngraph/shape.hpp"
55 std::string join(
const T& v,
const std::string& sep =
", ")
57 std::ostringstream ss;
59 for (
const auto& x : v)
71 std::string vector_to_string(
const T& v)
73 std::ostringstream os;
74 os <<
"[ " << ngraph::join(v) <<
" ]";
79 size_t hash_combine(
const std::vector<size_t>& list);
81 void dump(std::ostream& out,
const void*,
size_t);
83 std::string to_lower(
const std::string& s);
85 std::string to_upper(
const std::string& s);
87 std::string trim(
const std::string& s);
89 std::vector<std::string> split(
const std::string& s,
char delimiter,
bool trim =
false);
92 std::string locale_string(T x)
95 ss.imbue(std::locale(
""));
105 if (m_active ==
false)
109 m_start_time = m_clock.now();
115 if (m_active ==
true)
117 auto end_time = m_clock.now();
118 m_last_time = end_time - m_start_time;
119 m_total_time += m_last_time;
124 size_t get_call_count()
const;
125 size_t get_seconds()
const;
126 size_t get_milliseconds()
const;
127 size_t get_microseconds()
const;
128 std::chrono::nanoseconds get_timer_value()
const;
129 size_t get_nanoseconds()
const;
131 size_t get_total_seconds()
const;
132 size_t get_total_milliseconds()
const;
133 size_t get_total_microseconds()
const;
134 size_t get_total_nanoseconds()
const;
137 std::chrono::high_resolution_clock m_clock;
138 std::chrono::time_point<std::chrono::high_resolution_clock> m_start_time;
139 bool m_active =
false;
140 std::chrono::nanoseconds m_total_time =
141 std::chrono::high_resolution_clock::duration::zero();
142 std::chrono::nanoseconds m_last_time = std::chrono::high_resolution_clock::duration::zero();
143 size_t m_total_count = 0;
147 template <
typename T>
151 std::stringstream ss;
157 if (ss.fail() || ss.rdbuf()->in_avail() != 0)
159 throw std::runtime_error(
"Could not parse literal '" + s +
"'");
170 NGRAPH_API
double parse_string<double>(
const std::string& s);
178 NGRAPH_API uint8_t parse_string<uint8_t>(
const std::string& s);
181 template <
typename T>
184 std::vector<T> result(ss.size());
185 std::transform(ss.begin(), ss.end(), result.begin(), [](
const std::string& s) {
186 return parse_string<T>(s);
191 template <
typename T>
192 T ceil_div(
const T& x,
const T& y)
194 return (x == 0 ? 0 : (1 + (x - 1) / y));
197 template <
typename T>
198 T subtract_or_zero(T x, T y)
200 return y > x ? 0 : x - y;
204 void* ngraph_malloc(
size_t size);
206 void ngraph_free(
void*);
209 size_t round_up(
size_t size,
size_t alignment);
211 template <
typename T>
214 extern template NGRAPH_API AxisVector apply_permutation<AxisVector>(AxisVector input,
217 extern template NGRAPH_API Coordinate apply_permutation<Coordinate>(Coordinate input,
220 extern template NGRAPH_API Strides apply_permutation<Strides>(Strides input, AxisVector order);
222 extern template NGRAPH_API Shape apply_permutation<Shape>(Shape input, AxisVector order);
225 NGRAPH_API PartialShape apply_permutation(PartialShape input, AxisVector order);
228 AxisVector get_default_order(
size_t rank);
231 AxisVector get_default_order(
const Shape& shape);
242 template <
typename T>
247 static_assert(std::is_enum<T>::value,
"EnumMask template type must be an enum");
252 static_assert(std::is_unsigned<value_type>::value,
"EnumMask enum must use unsigned type.");
258 constexpr
EnumMask(
const T& enum_value)
259 : m_value{static_cast<value_type>(enum_value)}
262 EnumMask(
const EnumMask& other)
263 : m_value{other.m_value}
266 EnumMask(std::initializer_list<T> enum_values)
269 for (
auto& v : enum_values)
271 m_value |=
static_cast<value_type
>(v);
274 value_type value()
const {
return m_value; }
278 bool is_set(
const EnumMask& p)
const {
return (m_value & p.m_value) == p.m_value; }
283 void set(
const EnumMask& p) { m_value |= p.m_value; }
284 void clear(
const EnumMask& p) { m_value &= ~p.m_value; }
285 void clear_all() { m_value = 0; }
286 bool operator[](
const EnumMask& p)
const {
return is_set(p); }
287 bool operator==(
const EnumMask& other)
const {
return m_value == other.m_value; }
288 bool operator!=(
const EnumMask& other)
const {
return m_value != other.m_value; }
289 EnumMask& operator=(
const EnumMask& other)
291 m_value = other.m_value;
294 EnumMask& operator&=(
const EnumMask& other)
296 m_value &= other.m_value;
300 EnumMask& operator|=(
const EnumMask& other)
302 m_value |= other.m_value;
306 EnumMask operator&(
const EnumMask& other)
const
308 return EnumMask(m_value & other.m_value);
311 EnumMask operator|(
const EnumMask& other)
const
313 return EnumMask(m_value | other.m_value);
316 friend std::ostream& operator<<(std::ostream& os,
const EnumMask& m)
324 explicit EnumMask(
const value_type& value)
345 std::string version,
size_t& major,
size_t& minor,
size_t& patch, std::string& extra);
347 template <
typename T>
348 T double_to_int(
double x,
double float_to_int_converter(
double))
350 if (!std::is_integral<T>())
352 throw std::runtime_error(
353 "Function double_to_int template parameter must be an integral type.");
356 x = float_to_int_converter(x);
358 double min_t =
static_cast<double>(std::numeric_limits<T>::min());
361 return std::numeric_limits<T>::min();
364 double max_t =
static_cast<double>(std::numeric_limits<T>::max());
367 return std::numeric_limits<T>::max();
370 return static_cast<T
>(x);
374 template <
typename T>
375 std::vector<T> read_vector(std::shared_ptr<ngraph::runtime::Tensor> tv)
377 if (ngraph::element::from<T>() != tv->get_element_type())
379 throw std::invalid_argument(
"read_vector type must match Tensor type");
382 size_t size = element_count *
sizeof(T);
383 std::vector<T> rc(element_count);
384 tv->read(rc.data(), size);
388 template <
typename T>
389 std::vector<T> host_tensor_2_vector(ngraph::HostTensorPtr tensor)
391 NGRAPH_CHECK(tensor !=
nullptr,
392 "Invalid Tensor received, can't read the data from a null pointer.");
394 switch (tensor->get_element_type())
396 case ngraph::element::Type_t::boolean:
398 auto p = tensor->get_data_ptr<ngraph::element::Type_t::boolean>();
399 return std::vector<T>(p, p + tensor->get_element_count());
401 case ngraph::element::Type_t::bf16:
403 auto p = tensor->get_data_ptr<ngraph::element::Type_t::bf16>();
404 return std::vector<T>(p, p + tensor->get_element_count());
406 case ngraph::element::Type_t::f16:
408 auto p = tensor->get_data_ptr<ngraph::element::Type_t::f16>();
409 return std::vector<T>(p, p + tensor->get_element_count());
411 case ngraph::element::Type_t::f32:
413 auto p = tensor->get_data_ptr<ngraph::element::Type_t::f32>();
414 return std::vector<T>(p, p + tensor->get_element_count());
416 case ngraph::element::Type_t::f64:
418 auto p = tensor->get_data_ptr<ngraph::element::Type_t::f64>();
419 return std::vector<T>(p, p + tensor->get_element_count());
421 case ngraph::element::Type_t::i8:
423 auto p = tensor->get_data_ptr<ngraph::element::Type_t::i8>();
424 return std::vector<T>(p, p + tensor->get_element_count());
426 case ngraph::element::Type_t::i16:
428 auto p = tensor->get_data_ptr<ngraph::element::Type_t::i16>();
429 return std::vector<T>(p, p + tensor->get_element_count());
431 case ngraph::element::Type_t::i32:
433 auto p = tensor->get_data_ptr<ngraph::element::Type_t::i32>();
434 return std::vector<T>(p, p + tensor->get_element_count());
436 case ngraph::element::Type_t::i64:
438 auto p = tensor->get_data_ptr<ngraph::element::Type_t::i64>();
439 return std::vector<T>(p, p + tensor->get_element_count());
441 case ngraph::element::Type_t::u1: NGRAPH_CHECK(
false,
"u1 element type is unsupported");
break;
442 case ngraph::element::Type_t::u8:
444 auto p = tensor->get_data_ptr<ngraph::element::Type_t::u8>();
445 return std::vector<T>(p, p + tensor->get_element_count());
447 case ngraph::element::Type_t::u16:
449 auto p = tensor->get_data_ptr<ngraph::element::Type_t::u16>();
450 return std::vector<T>(p, p + tensor->get_element_count());
452 case ngraph::element::Type_t::u32:
454 auto p = tensor->get_data_ptr<ngraph::element::Type_t::u32>();
455 return std::vector<T>(p, p + tensor->get_element_count());
457 case ngraph::element::Type_t::u64:
459 auto p = tensor->get_data_ptr<ngraph::element::Type_t::u64>();
460 return std::vector<T>(p, p + tensor->get_element_count());
462 default: NGRAPH_UNREACHABLE(
"unsupported element type");
466 std::vector<float> NGRAPH_API read_float_vector(std::shared_ptr<ngraph::runtime::Tensor> tv);
468 std::vector<int64_t> NGRAPH_API read_index_vector(std::shared_ptr<ngraph::runtime::Tensor> tv);
471 std::ostream& operator<<(std::ostream& os,
const ngraph::NodeVector& nv);
A vector of axes.
Definition: axis_vector.hpp:30
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:35
static Dimension dynamic()
Create a dynamic dimension.
Definition: dimension.hpp:130
bool is_clear(const EnumMask &p) const
Check if all of the input parameter enum bit mask do not match.
Definition: util.hpp:282
bool is_set(const EnumMask &p) const
Check if all of the input parameter enum bit mask match.
Definition: util.hpp:278
constexpr EnumMask()
Definition: util.hpp:254
bool is_any_clear(const EnumMask &p) const
Check if any of the input parameter enum bit mask does not match.
Definition: util.hpp:280
std::underlying_type< T >::type value_type
Make sure the template type is an enum.
Definition: util.hpp:247
bool is_any_set(const EnumMask &p) const
Check if any of the input parameter enum bit mask match.
Definition: util.hpp:276
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
NGRAPH_API void parse_version_string(std::string version, size_t &major, size_t &minor, size_t &patch, std::string &extra)
Function to query parsed version information of the version of ngraph which contains this function....
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:71
T parse_string(const std::string &s)
Parses a string containing a literal of the underlying type.
Definition: util.hpp:148
NGRAPH_API float parse_string< float >(const std::string &s)
NGRAPH_API int8_t parse_string< int8_t >(const std::string &s)