Skip to content

vllm.v1.serial_utils

MsgpackDecoder

Decoder with custom torch tensor and numpy array serialization.

Note that unlike vanilla msgspec Decoders, this interface is generally not thread-safe when encoding tensors / numpy arrays.

For multimodal tensors sent via torch.multiprocessing.Queue (when IPC is enabled), they will be retrieved via the tensor_ipc_receiver during decoding. Works for both CUDA and CPU tensors.

Source code in vllm/v1/serial_utils.py
class MsgpackDecoder:
    """Decoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Decoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.

    For multimodal tensors sent via torch.multiprocessing.Queue (when IPC
    is enabled), they will be retrieved via the ``tensor_ipc_receiver``
    during decoding.  Works for both CUDA and CPU tensors.
    """

    def __init__(
        self,
        t: Any | None = None,
        share_mem: bool = True,
        tensor_ipc_receiver: TensorIpcReceiver | None = None,
    ):
        self.share_mem = share_mem
        self.pin_tensors = is_pin_memory_available()
        args = () if t is None else (t,)
        self.decoder = msgpack.Decoder(
            *args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
        )
        self.aux_buffers: Sequence[bytestr] = ()
        # Optional receiver for tensor IPC via torch.multiprocessing.Queue
        self.tensor_ipc_receiver = tensor_ipc_receiver
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            _log_insecure_serialization_warning()

    def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
        if isinstance(bufs, bytestr):  # type: ignore
            return self.decoder.decode(bufs)

        self.aux_buffers = bufs
        try:
            return self.decoder.decode(bufs[0])
        finally:
            self.aux_buffers = ()

    def dec_hook(self, t: type, obj: Any) -> Any:
        # Given native types in `obj`, convert to type `t`.
        if isclass(t):
            if issubclass(t, np.ndarray):
                return self._decode_ndarray(obj)
            if issubclass(t, TensorIpcHandle):
                # msgspec deserializes dataclasses to dicts, so convert
                # to TensorIpcHandle
                if isinstance(obj, dict):
                    obj = TensorIpcHandle(**obj)
                return self._decode_ipc_queue_tensor(obj)
            if issubclass(t, torch.Tensor):
                return self._decode_tensor(obj)
            if t is slice:
                return slice(*obj)
            if issubclass(t, MultiModalKwargsItem):
                return self._decode_mm_item(obj)
            if issubclass(t, MultiModalKwargsItems):
                return self._decode_mm_items(obj)
            if t is UtilityResult:
                return self._decode_utility_result(obj)
        return obj

    def _decode_utility_result(self, obj: Any) -> UtilityResult:
        result_type, result = obj
        if result_type is not None:
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
                raise TypeError(
                    "VLLM_ALLOW_INSECURE_SERIALIZATION must "
                    "be set to use custom utility result types"
                )
            # Use recursive decoding to handle nested structures
            result = _decode_type_info_recursive(
                result_type, result, self._convert_result
            )
        return UtilityResult(result)

    def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
        if result_type is None:
            return result
        mod_name, name = result_type
        mod = importlib.import_module(mod_name)
        result_type = getattr(mod, name)
        return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

    def _decode_ndarray(self, arr: Any) -> np.ndarray:
        dtype, shape, data = arr
        # zero-copy decode. We assume the ndarray will not be kept around,
        # as it now locks the whole received message buffer in memory.
        buffer = self.aux_buffers[data] if isinstance(data, int) else data
        arr = np.frombuffer(buffer, dtype=dtype)
        if not self.share_mem:
            arr = arr.copy()
        return arr.reshape(shape)

    def _decode_tensor(self, arr: Any) -> torch.Tensor:
        # Check if this is a TensorIpcHandle (sent via IPC queue)
        # This can happen when IPC is enabled for non-multimodal tensor fields
        if isinstance(arr, TensorIpcHandle):
            return self._decode_ipc_queue_tensor(arr)
        # Check if this is a dict that represents a TensorIpcHandle
        # (msgspec serializes dataclasses as dicts without type info)
        if (
            isinstance(arr, dict)
            and "tensor_id" in arr
            and "shape" in arr
            and "dtype" in arr
            and "device" in arr
        ):
            # Convert dict to TensorIpcHandle and decode it
            handle = TensorIpcHandle(**arr)
            return self._decode_ipc_queue_tensor(handle)
        # Check if this is a list/tuple with 5 elements (TensorIpcHandle)
        # msgspec serializes NamedTuples as lists
        if isinstance(arr, (list, tuple)) and len(arr) == 5:
            # Convert list to TensorIpcHandle and decode it
            handle = TensorIpcHandle(*arr)
            return self._decode_ipc_queue_tensor(handle)

        # Standard tensor decoding
        dtype, shape, data = arr
        is_aux = isinstance(data, int)
        buffer = self.aux_buffers[data] if is_aux else data
        buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
        torch_dtype = getattr(torch, dtype)
        assert isinstance(torch_dtype, torch.dtype)
        if not buffer.nbytes:  # torch.frombuffer doesn't like empty buffers
            assert 0 in shape
            return torch.empty(shape, dtype=torch_dtype)
        # Create uint8 array
        arr = torch.frombuffer(buffer, dtype=torch.uint8)
        # Clone ensures tensor is backed by pytorch-owned memory for safe
        # future async CPU->GPU transfer.
        # Pin larger tensors for more efficient CPU->GPU transfer.
        if not is_aux:
            arr = arr.clone()
        elif not self.share_mem:
            arr = arr.pin_memory() if self.pin_tensors else arr.clone()
        # Convert back to proper shape & type
        return arr.view(torch_dtype).view(shape)

    def _decode_ipc_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor:
        """Retrieve a tensor from torch.multiprocessing.Queue.

        Delegates to the TensorIpcReceiver. Works for CUDA and CPU.
        """
        assert self.tensor_ipc_receiver is not None, "Tensor IPC receiver is not set"
        return self.tensor_ipc_receiver.recv_tensor(handle)

    def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
        return MultiModalKwargsItems(
            {
                modality: [self._decode_mm_item(item) for item in itemlist]
                for modality, itemlist in obj.items()
            }
        )

    def _decode_mm_item(self, obj: dict[str, Any]) -> MultiModalKwargsItem:
        return MultiModalKwargsItem(
            {key: self._decode_mm_field_elem(elem) for key, elem in obj.items()}
        )

    def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
        if obj["data"] is not None:
            obj["data"] = self._decode_nested_tensors(obj["data"])

        # Reconstruct the field processor using MultiModalFieldConfig
        factory_meth_name, factory_kw = obj["field"]
        factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)

        # Special case: decode the union "slices" field of
        # MultiModalFlatField
        if factory_meth_name == "flat":
            factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])

        obj["field"] = factory_meth("", **factory_kw).field
        return MultiModalFieldElem(**obj)

    def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
        if isinstance(obj, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return obj
        if isinstance(obj, TensorIpcHandle):
            return self._decode_ipc_queue_tensor(obj)
        # Check if this is a dict that represents a TensorIpcHandle
        # (msgspec serializes dataclasses as dicts without type info
        # in nested structures)
        if (
            isinstance(obj, dict)
            and "tensor_id" in obj
            and "shape" in obj
            and "dtype" in obj
            and "device" in obj
        ):
            # Convert dict to TensorIpcHandle and decode it
            # Handle both new format (with request_id) and old format (without)
            handle = TensorIpcHandle(**obj)
            return self._decode_ipc_queue_tensor(handle)
        if not isinstance(obj, list):
            raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
        if obj and isinstance(obj[0], str):
            return self._decode_tensor(obj)
        return [self._decode_nested_tensors(x) for x in obj]

    def _decode_nested_slices(self, obj: Any) -> Any:
        assert isinstance(obj, (list, tuple))
        if obj and not isinstance(obj[0], (list, tuple)):
            return slice(*obj)
        return [self._decode_nested_slices(x) for x in obj]

    def cleanup_request_tensors(self, request_id: str) -> int:
        """Remove all orphaned tensors associated with a request.

        Pass-through to the TensorIpcReceiver. Returns 0 if no receiver.
        """
        if self.tensor_ipc_receiver is None:
            return 0
        return self.tensor_ipc_receiver.cleanup_request_tensors(request_id)

    def ext_hook(self, code: int, data: memoryview) -> Any:
        if code == CUSTOM_TYPE_RAW_VIEW:
            return data

        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            if code == CUSTOM_TYPE_PICKLE:
                return pickle.loads(data)
            if code == CUSTOM_TYPE_CLOUDPICKLE:
                return cloudpickle.loads(data)

        raise NotImplementedError(f"Extension type code {code} is not supported")

_decode_ipc_queue_tensor

_decode_ipc_queue_tensor(handle: TensorIpcHandle) -> Tensor

Retrieve a tensor from torch.multiprocessing.Queue.

Delegates to the TensorIpcReceiver. Works for CUDA and CPU.

Source code in vllm/v1/serial_utils.py
def _decode_ipc_queue_tensor(self, handle: TensorIpcHandle) -> torch.Tensor:
    """Retrieve a tensor from torch.multiprocessing.Queue.

    Delegates to the TensorIpcReceiver. Works for CUDA and CPU.
    """
    assert self.tensor_ipc_receiver is not None, "Tensor IPC receiver is not set"
    return self.tensor_ipc_receiver.recv_tensor(handle)

cleanup_request_tensors

cleanup_request_tensors(request_id: str) -> int

Remove all orphaned tensors associated with a request.

Pass-through to the TensorIpcReceiver. Returns 0 if no receiver.

Source code in vllm/v1/serial_utils.py
def cleanup_request_tensors(self, request_id: str) -> int:
    """Remove all orphaned tensors associated with a request.

    Pass-through to the TensorIpcReceiver. Returns 0 if no receiver.
    """
    if self.tensor_ipc_receiver is None:
        return 0
    return self.tensor_ipc_receiver.cleanup_request_tensors(request_id)

MsgpackEncoder

Encoder with custom torch tensor and numpy array serialization.

Note that unlike vanilla msgspec Encoders, this interface is generally not thread-safe when encoding tensors / numpy arrays.

By default, arrays below 256B are serialized inline Larger will get sent via dedicated messages. Note that this is a per-tensor limit.

When a tensor_ipc_sender is provided, tensors (CUDA and CPU) will be sent via torch.multiprocessing.Queue for zero-copy IPC instead of serialization.

Source code in vllm/v1/serial_utils.py
class MsgpackEncoder:
    """Encoder with custom torch tensor and numpy array serialization.

    Note that unlike vanilla `msgspec` Encoders, this interface is generally
    not thread-safe when encoding tensors / numpy arrays.

    By default, arrays below 256B are serialized inline Larger will get sent
    via dedicated messages. Note that this is a per-tensor limit.

    When a ``tensor_ipc_sender`` is provided, tensors (CUDA and CPU) will be
    sent via torch.multiprocessing.Queue for zero-copy IPC instead of
    serialization.
    """

    def __init__(
        self,
        size_threshold: int | None = None,
        tensor_ipc_sender: TensorIpcSender | None = None,
    ):
        if size_threshold is None:
            size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
        self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
        # This is used as a local stash of buffers that we can then access from
        # our custom `msgspec` hook, `enc_hook`. We don't have a way to
        # pass custom data to the hook otherwise.
        self.aux_buffers: list[bytestr] | None = None
        self.size_threshold = size_threshold
        # Optional sender for tensor IPC via torch.multiprocessing.Queue
        self.tensor_ipc_sender = tensor_ipc_sender
        # Counter for generating unique tensor IDs
        self._tensor_id_counter = 0
        # Current request ID being encoded (for associating tensors with requests)
        self._current_request_id: str | None = None
        if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            _log_insecure_serialization_warning()

    def set_request_context(self, request_id: str | None) -> None:
        """Set the current request ID being encoded (for tensor association)."""
        self._current_request_id = request_id

    def encode(self, obj: Any) -> Sequence[bytestr]:
        try:
            self.aux_buffers = bufs = [b""]
            bufs[0] = self.encoder.encode(obj)
            # This `bufs` list allows us to collect direct pointers to backing
            # buffers of tensors and np arrays, and return them along with the
            # top-level encoded buffer instead of copying their data into the
            # new buffer.
            return bufs
        finally:
            self.aux_buffers = None

    def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
        try:
            self.aux_buffers = [buf]
            bufs = self.aux_buffers
            self.encoder.encode_into(obj, buf)
            return bufs
        finally:
            self.aux_buffers = None

    def enc_hook(self, obj: Any) -> Any:
        if isinstance(obj, torch.Tensor):
            return self._encode_tensor(obj)

        # Fall back to pickle for object or void kind ndarrays.
        if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
            return self._encode_ndarray(obj)

        if isinstance(obj, slice):
            # We are assuming only int-based values will be used here.
            return tuple(
                int(v) if v is not None else None
                for v in (obj.start, obj.stop, obj.step)
            )

        if isinstance(obj, MultiModalKwargsItem):
            return self._encode_mm_item(obj)

        if isinstance(obj, MultiModalKwargsItems):
            return self._encode_mm_items(obj)

        if isinstance(obj, UtilityResult):
            result = obj.result
            if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
                return None, result
            # Since utility results are not strongly typed, we recursively
            # encode type information for nested structures of lists/dicts
            # to help with correct msgspec deserialization.
            return _encode_type_info_recursive(result), result

        if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
            raise TypeError(
                f"Object of type {type(obj)} is not serializable"
                "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
                "fallback to pickle-based serialization."
            )

        if isinstance(obj, FunctionType):
            # `pickle` is generally faster than cloudpickle, but can have
            # problems serializing methods.
            return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

        return msgpack.Ext(
            CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
        )

    def _encode_ndarray(
        self, obj: np.ndarray
    ) -> tuple[str, tuple[int, ...], int | memoryview]:
        assert self.aux_buffers is not None
        # If the array is non-contiguous, we need to copy it first
        arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
        if not obj.shape or obj.nbytes < self.size_threshold:
            # Encode small arrays and scalars inline. Using this extension type
            # ensures we can avoid copying when decoding.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr_data)

        # We serialize the ndarray as a tuple of native types.
        # The data is either inlined if small, or an index into a list of
        # backing buffers that we've stashed in `aux_buffers`.
        return obj.dtype.str, obj.shape, data

    def _encode_tensor(
        self, obj: torch.Tensor
    ) -> (
        tuple[str, tuple[int, ...], int | memoryview] | dict[str, Any] | TensorIpcHandle
    ):
        assert self.aux_buffers is not None

        # Check if we should use IPC for this tensor
        sender = self.tensor_ipc_sender
        if sender is not None:
            try:
                # Generate unique tensor ID
                tensor_id = f"{id(self)}_{self._tensor_id_counter}"
                self._tensor_id_counter += 1
                return sender.send_tensor(obj, self._current_request_id, tensor_id)
            except Exception as e:
                logger.warning(
                    "Failed to send tensor via IPC queue: %s. "
                    "Falling back to standard serialization.",
                    e,
                )
                # Fall through to standard serialization

        # Standard serialization fallback
        # For CUDA tensors without IPC support, we need to move to CPU first
        if obj.is_cuda:
            obj = obj.cpu()

        # view the tensor as a contiguous 1D array of bytes
        arr_data = tensor_data(obj)
        if obj.nbytes < self.size_threshold:
            # Smaller tensors are encoded inline, just like ndarrays.
            data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
        else:
            # Otherwise encode index of backing buffer to avoid copy.
            data = len(self.aux_buffers)
            self.aux_buffers.append(arr_data)
        dtype = str(obj.dtype).removeprefix("torch.")
        return dtype, obj.shape, data

    def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
        return {
            modality: [self._encode_mm_item(item) for item in itemlist]
            for modality, itemlist in items.items()
        }

    def _encode_mm_item(self, item: MultiModalKwargsItem) -> dict[str, Any]:
        return {key: self._encode_mm_field_elem(elem) for key, elem in item.items()}

    def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
        return {
            "data": (
                None if elem.data is None else self._encode_nested_tensors(elem.data)
            ),
            "field": self._encode_mm_field(elem.field),
        }

    def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
        if isinstance(nt, torch.Tensor):
            return self._encode_tensor(nt)
        if isinstance(nt, (int, float)):
            # Although it violates NestedTensors type, MultiModalKwargs
            # values are sometimes floats.
            return nt
        return [self._encode_nested_tensors(x) for x in nt]

    def _encode_mm_field(self, field: BaseMultiModalField):
        # Figure out the factory name for the field type.
        name = MMF_CLASS_TO_FACTORY.get(field.__class__)
        if not name:
            raise TypeError(f"Unsupported field type: {field.__class__}")

        # We just need to copy all of the field values in order
        # which will be then used to reconstruct the field.
        factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
        return name, factory_kw

set_request_context

set_request_context(request_id: str | None) -> None

Set the current request ID being encoded (for tensor association).

Source code in vllm/v1/serial_utils.py
def set_request_context(self, request_id: str | None) -> None:
    """Set the current request ID being encoded (for tensor association)."""
    self._current_request_id = request_id

PydanticMsgspecMixin

Source code in vllm/v1/serial_utils.py
class PydanticMsgspecMixin:
    @classmethod
    def __get_pydantic_core_schema__(
        cls, source_type: Any, handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Make msgspec.Struct compatible with Pydantic, respecting defaults.
        Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
        API as input or in `/docs`. Note this is cached by Pydantic and not
        called on every validation.
        """
        msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
        type_hints = get_type_hints(source_type)

        # Build the Pydantic typed_dict_field for each msgspec field
        fields = {}
        for name, hint in type_hints.items():
            msgspec_field = msgspec_fields[name]

            # typed_dict_field using the handler to get the schema
            field_schema = handler(hint)

            # Add default value to the schema.
            if msgspec_field.default_factory is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default_factory=msgspec_field.default_factory,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            elif msgspec_field.default is not msgspec.NODEFAULT:
                wrapped_schema = core_schema.with_default_schema(
                    schema=field_schema,
                    default=msgspec_field.default,
                )
                fields[name] = core_schema.typed_dict_field(wrapped_schema)
            else:
                # No default, so Pydantic will treat it as required
                fields[name] = core_schema.typed_dict_field(field_schema)
        return core_schema.no_info_after_validator_function(
            cls._validate_msgspec,
            core_schema.typed_dict_schema(fields),
        )

    @classmethod
    def _validate_msgspec(cls, value: Any) -> Any:
        """Validate and convert input to msgspec.Struct instance."""
        if isinstance(value, cls):
            return value
        if isinstance(value, dict):
            return cls(**value)
        return msgspec.convert(value, type=cls)

__get_pydantic_core_schema__ classmethod

__get_pydantic_core_schema__(
    source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema

Make msgspec.Struct compatible with Pydantic, respecting defaults. Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the API as input or in /docs. Note this is cached by Pydantic and not called on every validation.

Source code in vllm/v1/serial_utils.py
@classmethod
def __get_pydantic_core_schema__(
    cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
    """
    Make msgspec.Struct compatible with Pydantic, respecting defaults.
    Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
    API as input or in `/docs`. Note this is cached by Pydantic and not
    called on every validation.
    """
    msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
    type_hints = get_type_hints(source_type)

    # Build the Pydantic typed_dict_field for each msgspec field
    fields = {}
    for name, hint in type_hints.items():
        msgspec_field = msgspec_fields[name]

        # typed_dict_field using the handler to get the schema
        field_schema = handler(hint)

        # Add default value to the schema.
        if msgspec_field.default_factory is not msgspec.NODEFAULT:
            wrapped_schema = core_schema.with_default_schema(
                schema=field_schema,
                default_factory=msgspec_field.default_factory,
            )
            fields[name] = core_schema.typed_dict_field(wrapped_schema)
        elif msgspec_field.default is not msgspec.NODEFAULT:
            wrapped_schema = core_schema.with_default_schema(
                schema=field_schema,
                default=msgspec_field.default,
            )
            fields[name] = core_schema.typed_dict_field(wrapped_schema)
        else:
            # No default, so Pydantic will treat it as required
            fields[name] = core_schema.typed_dict_field(field_schema)
    return core_schema.no_info_after_validator_function(
        cls._validate_msgspec,
        core_schema.typed_dict_schema(fields),
    )

_validate_msgspec classmethod

_validate_msgspec(value: Any) -> Any

Validate and convert input to msgspec.Struct instance.

Source code in vllm/v1/serial_utils.py
@classmethod
def _validate_msgspec(cls, value: Any) -> Any:
    """Validate and convert input to msgspec.Struct instance."""
    if isinstance(value, cls):
        return value
    if isinstance(value, dict):
        return cls(**value)
    return msgspec.convert(value, type=cls)

TensorIpcHandle

Bases: NamedTuple

Handle for a tensor sent via IPC queue (zero-copy transfer).

Contains only metadata about the tensor. This is serialized via msgpack and used by the decoder to retrieve the actual tensor from the queue. The actual tensor is sent separately via torch.multiprocessing.Queue. Works for both CUDA and CPU tensors.

Source code in vllm/v1/serial_utils.py
class TensorIpcHandle(NamedTuple):
    """Handle for a tensor sent via IPC queue (zero-copy transfer).

    Contains only metadata about the tensor. This is serialized via msgpack
    and used by the decoder to retrieve the actual tensor from the queue.
    The actual tensor is sent separately via torch.multiprocessing.Queue.
    Works for both CUDA and CPU tensors.
    """

    request_id: str | None
    tensor_id: str
    shape: tuple[int, ...]
    dtype: str
    device: str

UtilityResult

Wrapper for special handling when serializing/deserializing.

Source code in vllm/v1/serial_utils.py
class UtilityResult:
    """Wrapper for special handling when serializing/deserializing."""

    def __init__(self, r: Any = None):
        self.result = r

_decode_type_info_recursive

_decode_type_info_recursive(
    type_info: Any,
    data: Any,
    convert_fn: Callable[[Sequence[str], Any], Any],
) -> Any

Recursively decode type information for nested structures of lists/dicts.

Source code in vllm/v1/serial_utils.py
def _decode_type_info_recursive(
    type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
) -> Any:
    """Recursively decode type information for nested structures of
    lists/dicts."""
    if type_info is None:
        return data
    if isinstance(type_info, dict):
        assert isinstance(data, dict)
        return {
            k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
            for k in type_info
        }
    if isinstance(type_info, list) and (
        # Exclude serialized tensors/numpy arrays.
        len(type_info) != 2 or not isinstance(type_info[0], str)
    ):
        assert isinstance(data, list)
        return [
            _decode_type_info_recursive(ti, d, convert_fn)
            for ti, d in zip(type_info, data)
        ]
    return convert_fn(type_info, data)

_encode_type_info_recursive

_encode_type_info_recursive(obj: Any) -> Any

Recursively encode type information for nested structures of lists/dicts.

Source code in vllm/v1/serial_utils.py
def _encode_type_info_recursive(obj: Any) -> Any:
    """Recursively encode type information for nested structures of
    lists/dicts."""
    if obj is None:
        return None
    if type(obj) is list:
        return [_encode_type_info_recursive(item) for item in obj]
    if type(obj) is dict:
        return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
    return _typestr(obj)

run_method

run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any

Run a method of an object with the given arguments and keyword arguments. If the method is string, it will be converted to a method using getattr. If the method is serialized bytes and will be deserialized using cloudpickle. If the method is a callable, it will be called directly.

Source code in vllm/v1/serial_utils.py
def run_method(
    obj: Any,
    method: str | bytes | Callable,
    args: tuple[Any, ...],
    kwargs: dict[str, Any],
) -> Any:
    """
    Run a method of an object with the given arguments and keyword arguments.
    If the method is string, it will be converted to a method using getattr.
    If the method is serialized bytes and will be deserialized using
    cloudpickle.
    If the method is a callable, it will be called directly.
    """
    if isinstance(method, bytes):
        func = partial(cloudpickle.loads(method), obj)
    elif isinstance(method, str):
        try:
            func = getattr(obj, method)
        except AttributeError:
            raise NotImplementedError(
                f"Method {method!r} is not implemented."
            ) from None
    else:
        func = partial(method, obj)  # type: ignore
    return func(*args, **kwargs)