softmax.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/op/op.hpp"
20 
21 namespace ngraph
22 {
23  namespace op
24  {
25  namespace v0
26  {
27  /// \brief Softmax operation.
28  ///
29  class NGRAPH_DEPRECATED(
30  "This operation is deprecated and will be removed soon. "
31  "Use v1::Softmax instead of it.") NGRAPH_API Softmax : public Op
32  {
33  NGRAPH_SUPPRESS_DEPRECATED_START
34  public:
35  static constexpr NodeTypeInfo type_info{"Softmax", 0};
36  const NodeTypeInfo& get_type_info() const override { return type_info; }
37  Softmax() = default;
38  /// \brief Constructs a softmax operation.
39  ///
40  /// \param arg Node that produces the first input tensor.<br>
41  /// `[d0, ...]`
42  /// \param axes The axis positions (0-based) on which to calculate the softmax.
43  ///
44  /// Output `[d0, ...]`
45  ///
46  Softmax(const Output<Node>& arg, const AxisSet& axes);
47  /// \brief Constructs a softmax operation.
48  ///
49  /// \param arg Node that produces the first input tensor.<br>
50  /// `[d0, ...]`
51  /// \param axes node produces the axis positions (0-based) on which to calculate the
52  /// softmax.
53  ///
54  /// Output `[d0, ...]`
55  ///
56  Softmax(const Output<Node>& arg, const Output<Node>& axes);
57 
58  void validate_and_infer_types() override;
59 
60  virtual std::shared_ptr<Node>
61  clone_with_new_inputs(const OutputVector& new_args) const override;
62 
63  bool are_axes_constant() const;
64  const AxisSet get_axes() const;
65  void set_axes(const AxisSet& axes);
66 
67  bool evaluate(const HostTensorVector& outputs,
68  const HostTensorVector& inputs) const override;
69  NGRAPH_SUPPRESS_DEPRECATED_END
70  };
71  }
72 
73  namespace v1
74  {
75  class NGRAPH_API Softmax : public Op
76  {
77  public:
78  static constexpr NodeTypeInfo type_info{"Softmax", 1};
79  const NodeTypeInfo& get_type_info() const override { return type_info; }
80  Softmax()
81  : m_axis(0)
82  {
83  }
84  /// \brief Constructs a softmax operation.
85  ///
86  /// \param arg Node that produces the first input tensor.<br>
87  /// `[d0, ...]`
88  /// \param axis The axis position (0-based) on which to calculate the softmax.
89  ///
90  /// Output `[d0, ...]`
91  ///
92  Softmax(const Output<Node>& arg, const size_t axis);
93 
94  bool visit_attributes(AttributeVisitor& visitor) override;
95  void validate_and_infer_types() override;
96 
97  size_t get_version() const override { return 1; }
98  virtual std::shared_ptr<Node>
99  clone_with_new_inputs(const OutputVector& new_args) const override;
100 
101  size_t get_axis() const { return m_axis; }
102  void set_axis(const size_t axis) { m_axis = axis; }
103  bool evaluate(const HostTensorVector& outputs,
104  const HostTensorVector& inputs) const override;
105 
106  private:
107  size_t m_axis;
108  };
109  }
110 
111  // default opset version
112  NGRAPH_SUPPRESS_DEPRECATED_START
113  using v0::Softmax;
114  NGRAPH_SUPPRESS_DEPRECATED_END
115  }
116 }
ngraph::op::v1::Softmax::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: softmax.hpp:79
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::op::v1::Softmax::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::v1::Softmax
Definition: softmax.hpp:76
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::v1::Softmax::get_version
size_t get_version() const override
Definition: softmax.hpp:97
ngraph::op::v1::Softmax::Softmax
Softmax(const Output< Node > &arg, const size_t axis)
Constructs a softmax operation.
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29