Source code for lmcache.experimental.storage_backend.local_disk_backend

import asyncio
import os
import threading
from collections import OrderedDict
from concurrent.futures import Future
from typing import List, Optional

import aiofiles
import torch

from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.lookup_server import LookupServerInterface
from lmcache.experimental.memory_management import (MemoryAllocatorInterface,
                                                    MemoryObj)
from lmcache.experimental.storage_backend.abstract_backend import \
    StorageBackendInterface
from lmcache.experimental.storage_backend.evictor import LRUEvictor, PutStatus
from lmcache.logging import init_logger
from lmcache.utils import (CacheEngineKey, DiskCacheMetadata,
                           _lmcache_nvtx_annotate)

logger = init_logger(__name__)


[docs] class LocalDiskBackend(StorageBackendInterface): def __init__( self, config: LMCacheEngineConfig, loop: asyncio.AbstractEventLoop, memory_allocator: MemoryAllocatorInterface, dst_device: str = "cuda", lookup_server: Optional[LookupServerInterface] = None, ): self.dict: OrderedDict[CacheEngineKey, DiskCacheMetadata] = OrderedDict() self.dst_device = dst_device self.disk_lock = threading.Lock() assert config.local_disk is not None self.path: str = config.local_disk if not os.path.exists(self.path): os.makedirs(self.path) logger.info(f"Created local disk cache directory: {self.path}") self.lookup_server = lookup_server # Initialize the evictor self.evictor = LRUEvictor(max_cache_size=config.max_local_disk_size) self.loop = loop self.put_tasks: List[CacheEngineKey] = [] self.memory_allocator = memory_allocator def __str__(self): return self.__class__.__name__ def _key_to_path( self, key: CacheEngineKey, ) -> str: return self.path + key.to_string().replace("/", "-") + ".pt"
[docs] def contains(self, key: CacheEngineKey) -> bool: with self.disk_lock: return key in self.dict
[docs] def exists_in_put_tasks(self, key: CacheEngineKey) -> bool: with self.disk_lock: return key in self.put_tasks
[docs] def remove( self, key: CacheEngineKey, ) -> None: path = self.dict[key].path self.disk_lock.acquire() self.dict.pop(key) self.disk_lock.release() os.remove(path)
[docs] def insert_key(self, key: CacheEngineKey, memory_obj: MemoryObj) -> None: path = self._key_to_path(key) size = memory_obj.get_size() shape = memory_obj.metadata.shape dtype = memory_obj.metadata.dtype with self.disk_lock: # Need to do reinsert to update cache recency if key in self.dict: self.dict.pop(key) self.dict[key] = DiskCacheMetadata(path, size, shape, dtype)
[docs] def submit_put_task( self, key: CacheEngineKey, memory_obj: MemoryObj, ) -> Optional[Future]: assert memory_obj.tensor is not None # Update cache recency evict_keys, put_status = self.evictor.update_on_put( self.dict, memory_obj.get_physical_size()) if put_status == PutStatus.ILLEGAL: return None # evict caches for evict_key in evict_keys: self.remove(evict_key) if self.lookup_server is not None: self.lookup_server.batched_remove(evict_keys) self.memory_allocator.ref_count_up(memory_obj) self.disk_lock.acquire() self.put_tasks.append(key) self.disk_lock.release() future = asyncio.run_coroutine_threadsafe( self.async_save_bytes_to_disk(key, memory_obj), self.loop) return future
[docs] def submit_prefetch_task( self, key: CacheEngineKey, ) -> Optional[Future]: self.disk_lock.acquire() if key not in self.dict: self.disk_lock.release() return None # Update cache recency self.evictor.update_on_hit(key, self.dict) path = self.dict[key].path dtype = self.dict[key].dtype shape = self.dict[key].shape self.disk_lock.release() logger.info(f"Prefetching {key} from disk.") assert dtype is not None assert shape is not None future = asyncio.run_coroutine_threadsafe( self.async_load_bytes_from_disk(path, dtype, shape), self.loop) return future
[docs] def get_blocking( self, key: CacheEngineKey, ) -> Optional[MemoryObj]: """ Blocking get function. """ self.disk_lock.acquire() if key not in self.dict: self.disk_lock.release() return None # Update cache recency self.evictor.update_on_hit(key, self.dict) path = self.dict[key].path dtype = self.dict[key].dtype shape = self.dict[key].shape assert dtype is not None assert shape is not None memory_obj = self.load_bytes_from_disk(path, dtype=dtype, shape=shape) self.disk_lock.release() return memory_obj
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() async def async_save_bytes_to_disk( self, key: CacheEngineKey, memory_obj: MemoryObj, ) -> None: """ Convert KV to bytes and async store bytes to disk. """ kv_chunk = memory_obj.tensor assert kv_chunk is not None byte_array = memory_obj.byte_array path = self._key_to_path(key) async with aiofiles.open(path, 'wb') as f: await f.write(byte_array) self.insert_key(key, memory_obj) self.memory_allocator.ref_count_down(memory_obj) self.disk_lock.acquire() self.put_tasks.remove(key) self.disk_lock.release()
# TODO(Jiayi): use `bytes_read = await f.readinto(buffer)` # for better performance (i.e., fewer copy)
[docs] async def async_load_bytes_from_disk( self, path: str, dtype: torch.dtype, shape: torch.Size, ) -> Optional[MemoryObj]: """ Async load bytearray from disk. """ memory_obj = self.memory_allocator.allocate(shape, dtype) if memory_obj is None: logger.debug("Memory allocation failed during async disk load.") return None buffer = memory_obj.byte_array async with aiofiles.open(path, 'rb') as f: await f.readinto(buffer) return memory_obj
# TODO(Jiayi): use memory allocator to redeuce cpu buffer allocation # TODO(Jiayi): the pinned cpu memory_obj should directly be passed into # gpu connector; this gpu buffer could be avoided
[docs] def load_bytes_from_disk( self, path: str, dtype: torch.dtype, shape: torch.Size, ) -> Optional[MemoryObj]: """ Load bytearray from disk. """ memory_obj = self.memory_allocator.allocate(shape, dtype) if memory_obj is None: logger.debug("Memory allocation failed during async disk load.") return None buffer = memory_obj.byte_array with open(path, 'rb') as f: f.readinto(buffer) return memory_obj
[docs] @_lmcache_nvtx_annotate @torch.inference_mode() def load_disk( self, path: str, backend: str = "bytes", dtype: Optional[torch.dtype] = None, shape: Optional[torch.Size] = None, ) -> Optional[MemoryObj]: """ Load KV from disk. """ if backend == "bytes": assert dtype is not None assert shape is not None memory_obj = self.load_bytes_from_disk(path, dtype, shape) else: raise ValueError(f"Invalid backend: {backend}") return memory_obj
[docs] def close(self) -> None: if self.lookup_server is not None: self.disk_lock.acquire() self.lookup_server.batched_remove(list(self.dict.keys())) self.disk_lock.release()