23 #include "ngraph/coordinate_diff.hpp"
24 #include "ngraph/node.hpp"
25 #include "ngraph/runtime/aligned_buffer.hpp"
26 #include "ngraph/runtime/host_tensor.hpp"
27 #include "ngraph/runtime/shared_buffer.hpp"
28 #include "ngraph/type/element_type.hpp"
29 #include "ngraph/type/element_type_traits.hpp"
30 #include "ngraph/util.hpp"
48 Constant(
const std::shared_ptr<runtime::Tensor>& tensor);
59 const std::vector<T>& values)
62 NODE_VALIDATION_CHECK(
64 values.size() == 1 || values.size() ==
shape_size(m_shape),
65 "Did not get the expected number of literals for a constant of shape ",
74 if (values.size() == 1)
76 write_values(std::vector<T>(
shape_size(m_shape), values[0]));
82 constructor_validate_and_infer_types();
83 m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
95 class =
typename std::enable_if<std::is_fundamental<T>::value>::type>
100 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
101 #pragma GCC diagnostic push
102 #pragma GCC diagnostic error "-Wswitch"
103 #pragma GCC diagnostic error "-Wswitch-enum"
107 case element::Type_t::boolean:
109 get_data_ptr_nc<element::Type_t::boolean>(),
115 case element::Type_t::bf16:
117 get_data_ptr_nc<element::Type_t::bf16>(),
123 case element::Type_t::f16:
125 get_data_ptr_nc<element::Type_t::f16>(),
131 case element::Type_t::f32:
133 get_data_ptr_nc<element::Type_t::f32>(),
139 case element::Type_t::f64:
141 get_data_ptr_nc<element::Type_t::f64>(),
147 case element::Type_t::i8:
149 get_data_ptr_nc<element::Type_t::i8>(),
155 case element::Type_t::i16:
157 get_data_ptr_nc<element::Type_t::i16>(),
163 case element::Type_t::i32:
165 get_data_ptr_nc<element::Type_t::i32>(),
171 case element::Type_t::i64:
173 get_data_ptr_nc<element::Type_t::i64>(),
179 case element::Type_t::u8:
181 get_data_ptr_nc<element::Type_t::u8>(),
187 case element::Type_t::u16:
189 get_data_ptr_nc<element::Type_t::u16>(),
195 case element::Type_t::u32:
197 get_data_ptr_nc<element::Type_t::u32>(),
203 case element::Type_t::u64:
205 get_data_ptr_nc<element::Type_t::u64>(),
211 case element::Type_t::u1:
throw std::runtime_error(
"unsupported type");
212 case element::Type_t::undefined:
throw std::runtime_error(
"unsupported type");
213 case element::Type_t::dynamic:
throw std::runtime_error(
"unsupported type");
215 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
216 #pragma GCC diagnostic pop
218 constructor_validate_and_infer_types();
219 m_all_elements_bitwise_identical =
true;
230 const std::vector<std::string>& values);
244 template <
typename T>
248 : m_element_type(type)
252 constructor_validate_and_infer_types();
262 infer_element_type();
263 set_output_type(0, m_element_type, m_shape);
269 const HostTensorVector& inputs)
const override;
270 bool evaluate_lower(
const HostTensorVector& outputs)
const override;
271 bool evaluate_upper(
const HostTensorVector& outputs)
const override;
274 bool constant_fold(OutputVector& outputs,
const OutputVector& inputs)
override
320 template <
typename T>
321 static std::shared_ptr<op::v0::Constant>
324 auto result = std::make_shared<op::v0::Constant>(type, shape, values);
325 result->validate_and_infer_types();
334 template <
typename T>
335 static std::shared_ptr<op::v0::Constant>
339 std::make_shared<op::v0::Constant>(type, shape, std::vector<T>{values});
340 result->validate_and_infer_types();
344 virtual std::shared_ptr<Node>
345 clone_with_new_inputs(
const OutputVector& new_args)
const override;
350 template <
typename T>
351 std::vector<T> get_vector()
const
353 const T* p = get_data_ptr<T>();
355 throw std::runtime_error(
"Cannot create vector! Buffer is not allocated.");
356 return std::vector<T>(p, p +
shape_size(m_shape));
363 template <
typename T>
366 auto source_type = get_element_type();
369 #if defined(_MSC_VER)
370 #pragma warning(push)
371 #pragma warning(disable : 4244)
375 case element::Type_t::boolean:
377 cast_vector<char>(rc);
380 case element::Type_t::bf16:
382 cast_vector<bfloat16>(rc);
385 case element::Type_t::f16:
387 cast_vector<float16>(rc);
390 case element::Type_t::f32:
392 cast_vector<float>(rc);
395 case element::Type_t::f64:
397 cast_vector<double>(rc);
400 case element::Type_t::i8:
402 cast_vector<int8_t>(rc);
405 case element::Type_t::i16:
407 cast_vector<int16_t>(rc);
410 case element::Type_t::i32:
412 cast_vector<int32_t>(rc);
415 case element::Type_t::i64:
417 cast_vector<int64_t>(rc);
420 case element::Type_t::u8:
422 cast_vector<uint8_t>(rc);
425 case element::Type_t::u16:
427 cast_vector<uint16_t>(rc);
430 case element::Type_t::u32:
432 cast_vector<uint32_t>(rc);
435 case element::Type_t::u64:
437 cast_vector<uint64_t>(rc);
440 default:
throw std::runtime_error(
"unsupported type");
442 #if defined(_MSC_VER)
448 const void* get_data_ptr()
const {
return (m_data ? m_data->get_ptr() :
nullptr); }
449 template <
typename T>
450 const T* get_data_ptr()
const
452 if (
sizeof(T) > m_element_type.size() &&
shape_size(m_shape) > 0)
457 return static_cast<const T*
>(get_data_ptr());
460 template <element::Type_t ET>
461 const typename element_type_traits<ET>::value_type* get_data_ptr()
const
463 NGRAPH_CHECK(ET == get_element_type(),
464 "get_data_ptr() called for incorrect element type.");
465 return static_cast<const typename element_type_traits<ET>::value_type*
>(
469 bool get_all_data_elements_bitwise_identical()
const
471 return m_all_elements_bitwise_identical;
473 std::string convert_value_to_string(
size_t index)
const;
480 m_alloc_buffer_on_visit_attributes = val;
484 template <
typename IN_T,
typename OUT_T>
485 void cast_vector(std::vector<OUT_T>& output_vector)
const
487 auto source_vector = get_vector<IN_T>();
488 output_vector.reserve(source_vector.size());
490 std::transform(source_vector.begin(),
492 std::back_inserter(output_vector),
493 [](IN_T c) { return static_cast<OUT_T>(c); });
496 void allocate_buffer();
498 void* get_data_ptr_nc() {
return (m_data ? m_data->get_ptr() :
nullptr); }
499 template <element::Type_t ET>
502 NGRAPH_CHECK(ET == get_element_type(),
503 "get_data_ptr_nc() called for incorrect element type.");
508 Constant(
const OutputVector& args)
514 virtual void infer_element_type() {}
515 template <
typename T>
516 void write_values(
const std::vector<T>& values)
519 m_element_type, m_shape, values, get_data_ptr_nc(),
shape_size(m_shape));
522 template <
typename T,
typename U>
523 static void write_buffer(
void* target,
const std::vector<U>& source,
size_t count)
525 T* p =
reinterpret_cast<T*
>(target);
526 for (
size_t i = 0; i < count; i++)
528 p[i] =
static_cast<T
>(source[i]);
532 template <
typename T>
533 static void write_to_buffer(
const element::Type& target_type,
535 const std::vector<T>& source,
537 size_t target_element_count)
539 if (source.size() != target_element_count)
541 throw std::runtime_error(
"Constant initializer does not match shape");
543 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
544 #pragma GCC diagnostic push
545 #pragma GCC diagnostic error "-Wswitch"
546 #pragma GCC diagnostic error "-Wswitch-enum"
550 case element::Type_t::boolean:
551 write_buffer<char, T>(target, source, target_element_count);
553 case element::Type_t::bf16:
554 write_buffer<bfloat16, T>(target, source, target_element_count);
556 case element::Type_t::f16:
557 write_buffer<float16, T>(target, source, target_element_count);
559 case element::Type_t::f32:
560 write_buffer<float, T>(target, source, target_element_count);
562 case element::Type_t::f64:
563 write_buffer<double, T>(target, source, target_element_count);
565 case element::Type_t::i8:
566 write_buffer<int8_t, T>(target, source, target_element_count);
568 case element::Type_t::i16:
569 write_buffer<int16_t, T>(target, source, target_element_count);
571 case element::Type_t::i32:
572 write_buffer<int32_t, T>(target, source, target_element_count);
574 case element::Type_t::i64:
575 write_buffer<int64_t, T>(target, source, target_element_count);
577 case element::Type_t::u8:
578 write_buffer<uint8_t, T>(target, source, target_element_count);
580 case element::Type_t::u16:
581 write_buffer<uint16_t, T>(target, source, target_element_count);
583 case element::Type_t::u32:
584 write_buffer<uint32_t, T>(target, source, target_element_count);
586 case element::Type_t::u64:
587 write_buffer<uint64_t, T>(target, source, target_element_count);
589 case element::Type_t::u1:
throw std::runtime_error(
"unsupported type");
590 case element::Type_t::undefined:
throw std::runtime_error(
"unsupported type");
591 case element::Type_t::dynamic:
throw std::runtime_error(
"unsupported type");
593 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
594 #pragma GCC diagnostic pop
598 static constexpr
size_t host_alignment() {
return 64; }
599 element::Type m_element_type;
601 std::shared_ptr<runtime::AlignedBuffer> m_data;
602 bool m_all_elements_bitwise_identical;
603 bool are_all_data_elements_bitwise_identical()
const;
604 bool m_alloc_buffer_on_visit_attributes =
true;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A set of axes.
Definition: axis_set.hpp:31
A vector of axes.
Definition: axis_vector.hpp:30
A difference (signed) of tensor element coordinates.
Definition: coordinate_diff.hpp:30
Coordinates for a tensor element.
Definition: coordinate.hpp:30
Shape for a tensor.
Definition: shape.hpp:31
Strides for a tensor.
Definition: strides.hpp:30
Definition: element_type.hpp:61
Base error for ngraph runtime errors.
Definition: except.hpp:28
Root of all actual ops.
Definition: op.hpp:29
Class for constants.
Definition: constant.hpp:40
std::vector< T > cast_vector() const
Return the Constant's value as a vector cast to type T.
Definition: constant.hpp:364
std::vector< std::string > get_value_strings() const
void alloc_buffer_on_visit_attributes(bool val)
Allows to avoid buffer allocation on the visit_attributes call.
Definition: constant.hpp:478
Constant(const element::Type &type, const Shape &shape, const void *data)
Constructs a tensor constant with the supplied data.
AxisSet get_axis_set_val() const
Returns the value of the constant node as an AxisSet object Can only be used on element::i64 nodes an...
Constant(const std::shared_ptr< runtime::Tensor > &tensor)
Initialize a constant from tensor.
AxisVector get_axis_vector_val() const
Returns the value of the constant node as an AxisVector object Can only be used on element::i64 nodes...
const NodeTypeInfo & get_type_info() const override
Definition: constant.hpp:43
Constant(const element::Type &type, const Shape &shape)
Create unitialized constant.
Constant(const element::Type &type, const Shape &shape, const std::vector< std::string > &values)
Constructs a tensor constant This constructor is mainly to support deserialization of constants.
CoordinateDiff get_coordinate_diff_val() const
Returns the value of the constant node as a CoordinateDiff object Can only be used on element::i64 no...
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void set_data_shape(const Shape &shape)
Update Constant shape. New shape size must equal to the data elements count.
Strides get_strides_val() const
Returns the value of the constant node as a Strides object Can only be used on element::i64 nodes and...
Constant(const element::Type &type, const Shape &shape, T value)
Constructs a uniform tensor constant.
Definition: constant.hpp:96
static std::shared_ptr< op::v0::Constant > create(const element::Type &type, Shape shape, const std::vector< T > values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:322
Constant(const element::Type &type, const Shape &shape, std::shared_ptr< runtime::SharedBuffer< T >> data)
Constructs a tensor constant with the supplied data.
Definition: constant.hpp:245
Shape get_shape_val() const
Returns the value of the constant node as a Shape object Can only be used on element::i64 nodes and i...
Coordinate get_coordinate_val() const
Returns the value of the constant node as a Coordinate object Can only be used on element::i64 nodes ...
static std::shared_ptr< op::v0::Constant > create(const element::Type &type, Shape shape, std::initializer_list< T > values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:336
Constant(const element::Type &type, const Shape &shape, const std::vector< T > &values)
Constructs a tensor constant.
Definition: constant.hpp:57
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Definition: constant.hpp:260
SharedBuffer class to store pointer to pre-acclocated buffer.
Definition: shared_buffer.hpp:30
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:71
Definition: element_type_traits.hpp:25