Source code for lmcache.storage_backend.serde.torch_serde

import io

import torch

from lmcache.logging import init_logger
from lmcache.storage_backend.serde.serde import Deserializer, Serializer

logger = init_logger(__name__)


[docs] class TorchSerializer(Serializer): def __init__(self): super().__init__()
[docs] def to_bytes(self, t: torch.Tensor) -> bytes: with io.BytesIO() as f: torch.save(t.cpu().clone().detach(), f) return f.getvalue()
[docs] class TorchDeserializer(Deserializer): def __init__(self, dtype): super().__init__(dtype)
[docs] def from_bytes_normal(self, b: bytes) -> torch.Tensor: with io.BytesIO(b) as f: return torch.load(f)
[docs] def from_bytes(self, b: bytes) -> torch.Tensor: return self.from_bytes_normal(b).to(dtype=self.dtype)