Skip to content

vllm.model_executor.layers.attention.mm_encoder_attention

MMEncoderAttention

Bases: CustomOp

Multi-headed attention without any cache, used for multimodal encoder.

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
    """Multi-headed attention without any cache, used for multimodal encoder."""

    # --8<-- [end:mm_encoder_attn]
    @classmethod
    def compute_max_seqlen(
        cls,
        attn_backend: AttentionBackendEnum,
        cu_seqlens: np.ndarray,
    ) -> int:
        max_seqlen = 0
        if (
            attn_backend
            in (
                AttentionBackendEnum.FLASH_ATTN,
                AttentionBackendEnum.ROCM_AITER_FA,
                AttentionBackendEnum.TRITON_ATTN,
                AttentionBackendEnum.FLASHINFER,
            )
            and len(cu_seqlens) >= 2
        ):
            max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
        if attn_backend == AttentionBackendEnum.FLASHINFER:
            max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
        return max_seqlen

    @classmethod
    def maybe_compute_sequence_lengths(
        cls,
        attn_backend: AttentionBackendEnum,
        cu_seqlens: np.ndarray,
    ) -> np.ndarray | None:
        if attn_backend != AttentionBackendEnum.FLASHINFER:
            return None
        sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
        sequence_lengths = add_padding_to_seqlens(
            sequence_lengths, len(sequence_lengths), 0
        )
        return sequence_lengths

    @classmethod
    def maybe_recompute_cu_seqlens(
        cls,
        attn_backend: AttentionBackendEnum,
        cu_seqlens: np.ndarray,
        hidden_size: int,
        tp_size: int,
    ) -> np.ndarray:
        if attn_backend != AttentionBackendEnum.FLASHINFER:
            return cu_seqlens

        batch_size = len(cu_seqlens) - 1
        scale = hidden_size // tp_size
        cu_seqlens = cu_seqlens * scale

        cu_seqlens_qko = cu_seqlens
        cu_seqlens_v = cu_seqlens * 3

        cu_seqlens_qko = add_padding_to_seqlens(
            cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
        )
        cu_seqlens_v = add_padding_to_seqlens(
            cu_seqlens_v, batch_size, cu_seqlens_v[-1]
        )
        return np.concatenate([cu_seqlens_qko, cu_seqlens_v])

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float | None = None,
        num_kv_heads: int | None = None,
        prefix: str = "",
    ) -> None:
        """
        Args:
            num_heads: number of attention heads per partition.
            head_size: hidden_size per attention head.
            scale: scale factor.
            num_kv_heads: number of kv heads.
            prefix: This has no effect, it is only here to make it easier to
                    swap between Attention and MultiHeadAttention
        """
        super().__init__()

        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = 1.0 / (head_size**0.5) if scale is None else scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        self.layer_name = prefix
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
            f"divisible by num_kv_heads ({self.num_kv_heads})"
        )
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()

        # Get device-specific vision attention backend.
        self.attn_backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
        )

        self.is_flash_attn_backend = self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }

        self._fa_version = (
            get_flash_attn_version(head_size=head_size)
            if self.is_flash_attn_backend
            else None
        )

        if self.attn_backend == AttentionBackendEnum.FLASHINFER:
            _get_flashinfer_workspace_buffer()

        logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

    @classmethod
    def enabled(cls) -> bool:
        return True

    def view_qkv_to_4d(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        bsz: int,
        q_len: int,
        kv_len: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Reshape query, key, value to 4D tensors:
        (batch_size, seq_len, num_heads, head_size)
        """
        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        return query, key, value

    def _forward_sdpa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

        output = vit_torch_sdpa_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            enable_gqa=self.num_heads > self.num_kv_heads,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_fa(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

        output = vit_flash_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
            fa_version=self._fa_version,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_triton(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
    ) -> torch.Tensor:
        """Input shape:
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        assert (cu_seqlens is not None and max_seqlen is not None) or (
            cu_seqlens is None and max_seqlen is None
        ), "cu_seqlens and max_seqlen should be both set or both None."

        bsz, q_len = query.size()[:2]
        kv_len = key.size(1)
        is_reshaped = query.dim() != 4

        query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

        output = vit_triton_attn_wrapper(
            q=query,
            k=key,
            v=value,
            batch_size=bsz,
            scale=self.scale,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        if is_reshaped:
            output = output.reshape(bsz, q_len, -1)
        return output

    def _forward_flashinfer(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,
        sequence_lengths: torch.Tensor
        | None = None,  # Only used for FlashInfer CuDNN backend
    ) -> torch.Tensor:
        return vit_flashinfer_wrapper(
            q=query,
            k=key,
            v=value,
            scale=self.scale,
            workspace_buffer=_get_flashinfer_workspace_buffer(),
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
            sequence_lengths=sequence_lengths,
        )

    def forward_native(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
        sequence_lengths: torch.Tensor
        | None = None,  # Only used for FlashInfer CuDNN backend
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_cuda(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
        sequence_lengths: torch.Tensor
        | None = None,  # Only used for FlashInfer CuDNN backend
    ) -> torch.Tensor:
        if self.is_flash_attn_backend:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
            return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
            return self._forward_flashinfer(
                query, key, value, cu_seqlens, max_seqlen, sequence_lengths
            )
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for CUDA: "
                f"{self.attn_backend}."
            )

    def forward_cpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
        sequence_lengths: torch.Tensor
        | None = None,  # Only used for FlashInfer CuDNN backend
    ) -> torch.Tensor:
        return self._forward_sdpa(query, key, value, cu_seqlens)

    def forward_xpu(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        cu_seqlens: torch.Tensor | None = None,
        max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
        sequence_lengths: torch.Tensor
        | None = None,  # Only used for FlashInfer CuDNN backend
    ) -> torch.Tensor:
        if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
            return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
            return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
            return self._forward_sdpa(query, key, value, cu_seqlens)
        else:
            raise ValueError(
                f"Unsupported multi-modal encoder attention backend for XPU: "
                f"{self.attn_backend}."
            )

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None

Parameters:

Name Type Description Default
num_heads int

number of attention heads per partition.

required
head_size int

hidden_size per attention head.

required
scale float | None

scale factor.

None
num_kv_heads int | None

number of kv heads.

None
prefix str

This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention

''
Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float | None = None,
    num_kv_heads: int | None = None,
    prefix: str = "",
) -> None:
    """
    Args:
        num_heads: number of attention heads per partition.
        head_size: hidden_size per attention head.
        scale: scale factor.
        num_kv_heads: number of kv heads.
        prefix: This has no effect, it is only here to make it easier to
                swap between Attention and MultiHeadAttention
    """
    super().__init__()

    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = 1.0 / (head_size**0.5) if scale is None else scale
    self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
    self.layer_name = prefix
    assert self.num_heads % self.num_kv_heads == 0, (
        f"num_heads ({self.num_heads}) is not "
        f"divisible by num_kv_heads ({self.num_kv_heads})"
    )
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    # During model initialization, the default dtype is set as the model
    # weight and activation dtype.
    dtype = torch.get_default_dtype()

    # Get device-specific vision attention backend.
    self.attn_backend = get_vit_attn_backend(
        head_size=head_size,
        dtype=dtype,
    )

    self.is_flash_attn_backend = self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }

    self._fa_version = (
        get_flash_attn_version(head_size=head_size)
        if self.is_flash_attn_backend
        else None
    )

    if self.attn_backend == AttentionBackendEnum.FLASHINFER:
        _get_flashinfer_workspace_buffer()

    logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")

_forward_fa

_forward_fa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_fa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

    output = vit_flash_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
        fa_version=self._fa_version,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_sdpa

_forward_sdpa(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_sdpa(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

    output = vit_torch_sdpa_wrapper(
        q=query,
        k=key,
        v=value,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        enable_gqa=self.num_heads > self.num_kv_heads,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

_forward_triton

_forward_triton(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    cu_seqlens: Tensor | None = None,
    max_seqlen: Tensor | None = None,
) -> Tensor

Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def _forward_triton(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens: torch.Tensor | None = None,
    max_seqlen: torch.Tensor | None = None,  # Only used for Flash Attention
) -> torch.Tensor:
    """Input shape:
    (batch_size x seq_len x hidden_size) or
    (batch_size x seq_len x num_heads x head_size)
    """
    assert (cu_seqlens is not None and max_seqlen is not None) or (
        cu_seqlens is None and max_seqlen is None
    ), "cu_seqlens and max_seqlen should be both set or both None."

    bsz, q_len = query.size()[:2]
    kv_len = key.size(1)
    is_reshaped = query.dim() != 4

    query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)

    output = vit_triton_attn_wrapper(
        q=query,
        k=key,
        v=value,
        batch_size=bsz,
        scale=self.scale,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    if is_reshaped:
        output = output.reshape(bsz, q_len, -1)
    return output

view_qkv_to_4d

view_qkv_to_4d(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[Tensor, Tensor, Tensor]

Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size)

Source code in vllm/model_executor/layers/attention/mm_encoder_attention.py
def view_qkv_to_4d(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    bsz: int,
    q_len: int,
    kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Reshape query, key, value to 4D tensors:
    (batch_size, seq_len, num_heads, head_size)
    """
    query = query.view(bsz, q_len, self.num_heads, self.head_size)
    key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
    value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

    return query, key, value