GatedDeltaNet#
Versioned name: GatedDeltaNet
Category: Sequence processing
Short description: GatedDeltaNet represents a linear recurrent sequence model that combines the delta rule memory update with a gating mechanism.
Detailed description: GatedDeltaNet implements the recurrence from the paper
arXiv:2412.06464. It processes a sequence of
query, key, and value vectors using the delta rule to update a hidden state matrix,
controlled by a per-token forget gate (applied as exp(g)) and a per-token
write gate beta. Queries are scaled by 1 / sqrt(key_head_dim) before being used
to compute the output. The following PyTorch-equivalent code illustrates the full
computation:
def torch_recurrent_gated_delta_rule(
query, key, value, recurrent_state, gate, beta,
):
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
output_attn = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
output_recurrent_state = recurrent_state
for i in range(sequence_length):
q_t = query[:, i]
k_t = key[:, i]
v_t = value[:, i]
g_t = gate[:, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, i].unsqueeze(-1)
output_recurrent_state = output_recurrent_state * g_t
kv_mem = (output_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
output_recurrent_state = output_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
output_attn[:, i] = (output_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
return output_attn, output_recurrent_state
Inputs
1:
query- 4D tensor of type T and shape[batch_size, seq_len, num_heads, key_head_dim], the query vectors for each token and head. Scaled internally by1 / sqrt(key_head_dim)before computing the output. Required.2:
key- 4D tensor of type T and shape[batch_size, seq_len, num_heads, key_head_dim], the key vectors for each token and head. Required.3:
value- 4D tensor of type T and shape[batch_size, seq_len, num_heads, value_head_dim], the value vectors for each token and head. Required.4:
recurrent_state- 4D tensor of type T and shape[batch_size, num_heads, key_head_dim, value_head_dim], the recurrent (initially all-zeros) hidden state matrix. Required.5:
gate- 3D tensor of type T and shape[batch_size, seq_len, num_heads], the forget gate in log-space. Applied asexp(g)at each time step to decay the hidden state before the delta update. Required.6:
beta- 3D tensor of type T and shape[batch_size, seq_len, num_heads], the write gate controlling how much of the delta correction is applied to the hidden state. Required.
Outputs
1:
output_attn- 4D tensor of type T and shape[batch_size, seq_len, num_heads, value_head_dim], the output vectors at each time step produced by applying the state matrix to the (scaled) query.2:
output_recurrent_state- 4D tensor of type T and shape[batch_size, num_heads, key_head_dim, value_head_dim], the hidden state matrix after processing the last token in the sequence.
Types
T: any supported floating-point type.
Example
<layer ... type="GatedDeltaNet" ...>
<input>
<port id="0"> <!-- `query` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
<dim>64</dim>
</port>
<port id="1"> <!-- `key` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
<dim>64</dim>
</port>
<port id="2"> <!-- `value` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
<dim>128</dim>
</port>
<port id="3"> <!-- `recurrent_state` -->
<dim>1</dim>
<dim>8</dim>
<dim>64</dim>
<dim>128</dim>
</port>
<port id="4"> <!-- `gate` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
</port>
<port id="5"> <!-- `beta` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
</port>
</input>
<output>
<port id="6"> <!-- `output_attn` -->
<dim>1</dim>
<dim>16</dim>
<dim>8</dim>
<dim>128</dim>
</port>
<port id="7"> <!-- `output_recurrent_state` -->
<dim>1</dim>
<dim>8</dim>
<dim>64</dim>
<dim>128</dim>
</port>
</output>
</layer>