matmul.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/op/op.hpp"
9 
10 namespace ngraph
11 {
12  namespace op
13  {
14  namespace v0
15  {
16  /// \brief Operator performing Matrix Multiplication.
17  class NGRAPH_API MatMul : public Op
18  {
19  public:
20  NGRAPH_RTTI_DECLARATION;
21  MatMul() = default;
22  /// \brief Constructs an Matrix Multiplication operation.
23  ///
24  /// \param A Matrix A
25  /// \param B Matrix B
26  /// \param transpose_a If matrix A should be transposed.
27  /// \param transpose_b If matrix B should be transposed.
28  MatMul(const Output<Node>& A,
29  const Output<Node>& B,
30  const bool& transpose_a = 0,
31  const bool& transpose_b = 0);
32 
33  bool visit_attributes(AttributeVisitor& visitor) override;
34  void validate_and_infer_types() override;
35 
36  virtual std::shared_ptr<Node>
37  clone_with_new_inputs(const OutputVector& new_args) const override;
38 
39  bool evaluate(const HostTensorVector& outputs,
40  const HostTensorVector& inputs) const override;
41  bool has_evaluate() const override;
42 
43  bool get_transpose_a() const { return m_transpose_a; }
44  bool get_transpose_b() const { return m_transpose_b; }
45  void set_transpose_a(bool transpose_a) { m_transpose_a = transpose_a; }
46  void set_transpose_b(bool transpose_b) { m_transpose_b = transpose_b; }
47 
48  private:
49  bool m_transpose_a;
50  bool m_transpose_b;
51  };
52  } // namespace v0
53  using v0::MatMul;
54  } // namespace op
55 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Root of all actual ops.
Definition: op.hpp:17
Operator performing Matrix Multiplication.
Definition: matmul.hpp:18
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
MatMul(const Output< Node > &A, const Output< Node > &B, const bool &transpose_a=0, const bool &transpose_b=0)
Constructs an Matrix Multiplication operation.
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16