Source code for lmcache.storage_backend.connector.base_connector

import abc
import time
from enum import Enum
from typing import List, Optional

import torch

from lmcache.logging import init_logger
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)


[docs] class ConnectorType(Enum): BYTES = 1 TENSOR = 2
[docs] class RemoteConnector(metaclass=abc.ABCMeta): """ Interface for remote connector """
[docs] @abc.abstractmethod def exists(self, key: str) -> bool: """ Check if the remote server contains the key Input: key: a string Returns: True if the cache engine contains the key, False otherwise """ raise NotImplementedError
[docs] @abc.abstractmethod def get(self, key: str) -> Optional[bytes | torch.Tensor]: """ Get the objects (bytes or Tensor) of the corresponding key Input: key: the key of the corresponding object Returns: The objects (bytes or Tensor) of the corresponding key Return None if the key does not exist """ raise NotImplementedError
[docs] @abc.abstractmethod def set(self, key: str, obj: bytes | torch.Tensor) -> None: """ Send the objects (bytes or Tensor) with the corresponding key directly to the remote server Input: key: the key of the corresponding object obj: the object (bytes or Tensor) of the corresponding key """ raise NotImplementedError
[docs] @abc.abstractmethod def list(self) -> List[str]: """ List all keys in the remote server Returns: A list of keys in the remote server """ raise NotImplementedError
[docs] @abc.abstractmethod def close(self) -> None: """ Close remote server """ raise NotImplementedError
[docs] class RemoteBytesConnector(RemoteConnector): pass
[docs] class RemoteTensorConnector(RemoteConnector): pass
[docs] class RemoteConnectorDebugWrapper(RemoteConnector): def __init__(self, connector: RemoteConnector): self.connector = connector
[docs] def exists(self, key: str) -> bool: return self.connector.exists(key)
[docs] @_lmcache_nvtx_annotate def get(self, key: str) -> Optional[bytes | torch.Tensor]: start = time.perf_counter() ret = self.connector.get(key) end = time.perf_counter() if ret is None or len(ret) == 0: logger.debug( "Didn't get any data from the remote backend, key is {key}") return None if check_connector_type(self.connector) == ConnectorType.BYTES: assert isinstance(ret, bytes) logger.debug( "Get %.2f MBytes data from the remote backend takes %.2f ms", len(ret) / 1e6, (end - start) * 1e3, ) elif check_connector_type(self.connector) == ConnectorType.TENSOR: assert isinstance(ret, torch.Tensor) logger.debug( "Get %.2f MBytes data from the remote backend takes %.2f ms", (ret.element_size() * ret.numel()) / 1e6, (end - start) * 1e3, ) return ret
[docs] def set(self, key: str, obj: bytes | torch.Tensor) -> None: start = time.perf_counter() self.connector.set(key, obj) end = time.perf_counter() if isinstance(self.connector, RemoteBytesConnector): assert isinstance(obj, bytes) logger.debug( "Put %.2f MBytes data to the remote backend takes %.2f ms", len(obj) / 1e6, (end - start) * 1e3, ) elif isinstance(self.connector, RemoteTensorConnector): assert isinstance(obj, torch.Tensor) logger.debug( "Put %.2f MBytes data to the remote backend takes %.2f ms", (obj.element_size() * obj.numel()) / 1e6, (end - start) * 1e3, )
[docs] def list(self) -> List[str]: return self.connector.list()
[docs] def close(self) -> None: return self.connector.close()
[docs] def check_connector_type(connector: RemoteConnector) -> ConnectorType: if isinstance(connector, RemoteBytesConnector): return ConnectorType.BYTES elif isinstance(connector, RemoteTensorConnector): return ConnectorType.TENSOR if isinstance(connector, RemoteConnectorDebugWrapper): # TODO: avoid possible recursive deadlock return check_connector_type(connector.connector) raise ValueError('Unsupported connector type')