Source code for lmcache.storage_backend.serde.serde
import abc
import time
import torch
from lmcache.logging import init_logger
from lmcache.utils import _lmcache_nvtx_annotate
logger = init_logger(__name__)
[docs]
class Serializer(metaclass=abc.ABCMeta):
[docs]
@abc.abstractmethod
def to_bytes(self, t: torch.Tensor) -> bytes:
"""
Serialize a pytorch tensor to bytes. The serialized bytes should contain
both the data and the metadata (shape, dtype, etc.) of the tensor.
Input:
t: the input pytorch tensor, can be on any device, in any shape,
with any dtype
Returns:
bytes: the serialized bytes
"""
raise NotImplementedError
[docs]
class SerializerDebugWrapper(Serializer):
def __init__(self, s: Serializer):
self.s = s
[docs]
def to_bytes(self, t: torch.Tensor) -> bytes:
start = time.perf_counter()
bs = self.s.to_bytes(t)
end = time.perf_counter()
logger.debug(f"Serialization took {end-start:.2f} seconds")
return bs
[docs]
class Deserializer(metaclass=abc.ABCMeta):
[docs]
@abc.abstractmethod
def from_bytes(self, bs: bytes) -> torch.Tensor:
"""
Deserialize a pytorch tensor from bytes.
Input:
bytes: a stream of bytes
Output:
torch.Tensor: the deserialized pytorch tensor
"""
raise NotImplementedError
[docs]
class DeserializerDebugWrapper(Deserializer):
def __init__(self, d: Deserializer):
self.d = d
[docs]
@_lmcache_nvtx_annotate
def from_bytes(self, t: bytes) -> torch.Tensor:
start = time.perf_counter()
ret = self.d.from_bytes(t)
end = time.perf_counter()
logger.debug(f"Deserialization took {(end-start)*1000:.2f} ms")
return ret