partial_shape.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <stddef.h>
8 
9 #include "ngraph/attribute_adapter.hpp"
10 #include "ngraph/dimension.hpp"
11 #include "ngraph/op/util/attr_types.hpp"
12 #include "ngraph/rank.hpp"
13 #include "ngraph/shape.hpp"
14 
15 namespace ngraph
16 {
17  namespace op
18  {
19  struct AutoBroadcastSpec;
20  }
21 
22  /// \brief Class representing a shape that may be partially or totally dynamic.
23  ///
24  /// XXX: THIS CLASS IS EXPERIMENTAL AND THE ENTIRE DESIGN IS SUBJECT TO CHANGE.
25  ///
26  /// A PartialShape may have:
27  ///
28  /// \li Dynamic rank. (Informal notation: `?`)
29  /// \li Static rank, but dynamic dimensions on some or all axes.
30  /// (Informal notation examples: `{1,2,?,4}`, `{?,?,?}`)
31  /// \li Static rank, and static dimensions on all axes.
32  /// (Informal notation examples: `{1,2,3,4}`, `{6}`, `{}`)
33  class NGRAPH_API PartialShape
34  {
35  using Dimensions = std::vector<Dimension>;
36 
37  public:
38  using iterator = Dimensions::iterator;
39  using const_iterator = Dimensions::const_iterator;
40  using reverse_iterator = Dimensions::reverse_iterator;
41  using const_reverse_iterator = Dimensions::const_reverse_iterator;
42 
43  /// \brief Constructs a shape with static rank from an initializer list of Dimension.
44  /// \param init The Dimension values for the constructed shape.
45  ///
46  /// Examples:
47  ///
48  /// \code{.cpp}
49  /// PartialShape s{2,3,4}; // rank=3, all dimensions static
50  /// PartialShape s{}; // rank=0
51  /// PartialShape s{2,Dimension::dynamic(),3}; // rank=3, dimension 1 dynamic
52  /// \endcode
53  PartialShape(std::initializer_list<Dimension> init);
54 
55  /// \brief Constructs a PartialShape with static rank from a vector of Dimension.
56  /// \param dimensions The Dimension values for the constructed shape.
57  PartialShape(const std::vector<Dimension>& dimensions);
58 
59  /// \brief Constructs a PartialShape with static rank from a vector of dimensions values.
60  /// \param dimensions The Dimension values for the constructed shape.
61  PartialShape(const std::vector<Dimension::value_type>& dimensions);
62 
63  /// \brief Constructs a static PartialShape with zero rank (the shape of a scalar).
65 
66  /// \brief Constructs a static PartialShape from a Shape.
67  /// \param shape The Shape to convert into PartialShape.
68  PartialShape(const Shape& shape);
69 
70  /// \brief Check if this shape is static.
71  /// \return `true` if this shape is static, else `false`.
72  ///
73  /// A shape is considered static if it has static rank, and all dimensions of the shape
74  /// are static.
75  bool is_static() const;
76 
77  /// \brief Check if this shape is dynamic.
78  /// \return `false` if this shape is static, else `true`.
79  ///
80  /// A shape is considered static if it has static rank, and all dimensions of the shape
81  /// are static.
82  bool is_dynamic() const { return !is_static(); }
83  /// \brief Get the rank of the shape.
84  /// \return The rank of the shape. This will be Rank::dynamic() if the rank of
85  /// the shape is dynamic.
86  Rank rank() const { return m_rank_is_static ? Rank(m_dimensions.size()) : Rank::dynamic(); }
87  /// \brief Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
88  /// \return A PartialShape with the given rank, and all dimensions (if any) dynamic.
90  /// \brief Check whether this shape is compatible with the argument, i.e., whether it is
91  /// possible to merge them.
92  /// \param s The shape to be checked for compatibility with this shape.
93  /// \return `true` if this shape is compatible with `s`, else `false`.
94  ///
95  /// Two shapes are compatible if
96  /// \li one or both of them has dynamic rank, or
97  /// \li both shapes have dynamic and equal rank, and their dimensions are elementwise
98  /// compatible (see Dimension::compatible()).
99  bool compatible(const PartialShape& s) const;
100 
101  /// \brief Check whether this shape represents the same scheme as the argument.
102  /// \param s The shape whose scheme is being compared with this shape.
103  /// \return `true` if this shape represents the same scheme as `s`, else `false`.
104  ///
105  /// Two shapes `s1` and `s2` represent the same scheme if
106  /// \li they both have dynamic rank, or
107  /// \li they both have static and equal rank `r`, and for every `i` from `0` to `r-1`,
108  /// `s1[i]` represents the same scheme as `s2[i]` (see Dimension::same_scheme()).
109  bool same_scheme(const PartialShape& s) const;
110 
111  /// \brief Check whether this shape is a relaxation of the argument.
112  /// \param s The shape which is being compared against this shape.
113  /// \return `true` if this shape relaxes `s`, else `false`.
114  ///
115  /// Intuitively, a PartialShape `s1` is said to _relax_ `s2` (or _is a
116  /// relaxation_ of `s2`) if it is "more permissive" than `s2`. In other
117  /// words, `s1` is a relaxation of `s2` if anything you can form by
118  /// plugging things into the dynamic dimensions of `s2` is also
119  /// something you can form by plugging things into the dynamic
120  /// dimensions of `s1`, but not necessarily the other way around.
121  ///
122  /// `s1.relaxes(s2)` is equivalent to `s2.refines(s1)`.
123  ///
124  /// Formally, PartialShape `s1` is said to _relax_ PartialShape `s2`
125  /// if:
126  /// \li For every `i` from `0` to `r-1`,
127  /// either `s1[i]` contains s2[i].
128  bool relaxes(const PartialShape& s) const;
129 
130  /// \brief Check whether this shape is a refinement of the argument.
131  /// \param s The shape which is being compared against this shape.
132  /// \return `true` if this shape refines `s`, else `false`.
133  ///
134  /// Intuitively, a PartialShape `s1` is said to _relax_ `s2` (or _is a
135  /// relaxation_ of `s2`) if it is "less permissive" than `s2`. In other
136  /// words, `s1` is a relaxation of `s2` if anything you can form by
137  /// plugging things into the dynamic dimensions of `s1` is also
138  /// something you can form by plugging things into the dynamic
139  /// dimensions of `s2`, but not necessarily the other way around.
140  ///
141  /// `s1.refines(s2)` is equivalent to `s2.relaxes(s1)`.
142  ///
143  /// Formally, PartialShape `s1` is said to _refine_ PartialShape `s2`
144  /// if:
145  /// \li `s2` has dynamic rank, or
146  /// \li `s1` and `s2` both have static rank `r`, and for every `i` from `0` to `r-1`,
147  /// either `s2[i]` is dynamic, or `s1[i]` == `s2[i]`.
148  bool refines(const PartialShape& s) const;
149 
150  /// \brief Checks that this shape's rank is compatible with `r`, and, if this shape's
151  /// rank is dynamic and `r` is static, updates this shape to have a rank of `r`
152  /// with dimensions all dynamic.
153  /// \return `true` if this shape's rank is compatible with `r`, else `false`.
154  bool merge_rank(Rank r);
155 
156  /// \brief Convert a static PartialShape to a Shape.
157  /// \return A new Shape `s` where `s[i] = size_t((*this)[i])`.
158  /// \throws std::invalid_argument If this PartialShape is dynamic.
159  Shape to_shape() const;
160 
161  /// \brief Returns `true` if all static dimensions of the tensor are non-negative, else
162  /// `false`.
163  bool all_non_negative() const;
164 
165  /// \brief Index operator for PartialShape.
166  /// \param i The index of the dimension being selected.
167  /// \return A reference to the `i`th Dimension of this shape.
168  const Dimension& operator[](size_t i) const;
169  /// \brief Index operator for PartialShape.
170  /// \param i The index of the dimension being selected.
171  /// \return A reference to the `i`th Dimension of this shape.
172  Dimension& operator[](size_t i);
173  /// \brief Returns a vector of the dimensions. This has no meaning if dynamic.
174  explicit operator std::vector<Dimension>() const { return m_dimensions; }
175  friend NGRAPH_API std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
176  friend PartialShape operator+(const PartialShape& s1, const PartialShape& s2);
177  bool operator==(const PartialShape& partial_shape) const;
178  bool operator!=(const PartialShape& partial_shape) const;
179  /// Get the max bounding shape
181  /// Get the min bounding shape
183  /// Get the unique shape
184  Shape get_shape() const;
185 
186  /// \brief Try to merge one shape into another.
187  /// \param[in,out] dst The shape that `src` will be merged into.
188  /// \param src The shape that will be merged into `dst`.
189  /// \return `true` if merging succeeds, else `false`.
190  ///
191  /// Merges `src` into `dst`, returning `true` on success and `false` on failure. If
192  /// `false` is returned, the effect on `dst` is unspecified.
193  ///
194  /// To merge two partial shapes `s1` and `s2` is to find the most permissive partial shape
195  /// `s` that is no more permissive than `s1` or `s2`, if `s` exists. For example:
196  ///
197  /// \code
198  /// merge(?,?) -> ?
199  /// merge(?,{?,?}) -> {?,?}
200  /// merge({?,?},{?,?}) -> {?,?}
201  /// merge({1,2,3,4},?) -> {1,2,3,4}
202  /// merge({1,2},{1,?}) -> {1,2}
203  /// merge({1,2,?,?},{1,?,3,?}) -> {1,2,3,?}
204  /// merge({1,2,3},{1,2,3}) -> {1,2,3}
205  ///
206  /// merge({1,?},{2,?}) fails [dimension 0 constraints are inconsistent]
207  /// merge({?,?},{?,?,?}) fails [ranks are inconsistent]
208  /// \endcode
209  ///
210  /// This function (merge_into) performs the "merge" operation described above on `dst` and
211  /// `src`, but overwrites `dst` with the result and returns `true` if merging is
212  /// successful; if merging is unsuccessful, the function returns `false` and may make
213  /// unspecified changes to `dst`.
214  static bool merge_into(PartialShape& dst, const PartialShape& src);
215 
216  /// \brief Try to merge one shape into another along with implicit broadcasting
218  const PartialShape& src,
219  const op::AutoBroadcastSpec& autob);
220 
221  /// \brief Returns a read/write iterator that points to the first
222  /// element in the shape. Iteration is done in ordinary
223  /// element order.
224  iterator begin() noexcept { return m_dimensions.begin(); }
225  /// \brief Returns a read-only (constant) iterator that points to the
226  /// first element in the shape. Iteration is done in ordinary
227  /// element order.
228  const_iterator begin() const noexcept { return cbegin(); }
229  /// \brief Returns a read/write iterator that points one past the last
230  /// element in the shape. Iteration is done in ordinary
231  /// element order.
232  iterator end() noexcept { return m_dimensions.end(); }
233  /// \brief Returns a read-only (constant) iterator that points one past
234  /// the last element in the shape. Iteration is done in ordinary
235  /// element order.
236  const_iterator end() const noexcept { return cend(); }
237  /// \brief Returns a read/write reverse iterator that points to the
238  /// last element in the shape. Iteration is done in reverse
239  /// element order.
240  reverse_iterator rbegin() noexcept { return m_dimensions.rbegin(); }
241  /// \brief Returns a read-only (constant) reverse iterator that points
242  /// to the last element in the shape. Iteration is done in
243  /// reverse element order.
244  const_reverse_iterator rbegin() const noexcept { return crbegin(); }
245  /// \brief Returns a read/write reverse iterator that points to one
246  /// before the first element in the shape. Iteration is done
247  /// in reverse element order.
248  reverse_iterator rend() noexcept { return m_dimensions.rend(); }
249  /// \brief Returns a read-only (constant) reverse iterator that points
250  /// to one before the first element in the shape. Iteration
251  /// is done in reverse element order.
252  const_reverse_iterator rend() const noexcept { return crend(); }
253  /// \brief Returns a read-only (constant) iterator that points to the
254  /// first element in the shape. Iteration is done in ordinary
255  /// element order.
256  const_iterator cbegin() const noexcept { return m_dimensions.cbegin(); }
257  /// \brief Returns a read-only (constant) iterator that points one past
258  /// the last element in the shape. Iteration is done in ordinary
259  /// element order.
260  const_iterator cend() const noexcept { return m_dimensions.cend(); }
261  /// \brief Returns a read-only (constant) reverse iterator that points
262  /// to the last element in the shape. Iteration is done in
263  /// reverse element order.
264  const_reverse_iterator crbegin() const noexcept { return m_dimensions.crbegin(); }
265  /// \brief Returns a read-only (constant) reverse iterator that points
266  /// to one before the first element in the shape. Iteration
267  /// is done in reverse element order.
268  const_reverse_iterator crend() const noexcept { return m_dimensions.crend(); }
269 
270  private:
271  // Private constructor for PartialShape::dynamic().
272  PartialShape(bool rank_is_static, const std::vector<Dimension>& dimensions);
273 
274  // True if the shape's rank is static.
275  bool m_rank_is_static;
276 
277  /// \brief Shape types. The shape type is lazily evaluated by calling the is_static()
278  /// method.
279  ///
280  /// \details It is highly recommended to avoid using the Dimension& operator[](size_t)
281  /// operator. It sets the shape type to SHAPE_IS_UPDATED and disables shape type caching.
282  /// Thus, the is_static method will have linear complexity because the shape is not
283  /// guaranteed to remain static or dynamic.
284  mutable enum class ShapeType {
285  SHAPE_IS_UNKNOWN, // The shape type is unknown and should be calculated by checking all
286  // dimensions.
287  SHAPE_IS_UPDATED, // User has retained a link to one dimension. Therefore, we can't
288  // guarantee that the shape will remain static or dynamic, and its
289  // type will always be evaluated.
290  SHAPE_IS_STATIC, // The shape type is known and static. Also there are no any retained
291  // dimensions by non-constant reference.
292  SHAPE_IS_DYNAMIC // The shape type is dynamic and there are no any retained dimensions
293  // by non-constant reference.
294  } m_shape_type{ShapeType::SHAPE_IS_UNKNOWN};
295 
296  // Shape dimensions. This has no meaning if m_rank_is_static is false.
297  Dimensions m_dimensions;
298  };
299 
300  /// \brief Elementwise addition of two PartialShape objects.
301  /// \param s1 Left operand for addition.
302  /// \param s2 Right operand for addition.
303  /// \return The result of elementwise adding `s1` to `s2` (see description).
304  /// \throws std::invalid_argument If `s1` and `s2` have inconsistent ranks.
305  ///
306  /// \li If `s1` or `s2` has dynamic rank, returns PartialShape::dynamic().
307  /// \li If `s1 and `s2` both have static rank, and their ranks are unequal, throws
308  /// std::invalid_argument.
309  /// \li If `s1` and `s2` both have static rank, and their ranks are equal,
310  /// returns a new shape whose `i`th dimension is `s1[i] + s2[i]`.
312 
313  /// \brief Inserts a human-readable representation of a PartialShape into an output stream.
314  /// \param str The output stream targeted for insertion.
315  /// \param shape The shape to be inserted into `str`.
316  /// \return A reference to `str` after insertion.
317  ///
318  /// The output to the stream is in "informal" notation. In other words:
319  ///
320  /// \li If `shape` has dynamic rank, inserts the string `?`.
321  /// \li If `shape` has static rank, inserts the string `{`, then inserts each dimension
322  /// of `shape` into the output stream separated by commas, then inserts `}`.
323  ///
324  /// Example:
325  ///
326  /// \code{.cpp}
327  /// PartialShape s1{PartialShape::dynamic())};
328  /// PartialShape s2{};
329  /// PartialShape s3{1,Dimension::dynamic(),2,3};
330  /// PartialShape s4{2,3,4};
331  /// std::cout << s1 << std::endl
332  /// << s2 << std::endl
333  /// << s3 << std::endl
334  /// << s4 << std::endl;
335  /// \endcode
336  ///
337  /// Output:
338  ///
339  /// \code
340  /// ?
341  /// {}
342  /// {1,?,2,3}
343  /// {2,3,4}
344  /// \endcode
345  NGRAPH_API
346  std::ostream& operator<<(std::ostream& str, const PartialShape& shape);
347 
348  template <>
349  class NGRAPH_API AttributeAdapter<PartialShape> : public ValueAccessor<std::vector<int64_t>>
350  {
351  public:
353  : m_ref(value)
354  {
355  }
356 
357  const std::vector<int64_t>& get() override;
358  void set(const std::vector<int64_t>& value) override;
359  static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<PartialShape>", 0};
360  const DiscreteTypeInfo& get_type_info() const override { return type_info; }
361  operator PartialShape&() { return m_ref; }
362 
363  protected:
364  PartialShape& m_ref;
365  std::vector<int64_t> m_buffer;
366  bool m_buffer_valid{false};
367  };
368 } // namespace ngraph
void set(const std::vector< int64_t > &value) override
Sets the value.
const std::vector< int64_t > & get() override
Returns the value.
An AttributeAdapter "captures" an attribute as an AT& and makes it available as a ValueAccessor<VAT>.
Definition: attribute_adapter.hpp:161
Class representing a dimension, which may be dynamic (undetermined until runtime),...
Definition: dimension.hpp:23
static Dimension dynamic()
Create a dynamic dimension.
Definition: dimension.hpp:118
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
bool all_non_negative() const
Returns true if all static dimensions of the tensor are non-negative, else false.
Rank rank() const
Get the rank of the shape.
Definition: partial_shape.hpp:86
bool same_scheme(const PartialShape &s) const
Check whether this shape represents the same scheme as the argument.
bool is_static() const
Check if this shape is static.
const_iterator end() const noexcept
Returns a read-only (constant) iterator that points one past the last element in the shape....
Definition: partial_shape.hpp:236
const_iterator cend() const noexcept
Returns a read-only (constant) iterator that points one past the last element in the shape....
Definition: partial_shape.hpp:260
bool is_dynamic() const
Check if this shape is dynamic.
Definition: partial_shape.hpp:82
Shape get_shape() const
Get the unique shape.
const_reverse_iterator rend() const noexcept
Returns a read-only (constant) reverse iterator that points to one before the first element in the sh...
Definition: partial_shape.hpp:252
const Dimension & operator[](size_t i) const
Index operator for PartialShape.
bool merge_rank(Rank r)
Checks that this shape's rank is compatible with r, and, if this shape's rank is dynamic and r is sta...
static bool broadcast_merge_into(PartialShape &dst, const PartialShape &src, const op::AutoBroadcastSpec &autob)
Try to merge one shape into another along with implicit broadcasting.
bool relaxes(const PartialShape &s) const
Check whether this shape is a relaxation of the argument.
const_reverse_iterator rbegin() const noexcept
Returns a read-only (constant) reverse iterator that points to the last element in the shape....
Definition: partial_shape.hpp:244
reverse_iterator rend() noexcept
Returns a read/write reverse iterator that points to one before the first element in the shape....
Definition: partial_shape.hpp:248
PartialShape()
Constructs a static PartialShape with zero rank (the shape of a scalar).
friend NGRAPH_API std::ostream & operator<<(std::ostream &str, const PartialShape &shape)
Inserts a human-readable representation of a PartialShape into an output stream.
friend PartialShape operator+(const PartialShape &s1, const PartialShape &s2)
Elementwise addition of two PartialShape objects.
const_reverse_iterator crbegin() const noexcept
Returns a read-only (constant) reverse iterator that points to the last element in the shape....
Definition: partial_shape.hpp:264
bool compatible(const PartialShape &s) const
Check whether this shape is compatible with the argument, i.e., whether it is possible to merge them.
PartialShape(std::initializer_list< Dimension > init)
Constructs a shape with static rank from an initializer list of Dimension.
const_iterator begin() const noexcept
Returns a read-only (constant) iterator that points to the first element in the shape....
Definition: partial_shape.hpp:228
Dimension & operator[](size_t i)
Index operator for PartialShape.
Shape get_max_shape() const
Get the max bounding shape.
iterator end() noexcept
Returns a read/write iterator that points one past the last element in the shape. Iteration is done i...
Definition: partial_shape.hpp:232
PartialShape(const std::vector< Dimension::value_type > &dimensions)
Constructs a PartialShape with static rank from a vector of dimensions values.
const_reverse_iterator crend() const noexcept
Returns a read-only (constant) reverse iterator that points to one before the first element in the sh...
Definition: partial_shape.hpp:268
Shape get_min_shape() const
Get the min bounding shape.
PartialShape(const std::vector< Dimension > &dimensions)
Constructs a PartialShape with static rank from a vector of Dimension.
bool refines(const PartialShape &s) const
Check whether this shape is a refinement of the argument.
Shape to_shape() const
Convert a static PartialShape to a Shape.
iterator begin() noexcept
Returns a read/write iterator that points to the first element in the shape. Iteration is done in ord...
Definition: partial_shape.hpp:224
static PartialShape dynamic(Rank r=Rank::dynamic())
Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
static bool merge_into(PartialShape &dst, const PartialShape &src)
Try to merge one shape into another.
PartialShape(const Shape &shape)
Constructs a static PartialShape from a Shape.
const_iterator cbegin() const noexcept
Returns a read-only (constant) iterator that points to the first element in the shape....
Definition: partial_shape.hpp:256
reverse_iterator rbegin() noexcept
Returns a read/write reverse iterator that points to the last element in the shape....
Definition: partial_shape.hpp:240
Shape for a tensor.
Definition: shape.hpp:19
Provides access to an attribute of type AT as a value accessor type VAT.
Definition: attribute_adapter.hpp:49
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Dimension Rank
Alias for Dimension, used when the value represents the number of axes in a shape,...
Definition: rank.hpp:15
PartialShape operator+(const PartialShape &s1, const PartialShape &s2)
Elementwise addition of two PartialShape objects.
Definition: type.hpp:27
Implicit broadcast specification.
Definition: attr_types.hpp:311