Inference Pipeline

Usually, to infer models with OpenVINO™ Runtime, you need to make the following steps in the application pipeline:

    1. Create Core object

    • 1.1. (Optional) Load extensions

    1. Read a model from a drive

    • 2.1. (Optional) Perform model preprocessing

    1. Load the model to the device

    1. Create an inference request

    1. Fill input tensors with data

    1. Start inference

    1. Process the inference results

The following code shows how to change the application code in each step to migrate to OpenVINO™ Runtime 2.0.

1. Create Core

Inference Engine API:

import numpy as np
import openvino.inference_engine as ie
core = ie.IECore()

OpenVINO™ Runtime API 2.0:

ov::Core core;
import openvino.runtime as ov
core = ov.Core()

1.1 (Optional) Load extensions

To load a model with custom operations, you need to add extensions for these operations. We highly recommend using OpenVINO Extensibility API to write extensions, but if you already have old extensions you can also load them to the new OpenVINO™ Runtime:

Inference Engine API:

core.AddExtension(std::make_shared<InferenceEngine::Extension>("path_to_extension_library.so"));
core.add_extension("path_to_extension_library.so", "CPU")

OpenVINO™ Runtime API 2.0:

core.add_extension(std::make_shared<InferenceEngine::Extension>("path_to_extension_library.so"));
core.add_extension("path_to_extension_library.so")

2. Read a model from a drive

Inference Engine API:

InferenceEngine::CNNNetwork network = core.ReadNetwork("model.xml");
network = core.read_network("model.xml")

OpenVINO™ Runtime API 2.0:

std::shared_ptr<ov::Model> model = core.read_model("model.xml");
model = core.read_model("model.xml")

Read model has the same structure as in the example from Model Creation migration guide.

Note, you can combine read and compile model stages into a single call ov::Core::compile_model(filename, devicename).

2.1 (Optional) Perform model preprocessing

When application’s input data doesn’t perfectly match the model’s input format, preprocessing steps may be necessary. See a detailed guide on how to migrate preprocessing in OpenVINO Runtime API 2.0

3. Load the Model to the Device

Inference Engine API:

InferenceEngine::ExecutableNetwork exec_network = core.LoadNetwork(network, "CPU");
# Load network to the device and create infer requests
exec_network = core.load_network(network, "CPU", num_requests=4)

OpenVINO™ Runtime API 2.0:

ov::CompiledModel compiled_model = core.compile_model(model, "CPU");
compiled_model = core.compile_model(model, "CPU")

If you need to configure OpenVINO Runtime devices with additional configuration parameters, refer to the Configure devices guide.

4. Create an Inference Request

Inference Engine API:

InferenceEngine::InferRequest infer_request = exec_network.CreateInferRequest();
# Done in the previous step

OpenVINO™ Runtime API 2.0:

ov::InferRequest infer_request = compiled_model.create_infer_request();
infer_request = compiled_model.create_infer_request()

5. Fill input tensors

The Inference Engine API fills inputs as I32 precision (not aligned with the original model):

InferenceEngine::Blob::Ptr input_blob1 = infer_request.GetBlob(inputs.begin()->first);
// fill first blob
InferenceEngine::MemoryBlob::Ptr minput1 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob1);
if (minput1) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput1->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}

InferenceEngine::Blob::Ptr input_blob2 = infer_request.GetBlob("data2");
// fill second blob
InferenceEngine::MemoryBlob::Ptr minput2 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob2);
if (minput2) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput2->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}
infer_request = exec_network.requests[0]
# Get input blobs mapped to input layers names
input_blobs = infer_request.input_blobs
data = input_blobs["data1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Fill the first blob ...
InferenceEngine::Blob::Ptr input_blob1 = infer_request.GetBlob(inputs.begin()->first);
// fill first blob
InferenceEngine::MemoryBlob::Ptr minput1 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob1);
if (minput1) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput1->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}

InferenceEngine::Blob::Ptr input_blob2 = infer_request.GetBlob("data2");
// fill second blob
InferenceEngine::MemoryBlob::Ptr minput2 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob2);
if (minput2) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput2->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}
infer_request = exec_network.requests[0]
# Get input blobs mapped to input layers names
input_blobs = infer_request.input_blobs
data = input_blobs["data1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Fill the first blob ...
InferenceEngine::Blob::Ptr input_blob1 = infer_request.GetBlob(inputs.begin()->first);
// fill first blob
InferenceEngine::MemoryBlob::Ptr minput1 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob1);
if (minput1) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput1->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}

InferenceEngine::Blob::Ptr input_blob2 = infer_request.GetBlob("data2");
// fill second blob
InferenceEngine::MemoryBlob::Ptr minput2 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob2);
if (minput2) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput2->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}
infer_request = exec_network.requests[0]
# Get input blobs mapped to input layers names
input_blobs = infer_request.input_blobs
data = input_blobs["data1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Fill the first blob ...
InferenceEngine::Blob::Ptr input_blob1 = infer_request.GetBlob(inputs.begin()->first);
// fill first blob
InferenceEngine::MemoryBlob::Ptr minput1 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob1);
if (minput1) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput1->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}

InferenceEngine::Blob::Ptr input_blob2 = infer_request.GetBlob("data2");
// fill second blob
InferenceEngine::MemoryBlob::Ptr minput2 = InferenceEngine::as<InferenceEngine::MemoryBlob>(input_blob2);
if (minput2) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = minput2->wmap();
    // Original I64 precision was converted to I32
    auto data = minputHolder.as<InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // Fill data ...
}
infer_request = exec_network.requests[0]
# Get input blobs mapped to input layers names
input_blobs = infer_request.input_blobs
data = input_blobs["data1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Fill the first blob ...

OpenVINO™ Runtime API 2.0 fills inputs as I64 precision (aligned with the original model):

// Get input tensor by index
ov::Tensor input_tensor1 = infer_request.get_input_tensor(0);
// IR v10 works with converted precisions (i64 -> i32)
auto data1 = input_tensor1.data<int32_t>();
// Fill first data ...

// Get input tensor by tensor name
ov::Tensor input_tensor2 = infer_request.get_tensor("data2_t");
// IR v10 works with converted precisions (i64 -> i32)
auto data2 = input_tensor1.data<int32_t>();
// Fill first data ...
# Get input tensor by index
input_tensor1 = infer_request.get_input_tensor(0)
# IR v10 works with converted precisions (i64 -> i32)
assert input_tensor1.data.dtype == np.int32
# Fill the first data ...

# Get input tensor by tensor name
input_tensor2 = infer_request.get_tensor("data2_t")
# IR v10 works with converted precisions (i64 -> i32)
assert input_tensor2.data.dtype == np.int32
# Fill the second data ..
// Get input tensor by index
ov::Tensor input_tensor1 = infer_request.get_input_tensor(0);
// Element types, names and layouts are aligned with framework
auto data1 = input_tensor1.data<int64_t>();
// Fill first data ...

// Get input tensor by tensor name
ov::Tensor input_tensor2 = infer_request.get_tensor("data2_t");
// Element types, names and layouts are aligned with framework
auto data2 = input_tensor1.data<int64_t>();
// Fill first data ...
# Get input tensor by index
input_tensor1 = infer_request.get_input_tensor(0)
# Element types, names and layouts are aligned with framework
assert input_tensor1.data.dtype == np.int64
# Fill the first data ...

# Get input tensor by tensor name
input_tensor2 = infer_request.get_tensor("data2_t")
assert input_tensor2.data.dtype == np.int64
# Fill the second data ...
// Get input tensor by index
ov::Tensor input_tensor1 = infer_request.get_input_tensor(0);
// Element types, names and layouts are aligned with framework
auto data1 = input_tensor1.data<int64_t>();
// Fill first data ...

// Get input tensor by tensor name
ov::Tensor input_tensor2 = infer_request.get_tensor("data2_t");
// Element types, names and layouts are aligned with framework
auto data2 = input_tensor1.data<int64_t>();
// Fill first data ...
# Get input tensor by index
input_tensor1 = infer_request.get_input_tensor(0)
# Element types, names and layouts are aligned with framework
assert input_tensor1.data.dtype == np.int64
# Fill the first data ...

# Get input tensor by tensor name
input_tensor2 = infer_request.get_tensor("data2_t")
assert input_tensor2.data.dtype == np.int64
# Fill the second data ...
// Get input tensor by index
ov::Tensor input_tensor1 = infer_request.get_input_tensor(0);
// Element types, names and layouts are aligned with framework
auto data1 = input_tensor1.data<int64_t>();
// Fill first data ...

// Get input tensor by tensor name
ov::Tensor input_tensor2 = infer_request.get_tensor("data2_t");
// Element types, names and layouts are aligned with framework
auto data2 = input_tensor1.data<int64_t>();
// Fill first data ...
# Get input tensor by index
input_tensor1 = infer_request.get_input_tensor(0)
# Element types, names and layouts are aligned with framework
assert input_tensor1.data.dtype == np.int64
# Fill the first data ...

# Get input tensor by tensor name
input_tensor2 = infer_request.get_tensor("data2_t")
assert input_tensor2.data.dtype == np.int64
# Fill the second data ...

6. Start Inference

Inference Engine API:

infer_request.Infer();
results = infer_request.infer()
// NOTE: For demonstration purposes we are trying to set callback
// which restarts inference inside one more time, so two inferences happen here

// Start inference without blocking current thread
auto restart_once = true;
infer_request.SetCompletionCallback<std::function<void(InferenceEngine::InferRequest, InferenceEngine::StatusCode)>>(
    [&, restart_once](InferenceEngine::InferRequest request, InferenceEngine::StatusCode status) mutable {
        if (status != InferenceEngine::OK) {
            // Process error code
        } else {
            // Extract inference result
            InferenceEngine::Blob::Ptr output_blob = request.GetBlob(outputs.begin()->first);
            // Restart inference if needed
            if (restart_once) {
                request.StartAsync();
                restart_once = false;
            }
        }
    });
infer_request.StartAsync();
// Get inference status immediately
InferenceEngine::StatusCode status = infer_request.Wait(InferenceEngine::InferRequest::STATUS_ONLY);
// Wait for 1 milisecond
status = infer_request.Wait(1);
// Wait for inference completion
infer_request.Wait(InferenceEngine::InferRequest::RESULT_READY);
# Start async inference on a single infer request
infer_request.async_infer()
# Wait for 1 milisecond
infer_request.wait(1)
# Wait for inference completion
infer_request.wait()

# Demonstrates async pipeline using ExecutableNetwork

results = []

# Callback to process inference results
def callback(output_blobs, _):
    # Copy the data from output blobs to numpy array
    results_copy = {out_name: out_blob.buffer[:] for out_name, out_blob in output_blobs.items()}
    results.append(process_results(results_copy))

# Setting callback for each infer requests
for infer_request in exec_network.requests:
    infer_request.set_completion_callback(callback, py_data=infer_request.output_blobs)

# Async pipline is managed by ExecutableNetwork
total_frames = 100
for _ in range(total_frames):
    # Wait for at least one free request
    exec_network.wait(num_request=1)
    # Get idle id
    idle_id = exec_network.get_idle_request_id()
    # Start asynchronous inference on idle request
    exec_network.start_async(request_id=idle_id, inputs=next(input_data))
# Wait for all requests to complete
exec_network.wait()

OpenVINO™ Runtime API 2.0:

infer_request.infer();
results = infer_request.infer()
// NOTE: For demonstration purposes we are trying to set callback
// which restarts inference inside one more time, so two inferences happen here

auto restart_once = true;
infer_request.set_callback([&, restart_once](std::exception_ptr exception_ptr) mutable {
    if (exception_ptr) {
        // procces exception or rethrow it.
        std::rethrow_exception(exception_ptr);
    } else {
        // Extract inference result
        ov::Tensor output_tensor = infer_request.get_output_tensor();
        // Restart inference if needed
        if (restart_once) {
            infer_request.start_async();
            restart_once = false;
        }
    }
});
// Start inference without blocking current thread
infer_request.start_async();
// Get inference status immediately
bool status = infer_request.wait_for(std::chrono::milliseconds{0});
// Wait for one milisecond
status = infer_request.wait_for(std::chrono::milliseconds{1});
// Wait for inference completion
infer_request.wait();
# Start async inference on a single infer request
infer_request.start_async()
# Wait for 1 milisecond
infer_request.wait_for(1)
# Wait for inference completion
infer_request.wait()

# Demonstrates async pipeline using AsyncInferQueue

results = []

def callback(request, frame_id):
    # Copy the data from output tensors to numpy array and process it
    results_copy = {output: data[:] for output, data in request.results.items()}
    results.append(process_results(results_copy, frame_id))

# Create AsyncInferQueue with 4 infer requests
infer_queue = ov.AsyncInferQueue(compiled_model, jobs=4)
# Set callback for each infer request in the queue
infer_queue.set_callback(callback)

total_frames = 100
for i in range(total_frames):
    # Wait for at least one available infer request and start asynchronous inference
    infer_queue.start_async(next(input_data), userdata=i)
# Wait for all requests to complete
infer_queue.wait_all()

7. Process the Inference Results

The Inference Engine API processes outputs as I32 precision (not aligned with the original model):

InferenceEngine::Blob::Ptr output_blob = infer_request.GetBlob(outputs.begin()->first);
InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
if (moutput) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = moutput->rmap();
    // Original I64 precision was converted to I32
    auto data =
        minputHolder.as<const InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // process output data
}
# Get output blobs mapped to output layers names
output_blobs = infer_request.output_blobs
data = output_blobs["out1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Process output data
InferenceEngine::Blob::Ptr output_blob = infer_request.GetBlob(outputs.begin()->first);
InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
if (moutput) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = moutput->rmap();
    // Original I64 precision was converted to I32
    auto data =
        minputHolder.as<const InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // process output data
}
# Get output blobs mapped to output layers names
output_blobs = infer_request.output_blobs
data = output_blobs["out1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Process output data
InferenceEngine::Blob::Ptr output_blob = infer_request.GetBlob(outputs.begin()->first);
InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
if (moutput) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = moutput->rmap();
    // Original I64 precision was converted to I32
    auto data =
        minputHolder.as<const InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // process output data
}
# Get output blobs mapped to output layers names
output_blobs = infer_request.output_blobs
data = output_blobs["out1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Process output data
InferenceEngine::Blob::Ptr output_blob = infer_request.GetBlob(outputs.begin()->first);
InferenceEngine::MemoryBlob::Ptr moutput = InferenceEngine::as<InferenceEngine::MemoryBlob>(output_blob);
if (moutput) {
    // locked memory holder should be alive all time while access to its
    // buffer happens
    auto minputHolder = moutput->rmap();
    // Original I64 precision was converted to I32
    auto data =
        minputHolder.as<const InferenceEngine::PrecisionTrait<InferenceEngine::Precision::I32>::value_type\*>();
    // process output data
}
# Get output blobs mapped to output layers names
output_blobs = infer_request.output_blobs
data = output_blobs["out1"].buffer
# Original I64 precision was converted to I32
assert data.dtype == np.int32
# Process output data

OpenVINO™ Runtime API 2.0 processes outputs:

  • For IR v10 as I32 precision (not aligned with the original model) to match the old behavior.

  • For IR v11, ONNX, ov::Model, Paddle as I64 precision (aligned with the original model) to match the new behavior.

// model has only one output
ov::Tensor output_tensor = infer_request.get_output_tensor();
// IR v10 works with converted precisions (i64 -> i32)
auto out_data = output_tensor.data<int32_t>();
// process output data
# Model has only one output
output_tensor = infer_request.get_output_tensor()
# IR v10 works with converted precisions (i64 -> i32)
assert output_tensor.data.dtype == np.int32
# process output data ...
// model has only one output
ov::Tensor output_tensor = infer_request.get_output_tensor();
// Element types, names and layouts are aligned with framework
auto out_data = output_tensor.data<int64_t>();
// process output data
# Model has only one output
output_tensor = infer_request.get_output_tensor()
# Element types, names and layouts are aligned with framework
assert output_tensor.data.dtype == np.int64
# process output data ...
// model has only one output
ov::Tensor output_tensor = infer_request.get_output_tensor();
// Element types, names and layouts are aligned with framework
auto out_data = output_tensor.data<int64_t>();
// process output data
# Model has only one output
output_tensor = infer_request.get_output_tensor()
# Element types, names and layouts are aligned with framework
assert output_tensor.data.dtype == np.int64
# process output data ...
// model has only one output
ov::Tensor output_tensor = infer_request.get_output_tensor();
// Element types, names and layouts are aligned with framework
auto out_data = output_tensor.data<int64_t>();
// process output data
# Model has only one output
output_tensor = infer_request.get_output_tensor()
# Element types, names and layouts are aligned with framework
assert output_tensor.data.dtype == np.int64
# process output data ...