Source code for lmcache.storage_backend.serde.safe_serde
from typing import Union
import torch
from safetensors.torch import load, save
from lmcache.config import GlobalConfig
from lmcache.logging import init_logger
from lmcache.storage_backend.serde.serde import Deserializer, Serializer
logger = init_logger(__name__)
[docs]
class SafeSerializer(Serializer):
def __init__(self):
super().__init__()
[docs]
def to_bytes(self, t: torch.Tensor) -> bytes:
return save({"tensor_bytes": t.cpu().contiguous()})
[docs]
class SafeDeserializer(Deserializer):
def __init__(self, dtype):
super().__init__(dtype)
self.debug = GlobalConfig.is_debug()
[docs]
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
# TODO(Jiayi): please verify the input type
# bytearray from `receive_all()` in connector?
[docs]
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
return self.from_bytes_normal(b)