Source code for lmcache.storage_backend.local_backend

import os
import queue
import threading
import time
from collections import OrderedDict
from concurrent.futures import Future, ProcessPoolExecutor
from typing import Dict, Optional, Tuple, Union

import torch
from safetensors import safe_open
from safetensors.torch import save_file

from lmcache.config import LMCacheEngineConfig, LMCacheMemPoolMetadata
from lmcache.logging import init_logger
from lmcache.storage_backend.abstract_backend import LMCBackendInterface
from lmcache.storage_backend.evictor import LRUEvictor
from lmcache.storage_backend.evictor.base_evictor import PutStatus
from lmcache.storage_backend.mem_pool import (KVObj, LocalCPUBufferPool,
                                              LocalCPUPool, LocalGPUPool,
                                              LocalPool)
from lmcache.utils import (CacheEngineKey, DiskCacheMetadata, KVCache,
                           _lmcache_nvtx_annotate)

logger = init_logger(__name__)


[docs] class LocalBackendEndSignal: pass
[docs] class LMCLocalBackend(LMCBackendInterface): """ Cache engine for storing the KV cache of the tokens in the local cpu/gpu memory. """ def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheMemPoolMetadata, dst_device: str = "cuda"): """ Throws: RuntimeError if the loaded configuration does not match the current configuration """ super().__init__(dst_device) self.chunk_size = config.chunk_size self.config = config self.dict: OrderedDict[CacheEngineKey, KVObj] = OrderedDict() self.device = config.local_device self.put_queue: queue.Queue[ Union[Tuple[CacheEngineKey, torch.Tensor], LocalBackendEndSignal]] = queue.Queue() self.put_thread = threading.Thread(target=self.put_worker, args=()) self.put_thread.start() self.update_lock = threading.Lock() # TODO(Jiayi): The storage size and caching policy for both # evictor and mpool need to be configured dynamically max_cache_size = self.config.max_local_cache_size self.evictor = LRUEvictor(max_cache_size) self.mpool: LocalPool if self.device == "cpu": self.mpool = LocalCPUPool(metadata) elif self.device == "cuda": self.mpool = LocalGPUPool(metadata) # TODO(Jiayi): A gpu buffer could speed up `get` # self.fix_sized_dst_buffer = torch.tensor()
[docs] def contains( self, key: CacheEngineKey, ) -> bool: """ Check if the cache engine contains the key. Input: key: the key of the token chunk, including prefix hash and format Returns: True if the cache engine contains the key, False otherwise """ return key in self.dict
[docs] def remove( self, key: CacheEngineKey, ) -> None: """ Remove the KV cache chunk by the given key Input: key: the key of the token chunk, including prefix hash and format """ kv_obj = self.dict.pop(key) self.mpool.free(kv_obj)
[docs] @_lmcache_nvtx_annotate def put_worker(self, ): while True: item = self.put_queue.get() if isinstance(item, LocalBackendEndSignal): break key, value = item self.put_nonblocking(key, value)
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() def put_nonblocking(self, key, kv_chunk): # Obtain keys to evict self.update_lock.acquire() evict_keys, put_status = self.evictor.update_on_put( self.dict, self.mpool.size_per_chunk) if put_status == PutStatus.ILLEGAL: self.update_lock.release() return # evict caches for evict_key in evict_keys: self.remove(evict_key) # free old block to avoid mem leak if key in self.dict: self.remove(key) # Allocate the kv chunk kv_obj = self.mpool.allocate(kv_chunk) self.update_lock.release() if kv_obj is None: return put_stream = torch.cuda.Stream() if kv_chunk.device != torch.cpu: # wait operation in main stream to finish # e.g., view operations on kv_chunk put_stream.wait_stream(torch.cuda.default_stream(kv_chunk.device)) with torch.cuda.stream(put_stream): kv_obj.data.copy_(kv_chunk, non_blocking=True) kv_chunk.record_stream(put_stream) put_stream.synchronize() # Store new chunk self.update_lock.acquire() self.dict[key] = kv_obj self.update_lock.release()
[docs] @torch.inference_mode() def put_blocking(self, key, kv_chunk): # Obtain keys to evict evict_keys, put_status = self.evictor.update_on_put( self.dict, self.mpool.size_per_chunk) # Abort put if cache too big if put_status == PutStatus.ILLEGAL: return kv_obj = self.mpool.allocate(kv_chunk) if kv_obj is None: return kv_obj.data.copy_(kv_chunk, non_blocking=False) # free old block to avoid mem leak if key in self.dict: self.remove(key) # Evict caches for evict_key in evict_keys: self.remove(evict_key) # Store new chunk self.dict[key] = kv_obj
[docs] def put( self, key: CacheEngineKey, kv_chunk: torch.Tensor, blocking: bool = True, ) -> None: """ Store the KV cache of the tokens into the cache engine. Input: key: the key of the token chunk, including prefix hash and format kv_chunk: the kv cache of the token chunk, in the format of nested tuples Returns: None Note: The KV cache should NOT have the "batch" dimension. """ if blocking: self.put_blocking(key, kv_chunk) else: self.put_queue.put((key, kv_chunk))
[docs] @_lmcache_nvtx_annotate def get( self, key: CacheEngineKey, ) -> Optional[torch.Tensor]: """ Retrieve the KV cache chunk by the given key Input: key: the key of the token chunk, including prefix hash and format Output: the kv cache of the token chunk, in the format of nested tuples None if the key is not found """ kv_chunk = None self.update_lock.acquire() kv_obj = self.dict.get(key, None) # Update cache recency if kv_obj is not None: self.evictor.update_on_get(key, self.dict) kv_chunk = kv_obj.data.to(self.dst_device) self.update_lock.release() return kv_chunk
[docs] def close(self): if self.put_thread is not None and self.put_thread.is_alive(): self.put_queue.put(LocalBackendEndSignal()) self.put_thread.join() logger.info("Closed the put worker in local backend")
def __del__(self): self.close()
# TODO(Jiayi): need to optimize disk saving/loading # current impl. with "safetensors" might not be efficient # but it is better than "torch.save/load" # TODO(Jiayi): need to support prefetch for disk
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() def save_disk( path: str, kv_chunk: torch.Tensor, ): save_file({"kv_chunk": kv_chunk.contiguous()}, path)
[docs] class LMCLocalDiskBackend(LMCBackendInterface): """ Cache engine for storing the KV cache of the tokens in the local disk. """ def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheMemPoolMetadata, dst_device: str = "cuda"): """ Throws: RuntimeError if the loaded configuration does not match the current configuration """ super().__init__(dst_device) self.chunk_size = config.chunk_size self.config = config self.dict: OrderedDict[CacheEngineKey, DiskCacheMetadata] = OrderedDict() self.path = config.local_device assert self.path is not None, ("Need to specify local path if when " "using LMCLocalDiskBackend") if not os.path.exists(self.path): os.makedirs(self.path) self.update_lock = threading.Lock() self.put_queue: queue.Queue[ Union[Tuple[CacheEngineKey, torch.Tensor], LocalBackendEndSignal]] = queue.Queue() self.put_thread = threading.Thread(target=self.put_worker, args=()) self.put_thread.start() self.future_pool: Dict[CacheEngineKey, Tuple[Future, KVObj]] = {} self.stop_event = threading.Event() self.sweeper_thread = threading.Thread(target=self.buffer_sweeper, args=()) self.sweeper_thread.start() # TODO(Jiayi): The storage size and caching policy for both # evictor and mpool need to be configured dynamically self.evictor = LRUEvictor(config.max_local_cache_size) # NOTE(Jiayi): This mbufferpool should be smaller than the actual # cpu backend but big enough to avoid stalls in save # TODO(Jiayi): share the buffer if both cpu and disk backend are enabled self.cpu_mbufferpool = LocalCPUBufferPool(metadata) self.proc_pool_executor = ProcessPoolExecutor(max_workers=4)
[docs] def contains( self, key: CacheEngineKey, ) -> bool: """ Check if the cache engine contains the key. Input: key: the key of the token chunk, including prefix hash and format Returns: True if the cache engine contains the key, False otherwise """ return key in self.dict
def _key_to_path( self, key: CacheEngineKey, ) -> str: """ Convert key to path_name Input: key: the key of the token chunk, including prefix hash and format Returns: returns the path name """ return self.path + key.to_string().replace("/", "-") + ".pt"
[docs] def remove( self, key: CacheEngineKey, ) -> None: """ Remove the KV cache chunk by the given key Input: key: the key of the token chunk, including prefix hash and format """ path = self.dict[key].path self.dict.pop(key) os.remove(path)
[docs] @_lmcache_nvtx_annotate def put_worker(self, ): while True: item = self.put_queue.get() if isinstance(item, LocalBackendEndSignal): break key, value = item self.put_nonblocking(key, value)
[docs] def buffer_sweeper(self, ): """ Sweep the future pool to free up memory. """ while not self.stop_event: logger.debug("Sweeping memory buffer") self.update_lock.acquire() for key in list(self.future_pool.keys()): future = self.future_pool[key][0] kv_obj = self.future_pool[key][1] if not future.done(): continue self.cpu_mbufferpool.free(kv_obj) del self.future_pool[key] self.update_lock.release() # sweep the memory every 30s time.sleep(30)
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() def put_nonblocking( self, key: CacheEngineKey, kv_chunk: torch.Tensor, ) -> None: path = self._key_to_path(key) logger.debug(f"Saving cache to {path}") self.update_lock.acquire() # Skip store if task is already being executed # TODO(Jiayi): what if already stored, should we # overwrite or skip? if key in self.future_pool: self.update_lock.release() return # Obtain keys to evict evict_keys, put_status = self.evictor.update_on_put( self.dict, self.evictor.get_size(kv_chunk)) # Abort put if cache too big if put_status == PutStatus.ILLEGAL: self.update_lock.release() return # evict caches for evict_key in evict_keys: self.remove(evict_key) self.update_lock.release() kv_obj = None # Allocate the kv chunk while kv_obj is None: self.update_lock.acquire() kv_obj = self.cpu_mbufferpool.allocate(kv_chunk) self.update_lock.release() if kv_obj is None: # TODO(Jiayi): Please tune the sleep time for better performance time.sleep(0.01) put_stream = torch.cuda.Stream() put_stream.wait_stream(torch.cuda.default_stream(kv_chunk.device)) with torch.cuda.stream(put_stream): kv_obj.data.copy_(kv_chunk, non_blocking=True) kv_chunk.record_stream(put_stream) put_stream.synchronize() future = self.proc_pool_executor.submit(save_disk, path, kv_obj.data) self.update_lock.acquire() self.future_pool[key] = (future, kv_obj) self.dict[key] = DiskCacheMetadata(path, kv_obj.size) # NOTE(Jiayi): the following `free` will result in data corruption # The serialized object (`kv_obj.data` in `submit`) may reference # the external memory (cpu tensor might be shared in multiprocessing), # and if the tensor is deleted, it might be invalidated. # self.cpu_mbufferpool.free(kv_obj) self.update_lock.release()
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() def put_blocking( self, key: CacheEngineKey, kv_chunk: torch.Tensor, ) -> None: path = self._key_to_path(key) logger.debug(f"Saving cache to {path}") self.update_lock.acquire() # Obtain keys to evict evict_keys, put_status = self.evictor.update_on_put( self.dict, self.evictor.get_size(kv_chunk)) # Abort put if cache too big if put_status == PutStatus.ILLEGAL: self.update_lock.release() return # evict caches for evict_key in evict_keys: self.remove(evict_key) self.update_lock.release() # The following order matters of `save_file` and `update dictionary` # matters save_file({"kv_chunk": kv_chunk}, path) self.update_lock.acquire() self.dict[key] = DiskCacheMetadata(path, self.evictor.get_size(kv_chunk)) self.update_lock.release()
[docs] def put( self, key: CacheEngineKey, kv_chunk: torch.Tensor, blocking: bool = True, ) -> None: """ Store the KV cache of the tokens into the cache engine. Input: key: the key of the token chunk, including prefix hash and format kv_chunk: the kv cache of the token chunk, in the format of nested tuples Returns: None Note: The KV cache should NOT have the "batch" dimension. """ if blocking: self.put_blocking(key, kv_chunk) else: self.put_queue.put((key, kv_chunk))
[docs] @_lmcache_nvtx_annotate def get( self, key: CacheEngineKey, ) -> Optional[KVCache]: """ Retrieve the KV cache chunk by the given key Input: key: the key of the token chunk, including prefix hash and format Output: the kv cache of the token chunk, in the format of nested tuples None if the key is not found """ self.update_lock.acquire() if key not in self.dict: self.update_lock.release() return None if key in self.future_pool: future = self.future_pool[key][0] kv_obj = self.future_pool[key][1] # NOTE(Jiayi): the following code is blocking # if future.exception(): # raise Exception(f"Task raised an exception: \ # {future.exception()}") if not future.done(): self.update_lock.release() return None self.cpu_mbufferpool.free(kv_obj) del self.future_pool[key] path = self.dict[key].path self.evictor.update_on_get(key, self.dict) with safe_open(path, framework="pt", device=self.dst_device) as f: # type: ignore kv_chunk = f.get_tensor("kv_chunk") self.update_lock.release() return kv_chunk
[docs] def close(self): if self.put_thread is not None and self.put_thread.is_alive(): self.put_queue.put(LocalBackendEndSignal()) self.put_thread.join() if self.sweeper_thread is not None and self.sweeper_thread.is_alive(): self.stop_event.set() self.sweeper_thread.join() self.proc_pool_executor.shutdown() logger.info("Closed the workers in local disk backend")
def __del__(self): self.close()