import asyncio
import multiprocessing
from typing import Dict, List, Optional
import torch
from lmcache.config import LMCacheEngineMetadata
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.distributed_server import (
DistributedServerInterface, NaiveDistributedServer)
from lmcache.experimental.gpu_connector import GPUConnectorInterface
from lmcache.experimental.lookup_server import (LookupServerInterface,
RedisLookupServer)
from lmcache.experimental.memory_management import (MemoryAllocatorInterface,
MixedMemoryAllocator)
from lmcache.experimental.storage_backend.storage_manager import StorageManager
from lmcache.experimental.token_database import (ChunkedTokenDatabase,
TokenDatabase)
from lmcache.logging import init_logger
from lmcache.observability import LMCacheStatsLogger, LMCStatsMonitor
from lmcache.usage_context import InitializeUsageContext
from lmcache.utils import _lmcache_nvtx_annotate
logger = init_logger(__name__)
[docs]
class CacheEngineEndSignal:
pass
[docs]
class LMCacheEngine:
"""The main class for the cache engine.
When storing the KV caches into the cache engine, it takes GPU KV
caches from the serving engine and convert them into MemoryObjs that
resides in the CPU. The MemoryObjs are then being stored into the
StorageBackends in an asynchronous manner.
When retrieving the KV caches from the cache engine, it fetches the
MemoryObjs from the StorageBackends and convert them into GPU KV caches
by GPUConnectors specialized for the serving engine.
It also supports prefetching the KV caches from the StorageBackends.
It relies on the StorageBackends to manage the requests of prefetching
and real retrieval and avoid the conflicts.
"""
def __init__(
self,
config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata,
memory_allocator: MemoryAllocatorInterface,
token_database: TokenDatabase,
gpu_connector: GPUConnectorInterface,
):
self.config = config
self.metadata = metadata
self.memory_allocator = memory_allocator
self.token_database = token_database
self.gpu_connector = gpu_connector
self.enable_p2p = config.enable_p2p
# NOTE: Unix systems use fork by default
multiprocessing.set_start_method('spawn', force=True)
self.lookup_server: Optional[LookupServerInterface] = None
# TODO(Jiayi): hard-coded for now
if self.enable_p2p:
self.lookup_server = RedisLookupServer(config)
self.storage_manager = StorageManager(config, metadata,
self.memory_allocator,
self.lookup_server)
if self.enable_p2p:
self.distributed_loop = asyncio.get_event_loop()
assert self.lookup_server is not None
self.distributed_server: DistributedServerInterface = \
NaiveDistributedServer(self.storage_manager,
self.lookup_server,
self.memory_allocator,
self.distributed_loop,
config)
InitializeUsageContext(config.to_original_config(), metadata)
self.stats_monitor = LMCStatsMonitor.GetOrCreate()
[docs]
@_lmcache_nvtx_annotate
@torch.inference_mode()
def store(self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""Store the tokens and mask into the cache engine.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
Should include KV cache specific information (e.g., paged KV buffer
and the page tables).
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
monitor_req_id = self.stats_monitor.on_store_request(
torch.sum(mask))
else:
monitor_req_id = self.stats_monitor.on_store_request(len(tokens))
for start, end, key in self.token_database.process_tokens(
tokens, mask):
if self.storage_manager.contains(key):
continue
# Allocate the memory object
num_tokens = end - start
kv_shape = self.gpu_connector.get_shape(num_tokens)
kv_dtype = self.metadata.kv_dtype
memory_obj = self.storage_manager.allocate(kv_shape, kv_dtype)
if memory_obj is None:
logger.warning("Failed to allocate memory for the KV cache.\n"
"The KV cache will not be stored.")
break
# Put the memory object to the storage backend
# Disabling put_queue for now, as it's not necessary
# and bringing big overhead
# self.put_queue.put((key, memory_obj, start, end, kwargs))
self.gpu_connector.from_gpu(memory_obj, start, end, **kwargs)
self.storage_manager.put(key, memory_obj)
# Update lookup server
if self.lookup_server is not None:
self.lookup_server.insert(key)
self.stats_monitor.on_store_finished(monitor_req_id)
[docs]
@_lmcache_nvtx_annotate
@torch.inference_mode()
def retrieve(self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Retrieve the KV caches from the cache engine. And put the retrieved
KV cache to the serving engine via the GPU connector.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
Should include KV cache specific information (e.g., paged KV buffer
and the page tables).
:return: the boolean mask indicating which tokens are retrieved. The
length of the mask should be the same as the tokens. On CPU.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
monitor_req_id = self.stats_monitor.on_retrieve_request(
torch.sum(mask))
else:
monitor_req_id = self.stats_monitor.on_retrieve_request(
len(tokens))
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
for start, end, key in self.token_database.process_tokens(
tokens, mask):
# Get the memory object from the storage backend
memory_obj = self.storage_manager.get(key)
if memory_obj is None:
if self.enable_p2p:
future_memory_obj = asyncio.run_coroutine_threadsafe(
self.distributed_server.issue_get(key),
self.distributed_loop)
memory_obj = future_memory_obj.result()
if memory_obj is None:
break
ret_mask[start:end] = True
# NOTE(Jiayi): memory_obj doesn't have to be a pinned
# cpu tensor for the sake of performance.
# For example, disk->gpu is faster than disk->cpu->gpu.
# RDMA is another example.
self.gpu_connector.to_gpu(memory_obj, start, end, **kwargs)
self.memory_allocator.ref_count_down(memory_obj)
self.stats_monitor.on_retrieve_finished(monitor_req_id,
torch.sum(ret_mask))
return ret_mask
[docs]
def prefetch(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> None:
"""Launch the prefetching process in the storage manager to load the
KV to the local CPU memory
"""
for start, end, key in self.token_database.process_tokens(
tokens, mask):
self.storage_manager.prefetch(key)
# TODO(Jiayi): Currently, search_range is only used for testing.
[docs]
def lookup(
self,
tokens: torch.Tensor,
search_range: Optional[List[str]] = None,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:param Optional[List[str]] search_range: The range of storage backends
to search in. Should be a subset of ["Hot", "LocalDiskBackend"] for now.
If None, search in all backends.
:return: An int indicating how many prefix tokens are cached.
"""
for start, end, key in self.token_database.process_tokens(tokens):
if not self.storage_manager.contains(key, search_range):
return start
return end
[docs]
def close(self) -> None:
"""Close the cache engine and free all the resources"""
if self.enable_p2p:
self.distributed_server.close()
self.storage_manager.close()
logger.info("LMCacheEngine closed.")
[docs]
class LMCacheEngineBuilder:
_instances: Dict[str, LMCacheEngine] = {}
_cfgs: Dict[str, LMCacheEngineConfig] = {}
_metadatas: Dict[str, LMCacheEngineMetadata] = {}
_stat_loggers: Dict[str, LMCacheStatsLogger] = {}
@staticmethod
def _Create_memory_allocator(
config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata,
) -> MemoryAllocatorInterface:
max_local_cpu_size = config.max_local_cpu_size
return MixedMemoryAllocator(int(max_local_cpu_size * 1024**3))
@staticmethod
def _Create_token_database(
config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata,
) -> TokenDatabase:
return ChunkedTokenDatabase(config, metadata)
[docs]
@classmethod
def get_or_create(
cls,
instance_id: str,
config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata,
gpu_connector:
GPUConnectorInterface, # gpu connectors is from outside
) -> LMCacheEngine:
"""
Builds a new LMCacheEngine instance if it doesn't already exist for the
given ID.
raises: ValueError if the instance already exists with a different
configuration.
"""
logger.info(f"Creating LMCacheEngine instance {instance_id}")
if instance_id not in cls._instances:
memory_allocator = cls._Create_memory_allocator(config, metadata)
token_database = cls._Create_token_database(config, metadata)
stat_logger = LMCacheStatsLogger(metadata, log_interval=10)
engine = LMCacheEngine(config, metadata, memory_allocator,
token_database, gpu_connector)
cls._instances[instance_id] = engine
cls._cfgs[instance_id] = config
cls._metadatas[instance_id] = metadata
cls._stat_loggers[instance_id] = stat_logger
return engine
else:
if (cls._cfgs[instance_id] != config
or cls._metadatas[instance_id] != metadata):
raise ValueError(
f"Instance {instance_id} already exists with a different "
f"configuration or metadata.")
return cls._instances[instance_id]
[docs]
@classmethod
def get(cls, instance_id: str) -> Optional[LMCacheEngine]:
"""Returns the LMCacheEngine instance associated with the instance ID,
or None if not found."""
return cls._instances.get(instance_id)
[docs]
@classmethod
def destroy(cls, instance_id: str) -> None:
"""Close and delete the LMCacheEngine instance by the instance ID"""
# TODO: unit test for this
if instance_id in cls._instances:
stat_logger = cls._stat_loggers[instance_id]
stat_logger.shutdown()
engine = cls._instances[instance_id]
engine.close()
cls._instances.pop(instance_id, None)
cls._cfgs.pop(instance_id, None)
cls._metadatas.pop(instance_id, None)
cls._stat_loggers.pop(instance_id, None)
LMCStatsMonitor.DestroyInstance()