namespace rnn {
// global functions
template <class TShape>
void validate_inputs_rank(
const op::util::RNNCellBase \* op,
const std::vector<TShape>& input_shapes,
const std::vector<Rank>& expected_ranks
);
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> cell_base_shape_infer(
const op::util::RNNCellBase \* op,
const std::vector<TShape>& input_shapes,
size_t num_gates,
size_t num_state_nodes,
bool linear_before_reset = false
);
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> seq_base_shape_infer(
const op::util::RNNCellBase \* op,
const std::vector<TShape>& input_shapes,
size_t num_gates,
size_t num_state_nodes,
op::RecurrentSequenceDirection direction,
bool linear_before_reset = false
);
} // namespace rnn