import asyncio
import os
import threading
from collections import OrderedDict
from concurrent.futures import Future
from typing import List, Optional
import aiofiles
import torch
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.lookup_server import LookupServerInterface
from lmcache.experimental.memory_management import (MemoryAllocatorInterface,
MemoryObj)
from lmcache.experimental.storage_backend.abstract_backend import \
StorageBackendInterface
from lmcache.experimental.storage_backend.evictor import LRUEvictor, PutStatus
from lmcache.logging import init_logger
from lmcache.utils import (CacheEngineKey, DiskCacheMetadata,
_lmcache_nvtx_annotate)
logger = init_logger(__name__)
[docs]
class LocalDiskBackend(StorageBackendInterface):
def __init__(
self,
config: LMCacheEngineConfig,
loop: asyncio.AbstractEventLoop,
memory_allocator: MemoryAllocatorInterface,
dst_device: str = "cuda",
lookup_server: Optional[LookupServerInterface] = None,
):
self.dict: OrderedDict[CacheEngineKey,
DiskCacheMetadata] = OrderedDict()
self.dst_device = dst_device
self.disk_lock = threading.Lock()
assert config.local_disk is not None
self.path: str = config.local_disk
if not os.path.exists(self.path):
os.makedirs(self.path)
logger.info(f"Created local disk cache directory: {self.path}")
self.lookup_server = lookup_server
# Initialize the evictor
self.evictor = LRUEvictor(max_cache_size=config.max_local_disk_size)
self.loop = loop
self.put_tasks: List[CacheEngineKey] = []
self.memory_allocator = memory_allocator
def __str__(self):
return self.__class__.__name__
def _key_to_path(
self,
key: CacheEngineKey,
) -> str:
return self.path + key.to_string().replace("/", "-") + ".pt"
[docs]
def contains(self, key: CacheEngineKey) -> bool:
with self.disk_lock:
return key in self.dict
[docs]
def exists_in_put_tasks(self, key: CacheEngineKey) -> bool:
with self.disk_lock:
return key in self.put_tasks
[docs]
def remove(
self,
key: CacheEngineKey,
) -> None:
path = self.dict[key].path
self.disk_lock.acquire()
self.dict.pop(key)
self.disk_lock.release()
os.remove(path)
[docs]
def insert_key(self, key: CacheEngineKey, memory_obj: MemoryObj) -> None:
path = self._key_to_path(key)
size = memory_obj.get_size()
shape = memory_obj.metadata.shape
dtype = memory_obj.metadata.dtype
with self.disk_lock:
# Need to do reinsert to update cache recency
if key in self.dict:
self.dict.pop(key)
self.dict[key] = DiskCacheMetadata(path, size, shape, dtype)
[docs]
def submit_put_task(
self,
key: CacheEngineKey,
memory_obj: MemoryObj,
) -> Optional[Future]:
assert memory_obj.tensor is not None
# Update cache recency
evict_keys, put_status = self.evictor.update_on_put(
self.dict, memory_obj.get_physical_size())
if put_status == PutStatus.ILLEGAL:
return None
# evict caches
for evict_key in evict_keys:
self.remove(evict_key)
if self.lookup_server is not None:
self.lookup_server.batched_remove(evict_keys)
self.memory_allocator.ref_count_up(memory_obj)
self.disk_lock.acquire()
self.put_tasks.append(key)
self.disk_lock.release()
future = asyncio.run_coroutine_threadsafe(
self.async_save_bytes_to_disk(key, memory_obj), self.loop)
return future
[docs]
def submit_prefetch_task(
self,
key: CacheEngineKey,
) -> Optional[Future]:
self.disk_lock.acquire()
if key not in self.dict:
self.disk_lock.release()
return None
# Update cache recency
self.evictor.update_on_hit(key, self.dict)
path = self.dict[key].path
dtype = self.dict[key].dtype
shape = self.dict[key].shape
self.disk_lock.release()
logger.info(f"Prefetching {key} from disk.")
assert dtype is not None
assert shape is not None
future = asyncio.run_coroutine_threadsafe(
self.async_load_bytes_from_disk(path, dtype, shape), self.loop)
return future
[docs]
def get_blocking(
self,
key: CacheEngineKey,
) -> Optional[MemoryObj]:
"""
Blocking get function.
"""
self.disk_lock.acquire()
if key not in self.dict:
self.disk_lock.release()
return None
# Update cache recency
self.evictor.update_on_hit(key, self.dict)
path = self.dict[key].path
dtype = self.dict[key].dtype
shape = self.dict[key].shape
assert dtype is not None
assert shape is not None
memory_obj = self.load_bytes_from_disk(path, dtype=dtype, shape=shape)
self.disk_lock.release()
return memory_obj
[docs]
@_lmcache_nvtx_annotate
@torch.inference_mode()
async def async_save_bytes_to_disk(
self,
key: CacheEngineKey,
memory_obj: MemoryObj,
) -> None:
"""
Convert KV to bytes and async store bytes to disk.
"""
kv_chunk = memory_obj.tensor
assert kv_chunk is not None
byte_array = memory_obj.byte_array
path = self._key_to_path(key)
async with aiofiles.open(path, 'wb') as f:
await f.write(byte_array)
self.insert_key(key, memory_obj)
self.memory_allocator.ref_count_down(memory_obj)
self.disk_lock.acquire()
self.put_tasks.remove(key)
self.disk_lock.release()
# TODO(Jiayi): use `bytes_read = await f.readinto(buffer)`
# for better performance (i.e., fewer copy)
[docs]
async def async_load_bytes_from_disk(
self,
path: str,
dtype: torch.dtype,
shape: torch.Size,
) -> Optional[MemoryObj]:
"""
Async load bytearray from disk.
"""
memory_obj = self.memory_allocator.allocate(shape, dtype)
if memory_obj is None:
logger.debug("Memory allocation failed during async disk load.")
return None
buffer = memory_obj.byte_array
async with aiofiles.open(path, 'rb') as f:
await f.readinto(buffer)
return memory_obj
# TODO(Jiayi): use memory allocator to redeuce cpu buffer allocation
# TODO(Jiayi): the pinned cpu memory_obj should directly be passed into
# gpu connector; this gpu buffer could be avoided
[docs]
def load_bytes_from_disk(
self,
path: str,
dtype: torch.dtype,
shape: torch.Size,
) -> Optional[MemoryObj]:
"""
Load bytearray from disk.
"""
memory_obj = self.memory_allocator.allocate(shape, dtype)
if memory_obj is None:
logger.debug("Memory allocation failed during async disk load.")
return None
buffer = memory_obj.byte_array
with open(path, 'rb') as f:
f.readinto(buffer)
return memory_obj
[docs]
@_lmcache_nvtx_annotate
@torch.inference_mode()
def load_disk(
self,
path: str,
backend: str = "bytes",
dtype: Optional[torch.dtype] = None,
shape: Optional[torch.Size] = None,
) -> Optional[MemoryObj]:
"""
Load KV from disk.
"""
if backend == "bytes":
assert dtype is not None
assert shape is not None
memory_obj = self.load_bytes_from_disk(path, dtype, shape)
else:
raise ValueError(f"Invalid backend: {backend}")
return memory_obj
[docs]
def close(self) -> None:
if self.lookup_server is not None:
self.disk_lock.acquire()
self.lookup_server.batched_remove(list(self.dict.keys()))
self.disk_lock.release()