ctc_greedy_decoder_seq_len.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/op/op.hpp"
20 
21 namespace ngraph
22 {
23  namespace op
24  {
25  namespace v6
26  {
27  /// \brief Operator performing CTCGreedyDecoder
28  ///
29  class NGRAPH_API CTCGreedyDecoderSeqLen : public Op
30  {
31  public:
32  NGRAPH_RTTI_DECLARATION;
33  CTCGreedyDecoderSeqLen() = default;
34  /// \brief Constructs a CTCGreedyDecoderSeqLen operation
35  ///
36  /// \param input 3-D tensor of logits on which greedy decoding is
37  /// performed
38  /// \param seq_len 1-D tensor of sequence lengths
39  /// \param merge_repeated Whether to merge repeated labels
40  /// \param classes_index_type Specifies the output classes_index tensor type
41  /// \param sequence_length_type Specifies the output sequence_length tensor type
43  const Output<Node>& seq_len,
44  const bool merge_repeated = true,
45  const element::Type& classes_index_type = element::i32,
46  const element::Type& sequence_length_type = element::i32);
47  /// \brief Constructs a CTCGreedyDecoderSeqLen operation
48  ///
49  /// \param input 3-D tensor of logits on which greedy decoding is
50  /// performed
51  /// \param seq_len 1-D tensor of sequence lengths
52  /// \param blank_index Scalar or 1-D tensor with 1 element used to mark a
53  /// blank index
54  /// \param merge_repeated Whether to merge repeated labels
55  /// \param classes_index_type Specifies the output classes_index tensor type
56  /// \param sequence_length_type Specifies the output sequence_length tensor type
58  const Output<Node>& seq_len,
59  const Output<Node>& blank_index,
60  const bool merge_repeated = true,
61  const element::Type& classes_index_type = element::i32,
62  const element::Type& sequence_length_type = element::i32);
63 
64  void validate_and_infer_types() override;
65  bool visit_attributes(AttributeVisitor& visitor) override;
66 
67  std::shared_ptr<Node>
68  clone_with_new_inputs(const OutputVector& new_args) const override;
69 
70  /// \brief Get merge_repeated attribute
71  ///
72  /// \return Current value of merge_repeated attribute
73  ///
74  bool get_merge_repeated() const { return m_merge_repeated; }
75  /// \brief Get classes_index_type attribute
76  ///
77  /// \return Current value of classes_index_type attribute
78  ///
79  const element::Type& get_classes_index_type() const { return m_classes_index_type; }
80  /// \brief Set classes_index_type attribute
81  ///
82  /// \param classes_index_type Type of classes_index
83  ///
84  void set_classes_index_type(const element::Type& classes_index_type)
85  {
86  m_classes_index_type = classes_index_type;
87  validate_and_infer_types();
88  }
89 
90  /// \brief Get sequence_length_type attribute
91  ///
92  /// \return Current value of sequence_length_type attribute
93  ///
95  {
96  return m_sequence_length_type;
97  }
98 
99  /// \brief Set sequence_length_type attribute
100  ///
101  /// \param sequence_length_type Type of sequence length
102  ///
103  void set_sequence_length_type(const element::Type& sequence_length_type)
104  {
105  m_sequence_length_type = sequence_length_type;
106  validate_and_infer_types();
107  }
108 
109  private:
110  bool m_merge_repeated;
111  element::Type m_classes_index_type{element::i32};
112  element::Type m_sequence_length_type{element::i32};
113  };
114  } // namespace v6
115  } // namespace op
116 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Definition: element_type.hpp:61
Root of all actual ops.
Definition: op.hpp:29
Operator performing CTCGreedyDecoder.
Definition: ctc_greedy_decoder_seq_len.hpp:30
void set_sequence_length_type(const element::Type &sequence_length_type)
Set sequence_length_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:103
const element::Type & get_sequence_length_type() const
Get sequence_length_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:94
void set_classes_index_type(const element::Type &classes_index_type)
Set classes_index_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:84
CTCGreedyDecoderSeqLen(const Output< Node > &input, const Output< Node > &seq_len, const Output< Node > &blank_index, const bool merge_repeated=true, const element::Type &classes_index_type=element::i32, const element::Type &sequence_length_type=element::i32)
Constructs a CTCGreedyDecoderSeqLen operation.
const element::Type & get_classes_index_type() const
Get classes_index_type attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:79
CTCGreedyDecoderSeqLen(const Output< Node > &input, const Output< Node > &seq_len, const bool merge_repeated=true, const element::Type &classes_index_type=element::i32, const element::Type &sequence_length_type=element::i32)
Constructs a CTCGreedyDecoderSeqLen operation.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
bool get_merge_repeated() const
Get merge_repeated attribute.
Definition: ctc_greedy_decoder_seq_len.hpp:74
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28