rnn_sequence.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 #include <string>
9 #include <vector>
10 
11 #include "ngraph/op/op.hpp"
12 #include "ngraph/op/util/rnn_cell_base.hpp"
13 
14 namespace ngraph
15 {
16  namespace op
17  {
18  namespace v5
19  {
20  class NGRAPH_API RNNSequence : public util::RNNCellBase
21  {
22  public:
23  NGRAPH_RTTI_DECLARATION;
24 
25  RNNSequence();
26 
28  const Output<Node>& X,
29  const Output<Node>& H_t,
30  const Output<Node>& sequence_lengths,
31  const Output<Node>& W,
32  const Output<Node>& R,
33  const Output<Node>& B,
34  size_t hidden_size,
36  const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
37  const std::vector<float>& activations_alpha = {},
38  const std::vector<float>& activations_beta = {},
39  float clip = 0.f);
40 
41  std::shared_ptr<Node>
42  clone_with_new_inputs(const OutputVector& new_args) const override;
43 
44  void validate_and_infer_types() override;
45 
46  bool visit_attributes(AttributeVisitor& visitor) override;
47 
48  op::RecurrentSequenceDirection get_direction() const { return m_direction; }
49 
50  protected:
52  };
53  } // namespace v5
54  } // namespace op
55 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Definition: rnn_sequence.hpp:21
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:424
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16