10 #include "ngraph/coordinate_diff.hpp"
11 #include "ngraph/node.hpp"
12 #include "ngraph/runtime/aligned_buffer.hpp"
13 #include "ngraph/runtime/host_tensor.hpp"
14 #include "ngraph/runtime/shared_buffer.hpp"
15 #include "ngraph/type/element_type.hpp"
16 #include "ngraph/type/element_type_traits.hpp"
17 #include "ngraph/util.hpp"
35 Constant(
const std::shared_ptr<runtime::Tensor>& tensor);
46 const std::vector<T>& values)
49 NODE_VALIDATION_CHECK(
51 values.size() == 1 || values.size() ==
shape_size(m_shape),
52 "Did not get the expected number of literals for a constant of shape ",
61 if (values.size() == 1)
63 fill_data(type, values.front());
69 m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical();
81 class =
typename std::enable_if<std::is_fundamental<T>::value>::type>
85 fill_data(type, value);
86 m_all_elements_bitwise_identical =
true;
92 using Type_t = element::Type_t;
93 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
94 #pragma GCC diagnostic push
95 #pragma GCC diagnostic error "-Wswitch"
96 #pragma GCC diagnostic error "-Wswitch-enum"
100 case Type_t::boolean: fill_data<Type_t::boolean>(value);
break;
101 case Type_t::bf16: fill_data<Type_t::bf16>(value);
break;
102 case Type_t::f16: fill_data<Type_t::f16>(value);
break;
103 case Type_t::f32: fill_data<Type_t::f32>(value);
break;
104 case Type_t::f64: fill_data<Type_t::f64>(value);
break;
105 case Type_t::i4: fill_data<Type_t::i4>(value);
break;
106 case Type_t::i8: fill_data<Type_t::i8>(value);
break;
107 case Type_t::i16: fill_data<Type_t::i16>(value);
break;
108 case Type_t::i32: fill_data<Type_t::i32>(value);
break;
109 case Type_t::i64: fill_data<Type_t::i64>(value);
break;
110 case Type_t::u1: fill_data<Type_t::u1>(value);
break;
111 case Type_t::u4: fill_data<Type_t::u4>(value);
break;
112 case Type_t::u8: fill_data<Type_t::u8>(value);
break;
113 case Type_t::u16: fill_data<Type_t::u16>(value);
break;
114 case Type_t::u32: fill_data<Type_t::u32>(value);
break;
115 case Type_t::u64: fill_data<Type_t::u64>(value);
break;
116 case Type_t::undefined:
117 case Type_t::dynamic:
throw std::runtime_error(
"unsupported type");
119 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
120 #pragma GCC diagnostic pop
132 const std::vector<std::string>& values);
146 template <
typename T>
150 : m_element_type(type)
154 constructor_validate_and_infer_types();
164 infer_element_type();
165 set_output_type(0, m_element_type, m_shape);
171 const HostTensorVector& inputs)
const override;
173 bool evaluate_lower(
const HostTensorVector& outputs)
const override;
174 bool evaluate_upper(
const HostTensorVector& outputs)
const override;
177 bool constant_fold(OutputVector& outputs,
const OutputVector& inputs)
override
223 template <
typename T>
226 const std::vector<T>& values)
228 return std::make_shared<Constant>(type, shape, values);
236 template <
typename T>
239 std::initializer_list<T> values)
241 return std::make_shared<Constant>(type, shape, std::vector<T>{values});
249 static std::shared_ptr<Constant>
252 return std::make_shared<Constant>(type, shape, memory);
255 virtual std::shared_ptr<Node>
256 clone_with_new_inputs(
const OutputVector& new_args)
const override;
261 template <
typename T>
262 std::vector<T> get_vector()
const
264 const T* p = get_data_ptr<T>();
266 throw std::runtime_error(
"Cannot create vector! Buffer is not allocated.");
267 return std::vector<T>(p, p +
shape_size(m_shape));
274 template <
typename T>
277 auto source_type = get_element_type();
279 using Type_t = element::Type_t;
280 #if defined(_MSC_VER)
281 #pragma warning(push)
282 #pragma warning(disable : 4244)
286 case Type_t::boolean: cast_vector<Type_t::boolean>(rc);
break;
287 case Type_t::bf16: cast_vector<Type_t::bf16>(rc);
break;
288 case Type_t::f16: cast_vector<Type_t::f16>(rc);
break;
289 case Type_t::f32: cast_vector<Type_t::f32>(rc);
break;
290 case Type_t::f64: cast_vector<Type_t::f64>(rc);
break;
291 case Type_t::i4: cast_vector<Type_t::i4>(rc);
break;
292 case Type_t::i8: cast_vector<Type_t::i8>(rc);
break;
293 case Type_t::i16: cast_vector<Type_t::i16>(rc);
break;
294 case Type_t::i32: cast_vector<Type_t::i32>(rc);
break;
295 case Type_t::i64: cast_vector<Type_t::i64>(rc);
break;
296 case Type_t::u1: cast_vector<Type_t::u1>(rc);
break;
297 case Type_t::u4: cast_vector<Type_t::u4>(rc);
break;
298 case Type_t::u8: cast_vector<Type_t::u8>(rc);
break;
299 case Type_t::u16: cast_vector<Type_t::u16>(rc);
break;
300 case Type_t::u32: cast_vector<Type_t::u32>(rc);
break;
301 case Type_t::u64: cast_vector<Type_t::u64>(rc);
break;
302 default:
throw std::runtime_error(
"unsupported type");
304 #if defined(_MSC_VER)
310 const void* get_data_ptr()
const {
return (m_data ? m_data->get_ptr() :
nullptr); }
311 template <
typename T>
312 const T* get_data_ptr()
const
314 if (
sizeof(T) > m_element_type.size() &&
shape_size(m_shape) > 0)
319 return static_cast<const T*
>(get_data_ptr());
322 template <element::Type_t ET>
323 const typename element_type_traits<ET>::value_type* get_data_ptr()
const
325 NGRAPH_CHECK(ET == get_element_type(),
326 "get_data_ptr() called for incorrect element type.");
327 return static_cast<const typename element_type_traits<ET>::value_type*
>(
331 bool get_all_data_elements_bitwise_identical()
const
333 return m_all_elements_bitwise_identical;
335 std::string convert_value_to_string(
size_t index)
const;
342 m_alloc_buffer_on_visit_attributes = val;
346 template <element::Type_t Type,
347 typename StorageDataType = fundamental_type_for<Type>,
348 typename std::enable_if<Type != element::Type_t::u1 &&
349 Type != element::Type_t::u4 &&
350 Type != element::Type_t::i4,
352 StorageDataType get_element_value(
size_t index)
const
354 return get_data_ptr<Type>()[index];
357 template <element::Type_t Type,
358 typename StorageDataType = fundamental_type_for<Type>,
359 typename std::enable_if<Type == element::Type_t::u1, bool>::type =
true>
360 StorageDataType get_element_value(
size_t index)
const
362 return (get_data_ptr<uint8_t>()[index / 8] >> (7 - (index % 8))) & 1;
365 template <element::Type_t Type,
366 typename StorageDataType = fundamental_type_for<Type>,
367 typename std::enable_if<Type == element::Type_t::u4, bool>::type =
true>
368 StorageDataType get_element_value(
size_t index)
const
370 return (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
373 template <element::Type_t Type,
374 typename StorageDataType = fundamental_type_for<Type>,
375 typename std::enable_if<Type == element::Type_t::i4, bool>::type =
true>
376 StorageDataType get_element_value(
size_t index)
const
378 const uint8_t i4data =
379 (get_data_ptr<uint8_t>()[index / 2] >> (index % 2 ? 0 : 4)) & 0x0F;
380 const bool is_negative_number = (i4data >> 3) & 0b1;
381 const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
385 template <element::Type_t Type,
387 typename std::enable_if<Type != element::Type_t::u1 &&
388 Type != element::Type_t::u4 &&
389 Type != element::Type_t::i4,
391 void cast_vector(std::vector<OUT_T>& output_vector)
const
396 using IN_T = fundamental_type_for<Type>;
397 auto source_vector = get_vector<IN_T>();
398 output_vector.reserve(source_vector.size());
400 std::transform(source_vector.begin(),
402 std::back_inserter(output_vector),
403 [](IN_T c) { return static_cast<OUT_T>(c); });
406 template <element::Type_t Type,
408 typename std::enable_if<Type == element::Type_t::u1, bool>::type =
true>
409 void cast_vector(std::vector<OUT_T>& output)
const
411 using IN_T = fundamental_type_for<Type>;
412 const auto element_number =
shape_size(m_shape);
413 const auto source_begin = get_data_ptr<uint8_t>();
414 const auto source_end = std::next(source_begin, (element_number + 7) / 8);
415 const auto round_element_no = element_number % 8
416 ? element_number - element_number % 8 + 8
418 output.reserve(round_element_no);
419 std::for_each(source_begin, source_end, [&](IN_T c) {
420 for (
const auto i : {7, 6, 5, 4, 3, 2, 1, 0})
422 const uint8_t data = (c >> i) & 0x01;
423 output.push_back(data);
426 output.resize(element_number);
429 template <element::Type_t Type,
431 typename std::enable_if<Type == element::Type_t::u4, bool>::type =
true>
432 void cast_vector(std::vector<OUT_T>& output)
const
434 using IN_T = fundamental_type_for<Type>;
435 const auto element_number =
shape_size(m_shape);
436 const auto source_begin = get_data_ptr<uint8_t>();
437 const auto source_end = std::next(source_begin, (element_number + 1) / 2);
438 const auto round_element_no =
439 element_number % 2 ? element_number + 1 : element_number;
440 output.reserve(round_element_no);
441 std::for_each(source_begin, source_end, [&](IN_T c) {
442 for (
const auto i : {4, 0})
444 const uint8_t data = (c >> i) & 0x0F;
445 output.push_back(data);
448 output.resize(element_number);
450 template <element::Type_t Type,
452 typename std::enable_if<Type == element::Type_t::i4, bool>::type =
true>
453 void cast_vector(std::vector<OUT_T>& output)
const
455 using IN_T = fundamental_type_for<Type>;
456 const auto element_number =
shape_size(m_shape);
457 const auto source_begin = get_data_ptr<uint8_t>();
458 const auto source_end = std::next(source_begin, (element_number + 1) / 2);
459 const auto round_element_no =
460 element_number % 2 ? element_number + 1 : element_number;
461 output.reserve(round_element_no);
462 std::for_each(source_begin, source_end, [&](IN_T c) {
463 for (
const auto i : {4, 0})
465 const uint8_t i4data = (c >> i) & 0x0F;
466 const bool is_negative_number = (i4data >> 3) & 0b1;
467 const int8_t data = is_negative_number ? i4data | 0xF0 : i4data;
468 output.push_back(data);
471 output.resize(element_number);
474 template <element::Type_t Type,
476 typename StorageDataType = fundamental_type_for<Type>,
477 typename std::enable_if<Type != element::Type_t::u1 &&
478 Type != element::Type_t::u4 &&
479 Type != element::Type_t::i4,
481 void fill_data(
const T& value)
484 const auto v =
static_cast<StorageDataType
>(value);
485 std::fill_n(get_data_ptr_nc<Type>(), size, v);
488 template <element::Type_t Type,
490 typename StorageDataType = fundamental_type_for<Type>,
491 typename std::enable_if<Type == element::Type_t::u1, bool>::type =
true>
492 void fill_data(
const T& value)
494 const StorageDataType v = value ? 0xFF : 0x00;
495 std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v);
498 template <element::Type_t Type,
500 typename StorageDataType = fundamental_type_for<Type>,
501 typename std::enable_if<Type == element::Type_t::u4 ||
502 Type == element::Type_t::i4,
504 void fill_data(
const T& value)
506 uint8_t v = value_in_range<Type>(value);
509 std::fill_n(get_data_ptr_nc<Type>(), mem_size(), v);
512 void allocate_buffer();
514 void* get_data_ptr_nc() {
return (m_data ? m_data->get_ptr() :
nullptr); }
516 template <element::Type_t ET>
517 typename element_type_traits<ET>::value_type* get_data_ptr_nc()
519 NGRAPH_CHECK(ET == get_element_type(),
520 "get_data_ptr_nc() called for incorrect element type.");
521 return static_cast<typename element_type_traits<ET>::value_type*
>(
525 Constant(
const OutputVector& args)
531 virtual void infer_element_type() {}
532 template <
typename T>
533 void write_values(
const std::vector<T>& values)
535 write_to_buffer(values);
538 template <element::Type_t Type,
540 typename StorageDataType = fundamental_type_for<Type>,
541 typename std::enable_if<Type != element::Type_t::u1 &&
542 Type != element::Type_t::u4 &&
543 Type != element::Type_t::i4,
545 void write_buffer(
const std::vector<T>& source)
547 auto p = get_data_ptr_nc<Type>();
548 for (
size_t i = 0; i < source.size(); i++)
550 p[i] =
static_cast<StorageDataType
>(source[i]);
554 template <element::Type_t Type,
556 typename StorageDataType = fundamental_type_for<Type>,
557 typename std::enable_if<Type == element::Type_t::u4 ||
558 Type == element::Type_t::i4,
560 void write_buffer(
const std::vector<T>& source)
562 auto p = get_data_ptr_nc<Type>();
564 for (; i < source.size() / 2; i++)
566 const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
567 const auto v2 = value_in_range<Type>(source[i * 2 + 1]) & 0x0F;
568 const auto v = (v1 << 4) | v2;
569 p[i] =
static_cast<StorageDataType
>(v);
571 if (source.size() % 2)
573 const auto v1 = value_in_range<Type>(source[i * 2]) & 0x0F;
574 const auto v = v1 << 4;
575 p[i] =
static_cast<StorageDataType
>(v);
579 template <element::Type_t Type,
581 typename StorageDataType = fundamental_type_for<Type>,
582 typename std::enable_if<Type == element::Type_t::u1, bool>::type =
true>
583 void write_buffer(
const std::vector<T>& source)
585 auto p = get_data_ptr_nc<Type>();
587 for (; i < source.size() / 8; i++)
590 for (
int j = 0; j != 8; j++)
592 const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
595 p[i] =
static_cast<StorageDataType
>(v);
598 for (
unsigned j = 0; j != source.size() % 8; j++)
600 const uint8_t b = source[i * 8 + j] ? 0x01 << (7 - j) : 0;
603 p[i] =
static_cast<StorageDataType
>(v);
606 template <
typename T>
607 void write_to_buffer(
const std::vector<T>& source)
609 const auto& target_type = m_element_type;
610 size_t target_element_count =
shape_size(m_shape);
611 if (source.size() != target_element_count)
613 throw std::runtime_error(
"Constant initializer does not match shape");
615 using Type_t = element::Type_t;
616 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
617 #pragma GCC diagnostic push
618 #pragma GCC diagnostic error "-Wswitch"
619 #pragma GCC diagnostic error "-Wswitch-enum"
623 case Type_t::boolean: write_buffer<Type_t::boolean>(source);
break;
624 case Type_t::bf16: write_buffer<Type_t::bf16>(source);
break;
625 case Type_t::f16: write_buffer<Type_t::f16>(source);
break;
626 case Type_t::f32: write_buffer<Type_t::f32>(source);
break;
627 case Type_t::f64: write_buffer<Type_t::f64>(source);
break;
628 case Type_t::i4: write_buffer<Type_t::i4>(source);
break;
629 case Type_t::i8: write_buffer<Type_t::i8>(source);
break;
630 case Type_t::i16: write_buffer<Type_t::i16>(source);
break;
631 case Type_t::i32: write_buffer<Type_t::i32>(source);
break;
632 case Type_t::i64: write_buffer<Type_t::i64>(source);
break;
633 case Type_t::u1: write_buffer<Type_t::u1>(source);
break;
634 case Type_t::u4: write_buffer<Type_t::u4>(source);
break;
635 case Type_t::u8: write_buffer<Type_t::u8>(source);
break;
636 case Type_t::u16: write_buffer<Type_t::u16>(source);
break;
637 case Type_t::u32: write_buffer<Type_t::u32>(source);
break;
638 case Type_t::u64: write_buffer<Type_t::u64>(source);
break;
639 case element::Type_t::undefined:
640 case element::Type_t::dynamic:
throw std::runtime_error(
"unsupported type");
642 #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
643 #pragma GCC diagnostic pop
647 ngraph::element::Type_t Type,
649 typename std::enable_if<Type == ngraph::element::Type_t::u4, bool>::type =
true>
650 static ngraph::fundamental_type_for<Type> value_in_range(
const ValueT& value)
652 const auto result = ngraph::fundamental_type_for<Type>(value);
653 NGRAPH_CHECK(0 <= result && result <= 15,
654 "assigned value out of range u4 values");
659 ngraph::element::Type_t Type,
661 typename std::enable_if<Type == ngraph::element::Type_t::i4, bool>::type =
true>
662 static ngraph::fundamental_type_for<Type> value_in_range(
const ValueT& value)
664 const auto result = ngraph::fundamental_type_for<Type>(value);
665 NGRAPH_CHECK(-8 <= result && result <= 7,
666 "assigned value out of range i4 values");
670 bool are_all_data_elements_bitwise_identical()
const;
671 static constexpr
size_t host_alignment() {
return 64; }
673 size_t mem_size()
const
675 const bool bitwidth_less_than_byte = m_element_type.bitwidth() < 8;
676 if (bitwidth_less_than_byte)
679 const auto bitwidth = size * m_element_type.bitwidth();
682 return bitwidth / 8 + (bitwidth % 8 ? 1 : 0);
684 return shape_size(m_shape) * m_element_type.size();
687 element::Type m_element_type;
689 std::shared_ptr<runtime::AlignedBuffer> m_data;
690 bool m_all_elements_bitwise_identical;
691 bool m_alloc_buffer_on_visit_attributes =
true;
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A set of axes.
Definition: axis_set.hpp:19
A vector of axes.
Definition: axis_vector.hpp:18
A difference (signed) of tensor element coordinates.
Definition: coordinate_diff.hpp:18
Coordinates for a tensor element.
Definition: coordinate.hpp:18
Shape for a tensor.
Definition: shape.hpp:19
Strides for a tensor.
Definition: strides.hpp:18
Definition: element_type.hpp:51
Base error for ngraph runtime errors.
Definition: except.hpp:16
Root of all actual ops.
Definition: op.hpp:17
Class for constants.
Definition: constant.hpp:27
std::vector< T > cast_vector() const
Return the Constant's value as a vector cast to type T.
Definition: constant.hpp:275
static std::shared_ptr< Constant > create(const element::Type &type, const Shape &shape, std::initializer_list< T > values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:237
std::vector< std::string > get_value_strings() const
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
void alloc_buffer_on_visit_attributes(bool val)
Allows to avoid buffer allocation on the visit_attributes call.
Definition: constant.hpp:340
Constant(const element::Type &type, const Shape &shape, const void *data)
Constructs a tensor constant with the supplied data.
AxisSet get_axis_set_val() const
Returns the value of the constant node as an AxisSet object Can only be used on element::i64 nodes an...
Constant(const std::shared_ptr< runtime::Tensor > &tensor)
Initialize a constant from tensor.
static std::shared_ptr< Constant > create(const element::Type &type, const Shape &shape, const std::vector< T > &values)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:224
AxisVector get_axis_vector_val() const
Returns the value of the constant node as an AxisVector object Can only be used on element::i64 nodes...
const NodeTypeInfo & get_type_info() const override
Definition: constant.hpp:30
Constant(const element::Type &type, const Shape &shape)
Create uninitialized constant.
Constant(const element::Type &type, const Shape &shape, const std::vector< std::string > &values)
Constructs a tensor constant This constructor is mainly to support deserialization of constants.
CoordinateDiff get_coordinate_diff_val() const
Returns the value of the constant node as a CoordinateDiff object Can only be used on element::i64 no...
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
void set_data_shape(const Shape &shape)
Update Constant shape. New shape size must equal to the data elements count.
Strides get_strides_val() const
Returns the value of the constant node as a Strides object Can only be used on element::i64 nodes and...
Constant(const element::Type &type, const Shape &shape, T value)
Constructs a uniform tensor constant.
Definition: constant.hpp:82
static std::shared_ptr< Constant > create(const element::Type &type, const Shape &shape, const void *memory)
Wrapper around constructing a shared_ptr of a Constant.
Definition: constant.hpp:250
Constant(const element::Type &type, const Shape &shape, std::shared_ptr< runtime::SharedBuffer< T >> data)
Constructs a tensor constant with the supplied data.
Definition: constant.hpp:147
Shape get_shape_val() const
Returns the value of the constant node as a Shape object Can only be used on element::i64 nodes and i...
Coordinate get_coordinate_val() const
Returns the value of the constant node as a Coordinate object Can only be used on element::i64 nodes ...
Constant(const element::Type &type, const Shape &shape, const std::vector< T > &values)
Constructs a tensor constant.
Definition: constant.hpp:44
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
Definition: constant.hpp:162
SharedBuffer class to store pointer to pre-acclocated buffer.
Definition: shared_buffer.hpp:18
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:59