namespace ov::op::rnn

namespace rnn {

// global functions

template <class TShape>
void validate_inputs_rank(
    const op::util::RNNCellBase \* op,
    const std::vector<TShape>& input_shapes,
    const std::vector<Rank>& expected_ranks
    );

template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> cell_base_shape_infer(
    const op::util::RNNCellBase \* op,
    const std::vector<TShape>& input_shapes,
    size_t num_gates,
    size_t num_state_nodes,
    bool linear_before_reset = false
    );

template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> seq_base_shape_infer(
    const op::util::RNNCellBase \* op,
    const std::vector<TShape>& input_shapes,
    size_t num_gates,
    size_t num_state_nodes,
    op::RecurrentSequenceDirection direction,
    bool linear_before_reset = false
    );

} // namespace rnn