gru_sequence.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include <memory>
8 #include <string>
9 #include <vector>
10 
11 #include "ngraph/op/op.hpp"
12 #include "ngraph/op/util/rnn_cell_base.hpp"
13 
14 namespace ngraph
15 {
16  namespace op
17  {
18  namespace v5
19  {
20  class NGRAPH_API GRUSequence : public util::RNNCellBase
21  {
22  public:
23  NGRAPH_RTTI_DECLARATION;
24  GRUSequence();
25 
26  GRUSequence(const Output<Node>& X,
27  const Output<Node>& H_t,
28  const Output<Node>& sequence_lengths,
29  const Output<Node>& W,
30  const Output<Node>& R,
31  const Output<Node>& B,
32  size_t hidden_size,
34  const std::vector<std::string>& activations =
35  std::vector<std::string>{"sigmoid", "tanh"},
36  const std::vector<float>& activations_alpha = {},
37  const std::vector<float>& activations_beta = {},
38  float clip = 0.f,
39  bool linear_before_reset = false);
40 
41  std::shared_ptr<Node>
42  clone_with_new_inputs(const OutputVector& new_args) const override;
43 
44  void validate_and_infer_types() override;
45 
46  bool visit_attributes(AttributeVisitor& visitor) override;
47  bool get_linear_before_reset() const { return m_linear_before_reset; }
48  op::RecurrentSequenceDirection get_direction() const { return m_direction; }
49 
50  protected:
52  bool m_linear_before_reset;
53  };
54  } // namespace v5
55  } // namespace op
56 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base class for all recurrent network cells.
Definition: rnn_cell_base.hpp:55
Definition: gru_sequence.hpp:21
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
RecurrentSequenceDirection
This class defines possible recurrent sequence directions.
Definition: attr_types.hpp:424
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16