10 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
11 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
43 using AtomicCallback = std::atomic<IInferRequest::CompletionCallback>;
44 using Futures = std::vector<std::shared_future<void>>;
45 using Promise = std::shared_ptr<std::promise<void>>;
46 enum Stage_e : std::uint8_t { executor, task };
47 struct DisableCallbackGuard{
48 explicit DisableCallbackGuard(AtomicCallback& callback)
49 : _callbackRef(callback), _callback(callback.exchange(
nullptr)) {}
50 ~DisableCallbackGuard() {
51 _callbackRef = _callback;
53 AtomicCallback& _callbackRef;
62 using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
76 : _syncRequest {request},
79 _pipeline {{taskExecutor, [
this] {_syncRequest->Infer();}}},
80 _syncPipeline{{std::make_shared<ImmediateExecutor>(), [
this] {_syncRequest->Infer();}}} {
97 if (millis_timeout < IInferRequest::WaitMode::RESULT_READY) {
99 << IInferRequest::WaitMode::RESULT_READY <<
" for InferRequest::Wait\n";
101 auto status = std::future_status::deferred;
105 std::lock_guard<std::mutex> lock {_mutex};
106 return _futures.empty() ? std::shared_future<void> {} : _futures.back();
109 if (!future.valid()) {
110 return StatusCode::INFER_NOT_STARTED;
113 switch (millis_timeout) {
114 case IInferRequest::WaitMode::RESULT_READY: {
116 status = std::future_status::ready;
118 case IInferRequest::WaitMode::STATUS_ONLY: {
119 status = future.wait_for(std::chrono::milliseconds {0});
122 status = future.wait_for(std::chrono::milliseconds {millis_timeout});
126 if (std::future_status::ready == status) {
128 return StatusCode::OK;
130 return StatusCode::RESULT_NOT_READY;
140 _publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](
IInferRequest*) {});
147 using Stage = std::pair<ITaskExecutor::Ptr, Task>;
161 void RunFirstStage(
const Pipeline::iterator itBeginStage,
const Pipeline::iterator itEndStage,
165 std::lock_guard<std::mutex> lock(_mutex);
167 _futures.erase(std::remove_if(std::begin(_futures), std::end(_futures),
168 [](
const std::shared_future<void>& future) {
169 if (future.valid()) {
170 return (std::future_status::ready ==
171 future.wait_for(std::chrono::milliseconds {0}));
178 _futures.emplace_back(_promise.get_future().share());
185 auto& firstStageExecutor = std::get<Stage_e::executor>(*itBeginStage);
186 IE_ASSERT(
nullptr != firstStageExecutor);
187 firstStageExecutor->run(MakeNextStageTask(itBeginStage, itEndStage, std::move(callbackExecutor)));
189 _promise.set_exception(std::current_exception());
203 std::lock_guard<std::mutex> lock(_mutex);
206 for (
auto&& future : _futures) {
207 if (future.valid()) {
219 DisableCallbackGuard disableCallbackGuard{_callback};
220 StartAsync_ThreadUnsafe();
221 Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
228 DisableCallbackGuard disableCallbackGuard{_callback};
229 _syncRequest->checkBlobs();
230 RunFirstStage(_syncPipeline.begin(), _syncPipeline.end(), _syncCallbackExecutor);
231 Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
241 _syncRequest->checkBlobs();
242 RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
250 _syncRequest->GetPerformanceCounts(perfMap);
254 _syncRequest->SetBlob(name, data);
258 _syncRequest->SetBlob(name, data, info);
262 _syncRequest->GetBlob(name, data);
266 _syncRequest->GetPreProcess(name, info);
270 _callback = callback;
283 _syncRequest->SetBatch(batch);
299 Task MakeNextStageTask(
const Pipeline::iterator itStage,
const Pipeline::iterator itEndStage,
301 return std::bind([
this, itStage, itEndStage](
ITaskExecutor::Ptr& callbackExecutor)
mutable {
303 std::exception_ptr localCurrentException =
nullptr;
304 auto& thisStage = *itStage;
305 auto itNextStage = itStage + 1;
308 auto& stageTask = std::get<Stage_e::task>(thisStage);
311 if (itEndStage != itNextStage) {
312 auto& nextStage = *itNextStage;
313 auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
315 nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
317 }
catch (InferenceEngine::details::InferenceEngineException& ie_ex) {
318 requestStatus = ie_ex.hasStatus() ? ie_ex.getStatus() : StatusCode::GENERAL_ERROR;
319 localCurrentException = std::make_exception_ptr(ie_ex);
321 requestStatus = StatusCode::GENERAL_ERROR;
322 localCurrentException = std::current_exception();
325 if ((itEndStage == itNextStage) || (
nullptr != localCurrentException)) {
326 auto lastStageTask = [
this, requestStatus, localCurrentException]()
mutable {
327 auto promise = std::move(_promise);
328 auto callback = _callback.load();
329 if (setIsRequestBusy(
false)) {
330 if (
nullptr != callback) {
333 callback(_publicInterface, requestStatus);
335 localCurrentException = std::current_exception();
339 if (
nullptr == localCurrentException) {
342 promise.set_exception(localCurrentException);
347 if (
nullptr == callbackExecutor) {
350 callbackExecutor->run(std::move(lastStageTask));
353 }, std::move(callbackExecutor));
356 void* _userData =
nullptr;
357 AtomicCallback _callback = {
nullptr};
359 std::promise<void> _promise;
360 mutable std::mutex _mutex;
Inference Engine Plugin API namespace.
std::shared_ptr< AsyncInferRequestThreadSafeDefault > Ptr
A shared pointer to AsyncInferRequestThreadSafeDefault.
Definition: ie_infer_async_request_thread_safe_default.hpp:62
Wrapper of asynchronous inference request to support thread-safe execution.
Definition: ie_infer_async_request_thread_safe_internal.hpp:21
void Infer_ThreadUnsafe() override
Performs inference of pipeline in syncronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:245
void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) override
Sets the blob thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:253
Wrappers from c++ function to c-style one.
Abstraction over platform specific implementations.
void InferUsingSync()
Implements Infer() using synchronous pipeline and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:227
ITaskExecutor::Ptr _callbackExecutor
Used to run post inference callback in asynchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:235
void SetBatch_ThreadUnsafe(int batch) override
Sets the dynamic batch thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:282
A header file for Inference Engine Task Executor Interface.
ITaskExecutor::Ptr _syncCallbackExecutor
Used to run post inference callback in synchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:236
void GetPreProcess_ThreadUnsafe(const char *name, const PreProcessInfo **info) const override
Gets the preprocessing information thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:265
void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) override
Sets the completion callback thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:269
StatusCode Wait(int64_t millis_timeout) override
Waits for completion of all pipeline stages If the pipeline raises an exception it will be rethrown h...
Definition: ie_infer_async_request_thread_safe_default.hpp:96
void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data, const PreProcessInfo &info) override
Sets the blob with preprocessing information thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:257
~AsyncInferRequestThreadSafeDefault()
Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
Definition: ie_infer_async_request_thread_safe_default.hpp:86
void StartAsync_ThreadUnsafe() override
Starts an asynchronous pipeline thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:240
void InferUsingAsync()
Implements Infer() using StartAsync() and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:218
void StopAndWait()
Forbids pipeline start and wait for all started pipelines.
Definition: ie_infer_async_request_thread_safe_default.hpp:200
#define NOT_ALLOCATED_str
Defines the not allocated message.
Definition: exception2status.hpp:138
std::shared_ptr< InferRequestInternal > Ptr
A shared pointer to a InferRequestInternal implementation.
Definition: ie_infer_request_internal.hpp:37
Pipeline _pipeline
Pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:237
std::shared_ptr< ITaskExecutor > Ptr
Definition: ie_itask_executor.hpp:51
std::shared_ptr< Blob > Ptr
Pipeline _syncPipeline
Synchronous pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:238
void(* CompletionCallback)(InferenceEngine::IInferRequest::Ptr context, InferenceEngine::StatusCode code)
#define THROW_IE_EXCEPTION
std::exception_ptr & CurrentException()
Provides the reference to static thread_local std::exception_ptr.
void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage, const ITaskExecutor::Ptr callbackExecutor={})
Creates and run the first stage task. If destructor was not called add a new std::future to the Async...
Definition: ie_infer_async_request_thread_safe_default.hpp:161
void SetUserData_ThreadUnsafe(void *data) override
Sets the user data thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:278
std::pair< ITaskExecutor::Ptr, Task > Stage
Each pipeline stage is a Task that is executed by specified ITaskExecutor implementation.
Definition: ie_infer_async_request_thread_safe_default.hpp:147
#define PARAMETER_MISMATCH_str
Defines the parameter mismatch message.
Definition: exception2status.hpp:96
std::shared_ptr< IInferRequest > Ptr
std::vector< Stage > Pipeline
Pipeline is vector of stages.
Definition: ie_infer_async_request_thread_safe_default.hpp:151
ITaskExecutor::Ptr _requestExecutor
Used to run inference CPU tasks.
Definition: ie_infer_async_request_thread_safe_default.hpp:234
void GetPerformanceCounts_ThreadUnsafe(std::map< std::string, InferenceEngineProfileInfo > &perfMap) const override
Gets the performance counts thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:249
void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr)
Sets the pointer to public interface.
Definition: ie_infer_async_request_thread_safe_default.hpp:139
void GetUserData_ThreadUnsafe(void **data) override
Gets the user data thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:273
void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) override
Gets the input or output blob thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:261
#define IE_ASSERT(EXPRESSION)
Base class with default implementation of asynchronous multi staged inference request....
Definition: ie_infer_async_request_thread_safe_default.hpp:42
std::function< void()> Task
Inference Engine Task Executor can use any copyable callable without parameters and output as a task....
Definition: ie_itask_executor.hpp:25
AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request, const ITaskExecutor::Ptr &taskExecutor, const ITaskExecutor::Ptr &callbackExecutor)
Wraps a InferRequestInternal::Ptr implementation and constructs a AsyncInferRequestThreadSafeDefault:...
Definition: ie_infer_async_request_thread_safe_default.hpp:73