ctc_greedy_decoder_seq_len.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 
9 namespace ngraph
10 {
11  namespace op
12  {
13  namespace v6
14  {
15  /// \brief Operator performing CTCGreedyDecoder
16  ///
17  class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
18  {
19  public:
20  NGRAPH_RTTI_DECLARATION;
21  CTCGreedyDecoderSeqLen() = default;
22  /// \brief Constructs a CTCGreedyDecoderSeqLen operation
23  ///
24  /// \param input 3-D tensor of logits on which greedy decoding is
25  /// performed
26  /// \param seq_len 1-D tensor of sequence lengths
27  /// \param merge_repeated Whether to merge repeated labels
28  /// \param classes_index_type Specifies the output classes_index tensor type
29  /// \param sequence_length_type Specifies the output sequence_length tensor type
31  const Output<Node>& seq_len,
32  const bool merge_repeated = true,
33  const element::Type& classes_index_type = element::i32,
34  const element::Type& sequence_length_type = element::i32);
35  /// \brief Constructs a CTCGreedyDecoderSeqLen operation
36  ///
37  /// \param input 3-D tensor of logits on which greedy decoding is
38  /// performed
39  /// \param seq_len 1-D tensor of sequence lengths
40  /// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
41  /// blank index
42  /// \param merge_repeated Whether to merge repeated labels
43  /// \param classes_index_type Specifies the output classes_index tensor type
44  /// \param sequence_length_type Specifies the output sequence_length tensor type
46  const Output<Node>& seq_len,
47  const Output<Node>& blank_index,
48  const bool merge_repeated = true,
49  const element::Type& classes_index_type = element::i32,
50  const element::Type& sequence_length_type = element::i32);
51 
52  void validate_and_infer_types() override;
53  bool visit_attributes(AttributeVisitor& visitor) override;
54 
55  std::shared_ptr<Node>
56  clone_with_new_inputs(const OutputVector& new_args) const override;
57 
58  /// \brief Get merge_repeated attribute
59  ///
60  /// \return Current value of merge_repeated attribute
61  ///
62  bool get_merge_repeated() const { return m_merge_repeated; }
63  /// \brief Get classes_index_type attribute
64  ///
65  /// \return Current value of classes_index_type attribute
66  ///
67  const element::Type& get_classes_index_type() const { return m_classes_index_type; }
68  /// \brief Set classes_index_type attribute
69  ///
70  /// \param classes_index_type Type of classes_index
71  ///
72  void set_classes_index_type(const element::Type& classes_index_type)
73  {
74  m_classes_index_type = classes_index_type;
75  validate_and_infer_types();
76  }
77 
78  /// \brief Get sequence_length_type attribute
79  ///
80  /// \return Current value of sequence_length_type attribute
81  ///
83  {
84  return m_sequence_length_type;
85  }
86 
87  /// \brief Set sequence_length_type attribute
88  ///
89  /// \param sequence_length_type Type of sequence length
90  ///
91  void set_sequence_length_type(const element::Type& sequence_length_type)
92  {
93  m_sequence_length_type = sequence_length_type;
94  validate_and_infer_types();
95  }
96 
97  private:
98  bool m_merge_repeated;
99  element::Type m_classes_index_type{element::i32};
100  element::Type m_sequence_length_type{element::i32};
101  };
102  } // namespace v6
103  } // namespace op
104 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: element_type.hpp:51
Root of all actual ops.
Definition: op.hpp:17
Operator performing CTCGreedyDecoder.
Definition: ctc_greedy_decoder_seq_len.hpp:18
void set_sequence_length_type(const element::Type &sequence_length_type)
Set sequence_length_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:91
const element::Type & get_sequence_length_type() const
Get sequence_length_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:82
void set_classes_index_type(const element::Type &classes_index_type)
Set classes_index_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:72
CTCGreedyDecoderSeqLen(const Output< Node > &input, const Output< Node > &seq_len, const Output< Node > &blank_index, const bool merge_repeated=true, const element::Type &classes_index_type=element::i32, const element::Type &sequence_length_type=element::i32)
Constructs a CTCGreedyDecoderSeqLen operation.
const element::Type & get_classes_index_type() const
Get classes_index_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:67
CTCGreedyDecoderSeqLen(const Output< Node > &input, const Output< Node > &seq_len, const bool merge_repeated=true, const element::Type &classes_index_type=element::i32, const element::Type &sequence_length_type=element::i32)
Constructs a CTCGreedyDecoderSeqLen operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool get_merge_repeated() const
Get merge_repeated attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:62
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16