reverse_sequence.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 
9 namespace ngraph
10 {
11  namespace op
12  {
13  namespace v0
14  {
15  class NGRAPH_API ReverseSequence : public Op
16  {
17  public:
18  static constexpr NodeTypeInfo type_info{"ReverseSequence", 0};
19  const NodeTypeInfo& get_type_info() const override { return type_info; }
20  ReverseSequence() = default;
21  /// \brief Constructs a ReverseSequence operation.
22  ///
23  /// \param arg tensor with input data to reverse
24  /// \param seq_lengths 1D tensor of integers with sequence lengths in the input
25  /// tensor.
26  /// \param batch_axis index of the batch dimension.
27  /// \param seq_axis index of the sequence dimension.
29  const Output<Node>& seq_lengths,
30  int64_t batch_axis,
31  int64_t seq_axis);
32 
33  bool visit_attributes(AttributeVisitor& visitor) override;
34  void validate_and_infer_types() override;
35 
36  virtual std::shared_ptr<Node>
37  clone_with_new_inputs(const OutputVector& new_args) const override;
38 
39  size_t get_batch_axis() const { return m_normalized_batch_axis; }
40  int64_t get_origin_batch_axis() const { return m_batch_axis; }
41  void set_batch_axis(int64_t batch_axis) { m_batch_axis = batch_axis; }
42  size_t get_sequence_axis() const { return m_normalized_seq_axis; }
43  int64_t get_origin_sequence_axis() const { return m_seq_axis; }
44  void set_sequence_axis(int64_t sequence_axis) { m_seq_axis = sequence_axis; }
45 
46  private:
47  int64_t m_batch_axis;
48  int64_t m_seq_axis = 1;
49  size_t m_normalized_batch_axis;
50  size_t m_normalized_seq_axis;
51  };
52  } // namespace v0
53  using v0::ReverseSequence;
54  } // namespace op
55 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Definition: reverse_sequence.hpp:16
ReverseSequence(const Output< Node > &arg, const Output< Node > &seq_lengths, int64_t batch_axis, int64_t seq_axis)
Constructs a ReverseSequence operation.
const NodeTypeInfo & get_type_info() const override
Definition: reverse_sequence.hpp:19
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27