Source code for lmcache.storage_backend.remote_backend

import queue
import threading
from typing import Iterable, Iterator, List, Optional, Tuple, Union

import torch

from lmcache.config import LMCacheEngineConfig, LMCacheEngineMetadata
from lmcache.logging import init_logger
from lmcache.storage_backend.abstract_backend import LMCBackendInterface
from lmcache.storage_backend.connector import CreateConnector
from lmcache.storage_backend.serde import CreateSerde
from lmcache.utils import CacheEngineKey, _lmcache_nvtx_annotate

logger = init_logger(__name__)

# FIXME(Jiayi): Put the following worker function(s) into class
# FIXME(Jiayi): Needs to consider concurrent setting (private queue?)


[docs] class RemoteBackendEndSignal: pass
[docs] class LMCRemoteBackend(LMCBackendInterface): """ Cache engine for storing the KV cache of the tokens in the remote server. """ def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata): """ Throws: RuntimeError if the loaded configuration does not match the current configuration """ super().__init__() #self.existing_keys: Set[CacheEngineKey] = set() self.put_thread = None assert config.remote_url is not None, ( "Need to provide remote_url when" " using LMCRemoteBackend") self.connection = CreateConnector(config.remote_url) assert config.remote_serde is not None, ( "Need to provide remote_serde " "when using LMCRemoteBackend") s, d = CreateSerde(config.remote_serde, config, metadata) self.serializer = s self.deserializer = d # For async put self.put_queue: queue.Queue[ Union[Tuple[CacheEngineKey, torch.Tensor], RemoteBackendEndSignal]] = queue.Queue() self.put_thread = threading.Thread(target=self.put_worker, args=()) self.put_thread.start() # FIXME(Jiayi): please remove this hard code self.dst_device = "cuda"
[docs] @_lmcache_nvtx_annotate def put_worker(self, ): # put_stream = torch.cuda.Stream() while True: item = self.put_queue.get() if isinstance(item, RemoteBackendEndSignal): break key, value = item # with torch.cuda.stream(put_stream): self.put_blocking(key, value)
def _combine_key( self, key: CacheEngineKey, ) -> str: """ Convert the tuple key to a single key """ return key.to_string() def _split_key( self, key: str, ) -> CacheEngineKey: """ Split the single key to a tuple key """ return CacheEngineKey.from_string(key)
[docs] def list(self) -> List[CacheEngineKey]: """ list the remote keys (and also update the 'cached' existing keys set) """ keys = self.connection.list() #for key in keys: # self.existing_keys.add(self._split_key(key)) return [self._split_key(key) for key in keys]
[docs] def contains( self, key: CacheEngineKey, ) -> bool: """ Check if the cache engine contains the key. Input: key: the key of the token chunk, including prefix hash and format Returns: True if the cache engine contains the key, False otherwise """ #if key in self.existing_keys: # return True #else: flag = self.connection.exists(self._combine_key(key)) # if flag: # self.existing_keys.add(key) return flag
[docs] def put_blocking( self, key: CacheEngineKey, kv_chunk: torch.Tensor, ) -> None: bs = self.serializer.to_bytes(kv_chunk) self.connection.set(self._combine_key(key), bs)
#self.existing_keys.add(key)
[docs] def put( self, key: CacheEngineKey, kv_chunk: torch.Tensor, blocking: bool = True, ) -> None: """ Store the KV cache of the tokens into the cache engine. Input: key: the key of the token chunk, including prefix hash and format kv_chunk: the kv cache of the token chunk, in a single big tensor blocking: whether to block until the put is done Returns: None Note: The KV cache should NOT have the "batch" dimension. """ if blocking: self.put_blocking(key, kv_chunk) else: self.put_queue.put((key, kv_chunk))
[docs] @_lmcache_nvtx_annotate def get( self, key: CacheEngineKey, ) -> Optional[torch.Tensor]: """ Retrieve the KV cache chunk (in a single big tensor) by the given key """ if not self.contains(key): return None bs = self.connection.get(self._combine_key(key)) if bs is None or len(bs) == 0: return None return self.deserializer.from_bytes(bs).to(self.dst_device)
[docs] def close(self): if self.put_thread is not None and self.put_thread.is_alive(): self.put_queue.put(RemoteBackendEndSignal()) self.put_thread.join() logger.info("Closed the put worker") if self.connection is not None: self.connection.close()
def __del__(self): self.close()
[docs] class LMCPipelinedRemoteBackend(LMCRemoteBackend): """ Implements the pipelined get functionality for the remote backend. """ def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata): """ Throws: RuntimeError if the loaded configuration does not match the current configuration """ super().__init__(config, metadata) # Comment out existing_keys for now to avoid consistency issues #self.existing_keys = set() self.network_thread = None self.deserialize_thread = None # Initialize network get thread queue logger.debug("Initializing network thread queue") self.network_queue: queue.Queue[Union[Tuple[ int, CacheEngineKey], RemoteBackendEndSignal]] = queue.Queue() self.network_thread = threading.Thread(target=self.network_worker, args=()) self.network_thread.start() # Initialize network get thread queue logger.debug("Initializing deserial thread queue") self.deserialize_queue: queue.Queue[Union[Tuple[ int, Optional[bytes]], RemoteBackendEndSignal]] = queue.Queue() self.deserialize_thread = threading.Thread( target=self.deserialize_worker, args=()) self.deserialize_thread.start() self.result_list: List[Optional[torch.Tensor]] = []
[docs] @_lmcache_nvtx_annotate def network_worker(self, ): while True: item = self.network_queue.get() if isinstance(item, RemoteBackendEndSignal): break idx, key = item if self.contains(key): data = self.connection.get(self._combine_key(key)) self.deserialize_queue.put_nowait((idx, data)) self.network_queue.task_done()
[docs] @_lmcache_nvtx_annotate def deserialize_worker(self, ): while True: item = self.deserialize_queue.get() if isinstance(item, RemoteBackendEndSignal): break idx, data = item if data is not None: result = self.deserializer.from_bytes(data).to(self.dst_device) else: result = None self.result_list.append(result) self.deserialize_queue.task_done()
[docs] @_lmcache_nvtx_annotate def batched_get( self, keys: Iterator[CacheEngineKey], ) -> Iterable[Optional[torch.Tensor]]: self.result_list = [] for idx, key in enumerate(keys): self.network_queue.put_nowait((idx, key)) self.network_queue.join() self.deserialize_queue.join() return self.result_list
[docs] def close(self): super().close() if self.network_thread is not None and self.network_thread.is_alive(): self.network_queue.put(RemoteBackendEndSignal()) self.network_thread.join() logger.info("Closed the network worker") if (self.deserialize_thread is not None and self.deserialize_thread.is_alive()): self.deserialize_queue.put(RemoteBackendEndSignal()) self.deserialize_thread.join() logger.info("Closed the deserialize worker")
def __del__(self): self.close()