ScaledDotProductAttention

Versioned name: ScaledDotProductAttention-13

Category: Sequence processing

Short description: ScaledDotProductAttention partially implements torch.nn.functional.scaled_dot_product_attention, omitting training-related parameter.

Detailed description:

ScaledDotProductAttention provides functionality according to the following pseudo-code using other operations from OpenVINO opset and numpy:

def ScaledDotProductAttention(query, key, value, attn_mask=None, scale=None, *, causal):
    L, S = Gather(ShapeOf(query), -2), Gather(ShapeOf(key), -2)
    if scale is None:
        scale = 1.0 / Sqrt(ConvertLike(Gather(ShapeOf(query), -1), query))
    attn_bias = Broadcast(ConvertLike(0, query), [L, S])
    if causal:
        attn_bias = numpy.triu(Broadcast(ConvertLike(-inf, query), [L, S]), k=1)
    elif attn_mask is not None:
        if attn_mask.element_type == boolean:
            attn_bias = Select(LogicalNot(attn_mask), ConvertLike(-inf, query), ConvertLike(0, query))
        else:
            attn_bias += attn_mask
    attn_weight = MatMul(query, Transpose(key, [-2, -1])) * scale
    attn_weight += attn_bias
    attn_weight = Softmax(attn_weight, axis=-1)
    return MatMul(attn_weight, value)

Attributes

  • causal

    • Description: If true, assumes causal attention masking according to the pseudo-code. In this case attention_mask input described below is ignored.

    • Range of values: a boolean value

    • Type: bool

    • Required: yes

Inputs

  • 1: query - at least 3 dimensional tensor of type T and shape [N, ..., L, E]. Required.

  • 2: key - at least 3 dimensional tensor of type T and shape [N, ..., S, E]. Required.

  • 3: value - at least 3 dimensional tensor of type T and shape [N, ..., S, Ev]. Required.

  • 4: attention_mask - two options available. attention_mask is ignored if causal is set to True. Optional.

    ** at least 3 dimensional tensor of type T or boolean and shape [N, ..., L, S].

    ** a scalar of type T with value 0. Scalar zero value signals that applying an attention mask is not necessary (similar to specifying attention_mask=None in the provided pseudo-code).

  • 5: scale a scalar tensor of type T, an alternative scale factor instead of 1/sqrt(query.shape[-1]) used by default in the pseudo-code above. Optional.

Outputs

  • 1: - the result of scaled dot-product attention, a tensor of type T and shape [N, ..., L, Ev].

Types

  • T: any supported floating-point type.

Dimensions

  • N, ... - one or more batch dimensions. Each batch dimension should be either constant across the input tensors (query, key, and value), indicating that they have the same batch size, or they should be broadcastable to the same value.

  • S - source sequence length

  • L - target sequence length

  • E - embedding dimension of the query and key

  • Ev - embedding dimension of the value

At least one batch dimension N is required in query, key and value inputs. Other batch dimensions ... are optional.

Examples

Example 1: One batch dimension, dynamic dimensions support

 <layer id="285" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Example with simple dimensions, with N = 1, L = -1, S = -1, E = 80, Ev = 80-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <port id="4" precision="FP32">
                                     <dim>1</dim> <!-- N -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>

Example 2: Matching multiple batch dimensions

 <layer id="286" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Multiple batch dimensions: N1 = 1, N2 = 2, N3 = 3-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <port id="4" precision="FP32">
                                     <dim>1</dim> <!-- N1 -->
                                     <dim>2</dim> <!-- N2 -->
                                     <dim>3</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>

Example 3: With batch dimensions broadcasting

 <layer id="287" name="aten::scaled_dot_product_attention_0" type="ScaledDotProductAttention" version="opset13">
                     <data causal="false" />
                     <input>
                             <!-- Multiple batch dimensions, broadcastable to the following values: N1 = 4, N2 = 6, N3 = 10-->
                             <port id="0" precision="FP32"> <!-- query -->
                                     <dim>1</dim> <!-- N1 (repeat 4 times) -->
                                     <dim>6</dim> <!-- N2 (repeat 1 time)-->
                                     <dim>5</dim> <!-- N3 (repeat 2 times)-->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="1" precision="FP32"> <!-- key -->
                                     <dim>2</dim> (repeat 2 times)<!-- N1 -->
                                     <dim>2</dim> (repeat 3 times)<!-- N2 -->
                                     <dim>2</dim> (repeat 5 times)<!-- N3 -->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- E -->
                             </port>
                             <port id="2" precision="FP32"> <!-- value -->
                                     <dim>4</dim> <!-- N1 (repeat 1 time)-->
                                     <dim>3</dim> <!-- N2 (repeat 2 times)-->
                                     <dim>10</dim> <!-- N3 (repeat 1 time)-->
                                     <dim>-1</dim> <!-- S -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                             <port id="3" precision="FP32"> <!-- attention_mask -->
                                     <dim>1</dim> <!-- N1 (repeat 4 times)-->
                                     <dim>2</dim> <!-- N2 (repeat 3 times)-->
                                     <dim>1</dim> <!-- N3 (repeat 10 times)-->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>-1</dim> <!-- S -->
                             </port>
                     </input>
                     <output>
                             <!-- Output contains broadcasted dimensions N1 = 4, N2 = 6, N3 = 10-->
                             <port id="4" precision="FP32">
                                     <dim>4</dim> <!-- N1 -->
                                     <dim>6</dim> <!-- N2 -->
                                     <dim>10</dim> <!-- N3 -->
                                     <dim>-1</dim> <!-- L -->
                                     <dim>80</dim> <!-- Ev -->
                             </port>
                     </output>
             </layer>