ie_infer_async_request_thread_safe_default.hpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
10 
11 #include <cpp_interfaces/impl/ie_infer_request_internal.hpp>
12 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
14 #include <ie_system_conf.h>
15 
16 #include <exception>
17 #include <future>
18 #include <map>
19 #include <memory>
20 #include <mutex>
21 #include <string>
22 #include <tuple>
23 #include <utility>
24 #include <vector>
25 
26 namespace InferenceEngine {
27 
28 /**
29  * @ingroup ie_dev_api_async_infer_request_api
30  * @brief Base class with default implementation of asynchronous multi staged inference request.
31  * To customize pipeline stages derived class should change the content
32  * of AsyncInferRequestThreadSafeDefault::_pipeline member container.
33  * It consists of pairs of tasks and executors which will run the task.
34  * The class is recommended to be used by plugins as a base class for asynchronous inference request implementation.
35  * @note To synchronize derived context with stages
36  * derived class should call AsyncInferRequestThreadSafeDefault::StopAndWait() function in destructor.
37  * @par Example
38  * Here is an example of asynchronous inference request implementation for some accelerator device.
39  * It uses 5 different executors to run different stages of a synchronous inference request.
40  *
41  * @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
42  */
44  enum InferState {Idle, Busy, Canceled, Stop};
45  using Futures = std::vector<std::shared_future<void>>;
46  using Promise = std::shared_ptr<std::promise<void>>;
47  enum Stage_e : std::uint8_t { executor, task };
48  InferRequestInternal::Ptr _syncRequest;
49 
50  friend struct DisableCallbackGuard;
51  struct DisableCallbackGuard {
52  explicit DisableCallbackGuard(AsyncInferRequestThreadSafeDefault* this_)
53  : _this{this_} {
54  std::lock_guard<std::mutex> lock{_this->_mutex};
55  std::swap(_callback, _this->_callback);
56  }
57  ~DisableCallbackGuard() {
58  std::lock_guard<std::mutex> lock{_this->_mutex};
59  _this->_callback = _callback;
60  }
61  AsyncInferRequestThreadSafeDefault* _this = nullptr;
62  IInferRequest::CompletionCallback _callback = nullptr;
63  };
64 
65  struct ImmediateStreamsExecutor : public InferenceEngine::ITaskExecutor {
66  explicit ImmediateStreamsExecutor(const IStreamsExecutor::Ptr& streamsExecutor) : _streamsExecutor{streamsExecutor} {}
67  void run(InferenceEngine::Task task) override {_streamsExecutor->Execute(std::move(task));}
68  IStreamsExecutor::Ptr _streamsExecutor;
69  };
70 
71  template<typename F>
72  void InferImpl(const F& f) {
73  _syncRequest->checkBlobs();
74  InferState state = InferState::Idle;
75  {
76  std::lock_guard<std::mutex> lock{_mutex};
77  state = _state;
78  switch (_state) {
79  case InferState::Busy :
80  THROW_IE_EXCEPTION_WITH_STATUS(REQUEST_BUSY);
81  case InferState::Canceled :
82  THROW_IE_EXCEPTION_WITH_STATUS(INFER_CANCELLED);
83  case InferState::Idle : {
84  _futures.erase(std::remove_if(std::begin(_futures), std::end(_futures),
85  [](const std::shared_future<void>& future) {
86  if (future.valid()) {
87  return (std::future_status::ready ==
88  future.wait_for(std::chrono::milliseconds {0}));
89  } else {
90  return true;
91  }
92  }),
93  _futures.end());
94  _promise = {};
95  _futures.emplace_back(_promise.get_future().share());
96  } break;
97  case InferState::Stop : break;
98  }
99  _state = InferState::Busy;
100  }
101  if (state != InferState::Stop) {
102  try {
103  f();
104  } catch (...) {
105  _promise.set_exception(std::current_exception());
106  std::lock_guard<std::mutex> lock{_mutex};
107  _state = InferState::Idle;
108  throw;
109  }
110  }
111  }
112 
113 protected:
114  /**
115  * @brief Throws exception if inference request is busy or canceled
116  */
117  void CheckState() const {
118  std::lock_guard<std::mutex> lock {_mutex};
119  switch (_state) {
120  case InferState::Busy :
121  THROW_IE_EXCEPTION_WITH_STATUS(REQUEST_BUSY);
122  case InferState::Canceled :
123  THROW_IE_EXCEPTION_WITH_STATUS(INFER_CANCELLED);
124  default: break;
125  }
126  }
127 
128 public:
129  /**
130  * @brief A shared pointer to AsyncInferRequestThreadSafeDefault
131  */
132  using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
133 
134  /**
135  * @brief Wraps a InferRequestInternal::Ptr implementation and constructs a
136  * AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run InferRequestInternal::Infer
137  * asynchronously.
138  *
139  * @param[in] request The synchronous request
140  * @param[in] taskExecutor The task executor
141  * @param[in] callbackExecutor The callback executor
142  */
144  const ITaskExecutor::Ptr& taskExecutor,
145  const ITaskExecutor::Ptr& callbackExecutor) :
146  _syncRequest {request},
147  _requestExecutor {taskExecutor},
148  _callbackExecutor {callbackExecutor},
149  _pipeline {{taskExecutor, [this] {_syncRequest->InferImpl();}}},
150  _syncPipeline {{std::make_shared<ImmediateExecutor>(), [this] {_syncRequest->InferImpl();}}} {
151  auto streamsExecutor = std::dynamic_pointer_cast<IStreamsExecutor>(taskExecutor);
152  if (streamsExecutor != nullptr) {
153  _syncPipeline = {{std::make_shared<ImmediateStreamsExecutor>(std::move(streamsExecutor)), [this] {_syncRequest->InferImpl();}}};
154  }
155  }
156 
157  /**
158  * @brief Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
159  */
161  StopAndWait();
162  }
163 
164  /**
165  * @brief Waits for completion of all pipeline stages
166  * If the pipeline raises an exception it will be rethrown here
167  * @param millis_timeout A timeout is `ms` to wait or special enum value of IInferRequest::WaitMode
168  * @return A status code
169  */
170  StatusCode Wait(int64_t millis_timeout) override {
171  if (millis_timeout < IInferRequest::WaitMode::RESULT_READY) {
172  THROW_IE_EXCEPTION_WITH_STATUS(PARAMETER_MISMATCH)
173  << " Timeout can't be less "
174  << IInferRequest::WaitMode::RESULT_READY << " for InferRequest::Wait\n";
175  }
176  auto status = std::future_status::deferred;
177 
178  // Just use the last '_futures' member to wait pipeline completion
179  auto future = [&] {
180  std::lock_guard<std::mutex> lock {_mutex};
181  return _futures.empty() ? std::shared_future<void> {} : _futures.back();
182  }();
183 
184  if (!future.valid()) {
185  return StatusCode::INFER_NOT_STARTED;
186  }
187 
188  switch (millis_timeout) {
189  case IInferRequest::WaitMode::RESULT_READY: {
190  future.wait();
191  status = std::future_status::ready;
192  } break;
193  case IInferRequest::WaitMode::STATUS_ONLY: {
194  status = future.wait_for(std::chrono::milliseconds {0});
195  } break;
196  default: {
197  status = future.wait_for(std::chrono::milliseconds {millis_timeout});
198  } break;
199  }
200 
201  if (std::future_status::ready == status) {
202  future.get();
203  return StatusCode::OK;
204  } else {
205  return StatusCode::RESULT_NOT_READY;
206  }
207  }
208 
209  void StartAsync() override {
210  InferImpl([&] {StartAsync_ThreadUnsafe();});
211  }
212 
213  void Infer() override {
214  DisableCallbackGuard disableCallbackGuard{this};
215  InferImpl([&] {Infer_ThreadUnsafe();});
216  Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
217  }
218 
219  std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const override {
220  CheckState();
221  return _syncRequest->GetPerformanceCounts();
222  }
223 
224  void SetBlob(const std::string& name, const Blob::Ptr& data) override {
225  CheckState();
226  _syncRequest->SetBlob(name, data);
227  }
228 
229  void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) override {
230  CheckState();
231  _syncRequest->SetBlob(name, data, info);
232  }
233 
234  Blob::Ptr GetBlob(const std::string& name) override {
235  CheckState();
236  return _syncRequest->GetBlob(name);
237  }
238 
239  const PreProcessInfo& GetPreProcess(const std::string& name) const override {
240  return _syncRequest->GetPreProcess(name);
241  }
242 
243  void SetBatch(int batch) override {
244  CheckState();
245  _syncRequest->SetBatch(batch);
246  };
247 
248  void GetUserData(void** data) override {
249  CheckState();
250  if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
251  *data = _userData;
252  }
253 
254  void SetUserData(void* data) override {
255  CheckState();
256  _userData = data;
257  }
258 
259  void SetCompletionCallback(IInferRequest::CompletionCallback callback) override {
260  CheckState();
261  _callback = callback;
262  }
263 
264  /**
265  * @brief Sets the pointer to public interface.
266  * @note Needed to correctly handle ownership between objects
267  * @param[in] ptr A shared pointer to a public IInferRequest interface.
268  */
270  _publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
271  }
272 
273  std::vector<InferenceEngine::IVariableStateInternal::Ptr> QueryState() override {
274  return _syncRequest->QueryState();
275  }
276 
277  void ThrowIfCanceled() const {
278  std::lock_guard<std::mutex> lock{_mutex};
279  if (_state == InferState::Canceled) {
280  THROW_IE_EXCEPTION_WITH_STATUS(INFER_CANCELLED);
281  }
282  }
283 
284  void Cancel() override {
285  std::lock_guard<std::mutex> lock{_mutex};
286  if (_state == InferState::Busy) {
287  _state = InferState::Canceled;
288  }
289  }
290 
291 protected:
292  /**
293  * @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
294  */
295  using Stage = std::pair<ITaskExecutor::Ptr, Task>;
296  /**
297  * @brief Pipeline is vector of stages
298  */
299  using Pipeline = std::vector<Stage>;
300 
301  /**
302  * @brief Creates and run the first stage task. If destructor was not called add a new std::future to the
303  * AsyncInferRequestThreadSafeDefault::_futures list that would be used to wait
304  * AsyncInferRequestThreadSafeDefault::_pipeline finish
305  * @param[in] itBeginStage Iterator to begin of pipeline
306  * @param[in] itEndStage End pipeline iterator
307  * @param[in] callbackExecutor Final or error stage executor
308  */
309  void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage,
310  const ITaskExecutor::Ptr callbackExecutor = {}) {
311  auto& firstStageExecutor = std::get<Stage_e::executor>(*itBeginStage);
312  IE_ASSERT(nullptr != firstStageExecutor);
313  firstStageExecutor->run(MakeNextStageTask(itBeginStage, itEndStage, std::move(callbackExecutor)));
314  }
315 
316  /**
317  * @brief Forbids pipeline start and wait for all started pipelines.
318  * @note Should be called in derived class destructor to wait for completion of usage of derived context captured by
319  * pipeline tasks
320  */
321  void StopAndWait() {
322  _callback = nullptr;
323  Futures futures;
324  InferState state = InferState::Idle;
325  {
326  std::lock_guard<std::mutex> lock{_mutex};
327  state = _state;
328  if (state != InferState::Stop) {
329  _state = InferState::Stop;
330  futures = std::move(_futures);
331  }
332  }
333  if (state != InferState::Stop) {
334  for (auto&& future : futures) {
335  if (future.valid()) {
336  future.wait();
337  }
338  }
339  }
340  }
341 
342 
343  ITaskExecutor::Ptr _requestExecutor; //!< Used to run inference CPU tasks.
344  ITaskExecutor::Ptr _callbackExecutor; //!< Used to run post inference callback in asynchronous pipline
345  ITaskExecutor::Ptr _syncCallbackExecutor; //!< Used to run post inference callback in synchronous pipline
346  Pipeline _pipeline; //!< Pipeline variable that should be filled by inherited class.
347  Pipeline _syncPipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
348 
349  /**
350  * @brief Starts an asynchronous pipeline thread unsafe.
351  * @note Used by StartAsync which ensures thread-safety and calls this method after.
352  */
353  virtual void StartAsync_ThreadUnsafe() {
354  RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
355  }
356 
357  /**
358  * @brief Performs inference of pipeline in syncronous mode
359  * @note Used by Infer which ensures thread-safety and calls this method after.
360  */
361  virtual void Infer_ThreadUnsafe() {
362  RunFirstStage(_syncPipeline.begin(), _syncPipeline.end(), _syncCallbackExecutor);
363  }
364 
365  /**
366  * @brief Implements Infer() using StartAsync() and Wait()
367  */
369  StartAsync_ThreadUnsafe();
370  }
371 
372 private:
373  /**
374  * @brief Create a task with next pipeline stage.
375  * Each call to MakeNextStageTask() generates @ref Task objects for each stage.
376  * On last stage or if the exception is raised from `_pipeline` task
377  * the last stage task is called or passed to callback executor if it is presented. The last stage task call the
378  * callback, if it is presented, capture the `_promise` member and use it to forward completion or exception to the
379  * one of `_futures` member
380  * @param[in] itStage Iterator to next stage of pipeline
381  * @param[in] itEndStage End pipeline iterator
382  * @param[in] callbackExecutor Executor that will run final stage with callback call
383  * @return A next stage task
384  */
385  Task MakeNextStageTask(const Pipeline::iterator itStage, const Pipeline::iterator itEndStage,
386  const ITaskExecutor::Ptr callbackExecutor) {
387  return std::bind([this, itStage, itEndStage](ITaskExecutor::Ptr& callbackExecutor) mutable {
388  StatusCode requestStatus = StatusCode::OK;
389  std::exception_ptr localCurrentException = nullptr;
390  auto& thisStage = *itStage;
391  auto itNextStage = itStage + 1;
392 
393  try {
394  auto& stageTask = std::get<Stage_e::task>(thisStage);
395  IE_ASSERT(nullptr != stageTask);
396  stageTask();
397  if (itEndStage != itNextStage) {
398  auto& nextStage = *itNextStage;
399  auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
400  IE_ASSERT(nullptr != nextStageExecutor);
401  nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
402  }
403  } catch (InferenceEngine::details::InferenceEngineException& ie_ex) {
404  requestStatus = ie_ex.hasStatus() ? ie_ex.getStatus() : StatusCode::GENERAL_ERROR;
405  localCurrentException = std::make_exception_ptr(ie_ex);
406  } catch (...) {
407  requestStatus = StatusCode::GENERAL_ERROR;
408  localCurrentException = std::current_exception();
409  }
410 
411  if ((itEndStage == itNextStage) || (nullptr != localCurrentException)) {
412  auto lastStageTask = [this, requestStatus, localCurrentException]() mutable {
413  auto promise = std::move(_promise);
414  IInferRequest::CompletionCallback callback = nullptr;
415  {
416  std::lock_guard<std::mutex> lock{_mutex};
417  _state = InferState::Idle;
418  callback = _callback;
419  }
420  if (nullptr != callback) {
421  InferenceEngine::CurrentException() = localCurrentException;
422  try {
423  callback(_publicInterface, requestStatus);
424  } catch (...) {
425  localCurrentException = std::current_exception();
426  }
428  }
429  if (nullptr == localCurrentException) {
430  promise.set_value();
431  } else {
432  promise.set_exception(localCurrentException);
433  }
434  };
435 
436  if (nullptr == callbackExecutor) {
437  lastStageTask();
438  } else {
439  callbackExecutor->run(std::move(lastStageTask));
440  }
441  }
442  }, std::move(callbackExecutor));
443  }
444 
445  void* _userData = nullptr;
446  IInferRequest::CompletionCallback _callback = nullptr;
447  IInferRequest::Ptr _publicInterface;
448  std::promise<void> _promise;
449  mutable std::mutex _mutex;
450  Futures _futures;
451  InferState _state = InferState::Idle;
452 };
453 } // namespace InferenceEngine
Base class with default implementation of asynchronous multi staged inference request....
Definition: ie_infer_async_request_thread_safe_default.hpp:43
void InferUsingAsync()
Implements Infer() using StartAsync() and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:368
void SetUserData(void *data) override
Set arbitrary data for the request.
Definition: ie_infer_async_request_thread_safe_default.hpp:254
StatusCode Wait(int64_t millis_timeout) override
Waits for completion of all pipeline stages If the pipeline raises an exception it will be rethrown h...
Definition: ie_infer_async_request_thread_safe_default.hpp:170
std::map< std::string, InferenceEngineProfileInfo > GetPerformanceCounts() const override
Queries performance measures per layer to get feedback of what is the most time consuming layer....
Definition: ie_infer_async_request_thread_safe_default.hpp:219
void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage, const ITaskExecutor::Ptr callbackExecutor={})
Creates and run the first stage task. If destructor was not called add a new std::future to the Async...
Definition: ie_infer_async_request_thread_safe_default.hpp:309
std::pair< ITaskExecutor::Ptr, Task > Stage
Each pipeline stage is a Task that is executed by specified ITaskExecutor implementation.
Definition: ie_infer_async_request_thread_safe_default.hpp:295
void SetBlob(const std::string &name, const Blob::Ptr &data) override
Set input/output data to infer.
Definition: ie_infer_async_request_thread_safe_default.hpp:224
std::shared_ptr< AsyncInferRequestThreadSafeDefault > Ptr
A shared pointer to AsyncInferRequestThreadSafeDefault.
Definition: ie_infer_async_request_thread_safe_default.hpp:132
void Cancel() override
Cancel current inference request execution.
Definition: ie_infer_async_request_thread_safe_default.hpp:284
void Infer() override
Infers specified input(s) in synchronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:213
void SetBatch(int batch) override
Sets new batch size when dynamic batching is enabled in executable network that created this request.
Definition: ie_infer_async_request_thread_safe_default.hpp:243
ITaskExecutor::Ptr _syncCallbackExecutor
Used to run post inference callback in synchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:345
void SetCompletionCallback(IInferRequest::CompletionCallback callback) override
Set callback function which will be called on success or failure of asynchronous request.
Definition: ie_infer_async_request_thread_safe_default.hpp:259
void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr)
Sets the pointer to public interface.
Definition: ie_infer_async_request_thread_safe_default.hpp:269
Blob::Ptr GetBlob(const std::string &name) override
Get input/output data to infer.
Definition: ie_infer_async_request_thread_safe_default.hpp:234
virtual void StartAsync_ThreadUnsafe()
Starts an asynchronous pipeline thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:353
std::vector< Stage > Pipeline
Pipeline is vector of stages.
Definition: ie_infer_async_request_thread_safe_default.hpp:299
std::vector< InferenceEngine::IVariableStateInternal::Ptr > QueryState() override
Queries memory states.
Definition: ie_infer_async_request_thread_safe_default.hpp:273
void StartAsync() override
Start inference of specified input(s) in asynchronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:209
Pipeline _syncPipeline
Synchronous pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:347
void CheckState() const
Throws exception if inference request is busy or canceled.
Definition: ie_infer_async_request_thread_safe_default.hpp:117
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo &info) override
Sets pre-process for input data.
Definition: ie_infer_async_request_thread_safe_default.hpp:229
~AsyncInferRequestThreadSafeDefault()
Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
Definition: ie_infer_async_request_thread_safe_default.hpp:160
ITaskExecutor::Ptr _callbackExecutor
Used to run post inference callback in asynchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:344
virtual void Infer_ThreadUnsafe()
Performs inference of pipeline in syncronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:361
void GetUserData(void **data) override
Get arbitrary data for the request.
Definition: ie_infer_async_request_thread_safe_default.hpp:248
Pipeline _pipeline
Pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:346
AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request, const ITaskExecutor::Ptr &taskExecutor, const ITaskExecutor::Ptr &callbackExecutor)
Wraps a InferRequestInternal::Ptr implementation and constructs a AsyncInferRequestThreadSafeDefault:...
Definition: ie_infer_async_request_thread_safe_default.hpp:143
void StopAndWait()
Forbids pipeline start and wait for all started pipelines.
Definition: ie_infer_async_request_thread_safe_default.hpp:321
ITaskExecutor::Ptr _requestExecutor
Used to run inference CPU tasks.
Definition: ie_infer_async_request_thread_safe_default.hpp:343
const PreProcessInfo & GetPreProcess(const std::string &name) const override
Gets pre-process for input data.
Definition: ie_infer_async_request_thread_safe_default.hpp:239
std::shared_ptr< Blob > Ptr
An internal API of asynchronous inference request to be implemented by plugin, which is used in Infer...
Definition: ie_iinfer_async_request_internal.hpp:22
void(* CompletionCallback)(InferenceEngine::IInferRequest::Ptr context, InferenceEngine::StatusCode code)
std::shared_ptr< IInferRequest > Ptr
std::shared_ptr< IStreamsExecutor > Ptr
Definition: ie_istreams_executor.hpp:36
Interface for Task Executor. Inference Engine uses InferenceEngine::ITaskExecutor interface to run al...
Definition: ie_itask_executor.hpp:46
std::shared_ptr< ITaskExecutor > Ptr
Definition: ie_itask_executor.hpp:51
std::shared_ptr< InferRequestInternal > Ptr
A shared pointer to a InferRequestInternal implementation.
Definition: ie_infer_request_internal.hpp:37
Wrappers from c++ function to c-style one.
#define THROW_IE_EXCEPTION_WITH_STATUS(__status)
Throws an exception along with the status (which is eventually converted to the typed exception)
Definition: exception2status.hpp:19
#define NOT_ALLOCATED_str
Defines the not allocated message.
Definition: exception2status.hpp:142
std::function< void()> Task
Inference Engine Task Executor can use any copyable callable without parameters and output as a task....
Definition: ie_itask_executor.hpp:25
#define THROW_IE_EXCEPTION
#define IE_ASSERT(EXPRESSION)
A header file for Inference Engine Immediate Executor implementation.
A header file for Inference Engine Streams-based Executor Interface.
A header file for Inference Engine Task Executor Interface.
Abstraction over platform specific implementations.
Inference Engine Plugin API namespace.
std::exception_ptr & CurrentException()
Provides the reference to static thread_local std::exception_ptr.