Source code for lmcache.storage_backend.hybrid_backend
import time
from typing import Iterable, List, Optional, Union
import torch
from lmcache.config import LMCacheEngineConfig, LMCacheEngineMetadata
from lmcache.logging import init_logger
from lmcache.storage_backend.abstract_backend import LMCBackendInterface
from lmcache.storage_backend.local_backend import LMCLocalBackend
from lmcache.storage_backend.remote_backend import (LMCPipelinedRemoteBackend,
LMCRemoteBackend)
from lmcache.utils import CacheEngineKey, _lmcache_nvtx_annotate
logger = init_logger(__name__)
[docs]
class LMCHybridBackend(LMCBackendInterface):
"""
A hybrid backend that uses both local and remote backend to store and
retrieve data.
It implements write-through and read-through caching.
"""
# TODO: LRU eviction policy
def __init__(self, config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata):
self.local_store = LMCLocalBackend(config)
self.remote_store: Union[LMCPipelinedRemoteBackend, LMCRemoteBackend]
if config.pipelined_backend:
self.remote_store = LMCPipelinedRemoteBackend(config, metadata)
else:
self.remote_store = LMCRemoteBackend(config, metadata)
# NOTE: no need to add `dst_device` in hybrid bckend
# as the logic is handled in local/remote backend
# TODO add a configuration item to do this
self._prefetch(metadata)
def _prefetch(self, metadata: LMCacheEngineMetadata):
keys = self.remote_store.list()
nfetched = 0
logger.info("Found %d keys in remote backend", len(keys))
logger.debug(f"Metadata is {metadata}")
start = time.perf_counter()
for key in keys:
if (key.model_name != metadata.model_name
or key.worker_id != metadata.worker_id
or key.world_size != metadata.world_size):
continue
retrived_data = self.remote_store.get(key)
if retrived_data is not None:
self.local_store.put(key, retrived_data)
nfetched += 1
end = time.perf_counter()
logger.info(
"Pre-fetched %d keys from remote backend, used %.2f sec",
nfetched,
end - start,
)
[docs]
def contains(
self,
key: CacheEngineKey,
) -> bool:
return self.local_store.contains(key) or self.remote_store.contains(
key)
[docs]
def put(
self,
key: CacheEngineKey,
value: torch.Tensor,
blocking: bool = True,
):
# HACK(Jiayi): skip local cpu cache for now,
# local cpu cache can be activated with prefetching
# TODO(Jiayi): write-back/write through should determined by config
self.local_store.put(key, value, blocking=True)
self.remote_store.put(key, value, blocking)
[docs]
@_lmcache_nvtx_annotate
def get(
self,
key: CacheEngineKey,
) -> Optional[torch.Tensor]:
value = self.local_store.get(key)
if value is None:
value = self.remote_store.get(key)
if value is not None:
self.local_store.put(key, value)
return value
[docs]
@_lmcache_nvtx_annotate
def batched_get(
self,
keys: Iterable[CacheEngineKey],
) -> Iterable[Optional[torch.Tensor]]:
ret: List[Optional[torch.Tensor]] = []
remote_queries = []
remote_query_idxs = []
for idx, key in enumerate(keys):
value = self.local_store.get(key)
ret.append(value)
if value is None:
remote_queries.append(key)
remote_query_idxs.append(idx)
remote_query_results = self.remote_store.batched_get(remote_queries)
for idx, key, result in zip(remote_query_idxs, remote_queries,
remote_query_results):
if result is not None:
self.local_store.put(key, result)
ret[idx] = result
return ret
[docs]
def close(self):
self.local_store.close()
self.remote_store.close()