shape.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <cstdio>
8 #include <vector>
9 
10 #include "ngraph/attribute_adapter.hpp"
11 #include "ngraph/axis_set.hpp"
12 #include "ngraph/ngraph_visibility.hpp"
13 #include "ngraph/strides.hpp"
14 
15 namespace ngraph
16 {
17  /// \brief Shape for a tensor.
18  class Shape : public std::vector<size_t>
19  {
20  public:
21  NGRAPH_API Shape();
22 
23  NGRAPH_API Shape(const std::initializer_list<size_t>& axis_lengths);
24 
25  NGRAPH_API Shape(const std::vector<size_t>& axis_lengths);
26 
27  NGRAPH_API Shape(const Shape& axis_lengths);
28 
29  NGRAPH_API explicit Shape(size_t n, size_t initial_value = 0);
30 
31  NGRAPH_API ~Shape();
32 
33  template <class InputIterator>
34  Shape(InputIterator first, InputIterator last)
35  : std::vector<size_t>(first, last)
36  {
37  }
38 
39  NGRAPH_API Shape& operator=(const Shape& v);
40  NGRAPH_API Shape& operator=(Shape&& v) noexcept;
41  };
42 
43  template <>
44  class NGRAPH_API AttributeAdapter<Shape>
45  : public IndirectVectorValueAccessor<Shape, std::vector<int64_t>>
46 
47  {
48  public:
49  AttributeAdapter(Shape& value)
51  {
52  }
53  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<Shape>", 0};
54  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
55  };
56 
57  /// Number of elements in spanned by a shape
58  template <typename SHAPE_TYPE>
59  size_t shape_size(const SHAPE_TYPE& shape)
60  {
61  size_t size = 1;
62  for (auto d : shape)
63  {
64  size *= d;
65  }
66  return size;
67  }
68 
69  /// Row-major strides for a shape
70  template <typename SHAPE_TYPE>
71  std::vector<size_t> row_major_strides(const SHAPE_TYPE& shape)
72  {
73  std::vector<size_t> strides(shape.size());
74  size_t s = 1;
75  auto st = strides.rbegin();
76  for (auto d = shape.rbegin(); d != shape.rend() && st != strides.rend(); d++, st++)
77  {
78  *st = s;
79  s *= *d;
80  }
81  return strides;
82  }
83 
84  template <typename SHAPE_TYPE>
85  size_t row_major_stride(const SHAPE_TYPE& shape, size_t axis)
86  {
87  size_t s = 1;
88  for (size_t i = shape.size(); i-- > axis + 1;)
89  {
90  s *= shape[i];
91  }
92  return s;
93  }
94 
95  template <typename SHAPE_TYPE>
96  inline bool is_scalar(const SHAPE_TYPE& shape)
97  {
98  return 0 == shape.size();
99  }
100 
101  template <typename SHAPE_TYPE>
102  inline bool is_vector(const SHAPE_TYPE& shape)
103  {
104  return 1 == shape.size();
105  }
106 
107  NGRAPH_API
108  std::ostream& operator<<(std::ostream& s, const Shape& shape);
109 } // namespace ngraph
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Definition: attribute_adapter.hpp:126
Shape for a tensor.
Definition: shape.hpp:19
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
size_t shape_size(const SHAPE_TYPE &shape)
Number of elements in spanned by a shape.
Definition: shape.hpp:59
std::vector< size_t > row_major_strides(const SHAPE_TYPE &shape)
Row-major strides for a shape.
Definition: shape.hpp:71
Definition: type.hpp:27