Source code for lmcache.experimental.storage_backend.storage_manager

import asyncio
import threading
from collections import OrderedDict
from concurrent.futures import Future
from typing import Dict, List, Optional, Tuple

import torch

from lmcache.config import LMCacheEngineMetadata
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.lookup_server import LookupServerInterface
from lmcache.experimental.memory_management import (MemoryAllocatorInterface,
                                                    MemoryFormat, MemoryObj,
                                                    MixedMemoryAllocator)
from lmcache.experimental.storage_backend import CreateStorageBackends
from lmcache.experimental.storage_backend.abstract_backend import \
    StorageBackendInterface
from lmcache.logging import init_logger
from lmcache.utils import CacheEngineKey, _lmcache_nvtx_annotate

logger = init_logger(__name__)


# TODO: extend this class to implement caching policies and eviction policies
[docs] class StorageManager: """ The StorageManager is responsible for managing the storage backends. """ def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, allocator: MemoryAllocatorInterface, lookup_server: Optional[LookupServerInterface] = None): self.memory_allocator = allocator self.hot_cache: OrderedDict[CacheEngineKey, MemoryObj] = OrderedDict() self.use_hot = config.local_cpu self.loop = asyncio.new_event_loop() self.thread = threading.Thread(target=self.loop.run_forever) self.thread.start() #TODO: remove hardcode dst_device = "cuda" self.storage_backends: OrderedDict[str, StorageBackendInterface] =\ CreateStorageBackends( config, metadata, self.loop, allocator, dst_device, lookup_server) self.prefetch_tasks: Dict[CacheEngineKey, Future] = {} self.put_tasks: Dict[str, Dict[CacheEngineKey, Tuple[Future, MemoryObj]]] = {} for backend_name in self.storage_backends.keys(): self.put_tasks[backend_name] = {} self.manager_lock = threading.Lock() self.lookup_server = lookup_server self.stream = torch.cuda.Stream()
[docs] def allocate( self, shape: torch.Size, dtype: torch.dtype, eviction=True, ) -> Optional[MemoryObj]: """ Allocate memory object with memory allocator. Use LRU evictor if eviction is enabled. """ self.manager_lock.acquire() memory_obj = self.memory_allocator.allocate(shape, dtype) if not eviction or memory_obj is not None: self.manager_lock.release() return memory_obj assert isinstance(self.memory_allocator, MixedMemoryAllocator) evict_keys = [] for evict_key in self.hot_cache: # If the ref_count > 1, we cannot evict it as the hot cache # might be used as buffers by other storage backends if self.memory_allocator.get_ref_count( self.hot_cache[evict_key]) > 1: continue evict_keys.append(evict_key) self.memory_allocator.ref_count_down(self.hot_cache[evict_key]) memory_obj = self.memory_allocator.allocate(shape, dtype) logger.debug("Evicting 1 chunk from hot cache") if memory_obj is not None: break # TODO(Jiayi): move this before the loop # In this way, we don't need to do eviction for big objects # TODO(Jiayi): the following code is hacky, please refactor if self.memory_allocator.pin_allocator.num_active_allocations == 0: break for evict_key in evict_keys: self.hot_cache.pop(evict_key) if self.lookup_server is not None: self.lookup_server.batched_remove(evict_keys) self.manager_lock.release() return memory_obj
[docs] def put( self, key: CacheEngineKey, memory_obj: MemoryObj, ) -> None: """ Non-blocking function to put the memory object into the storages. Do not store if the same object is being stored (handled here by storage manager) or has been stored (handled by storage backend). """ self.manager_lock.acquire() if self.use_hot: # During overwrite, we need to free the old memory object # to avoid memory leak. # NOTE(Jiayi): overwrite should not happen, at least for # prefix caching if key in self.hot_cache: old_memory_obj = self.hot_cache.pop(key) self.memory_allocator.ref_count_down(old_memory_obj) self.hot_cache[key] = memory_obj self.memory_allocator.ref_count_up(memory_obj) # TODO(Jiayi): currently, the entire put task will be cancelled # if one of the backend is already storing this cache. # This might not be ideal. for storage_backend in self.storage_backends.values(): if storage_backend.exists_in_put_tasks(key): self.memory_allocator.ref_count_down(memory_obj) self.manager_lock.release() return self.manager_lock.release() #ever_put = False for backend_name, backend in self.storage_backends.items(): put_task = backend.submit_put_task(key, memory_obj) if put_task is None: continue self.manager_lock.acquire() self.memory_allocator.ref_count_down(memory_obj) self.manager_lock.release()
@_lmcache_nvtx_annotate def _update_hot_cache(self, key: CacheEngineKey, memory_obj: MemoryObj): if memory_obj is None or not self.use_hot: return if memory_obj.tensor is not None and memory_obj.tensor.is_cuda: self.manager_lock.acquire() if key in self.hot_cache: self.manager_lock.release() return self.manager_lock.release() # Allocate a cpu memory object cpu_memory_obj = self.memory_allocator.allocate( memory_obj.get_shape(), memory_obj.get_dtype(), fmt=memory_obj.get_memory_format()) if cpu_memory_obj is None: logger.warning( "Memory allocation failed in cachegen deserializer") return None # Copy the tensor to the cpu memory object assert cpu_memory_obj.tensor is not None self.stream.wait_stream(torch.cuda.default_stream()) with torch.cuda.stream(self.stream): cpu_memory_obj.tensor.copy_(memory_obj.tensor, non_blocking=True) memory_obj.tensor.record_stream(self.stream) # Update the hot cache self.manager_lock.acquire() self.hot_cache[key] = cpu_memory_obj self.memory_allocator.ref_count_up(cpu_memory_obj) self.manager_lock.release() logger.debug("Updated hot cache!") return else: self.manager_lock.acquire() if self.use_hot and key not in self.hot_cache: self.hot_cache[key] = memory_obj self.memory_allocator.ref_count_up(memory_obj) self.manager_lock.release()
[docs] def get(self, key: CacheEngineKey) -> Optional[MemoryObj]: """ Blocking function to get the memory object from the storages. """ # Search in prefetch task self.manager_lock.acquire() prefetch_task = self.prefetch_tasks.get(key, None) self.manager_lock.release() # Wait until prefetch task finishes # Here, it is assumed all prefetch tasks load the memoryobj to # hot cache (pinned cpu buffer) if prefetch_task is not None: assert self.use_hot is True,\ "CPU cache must be enabled for prefetching" logger.debug("Waiting for prefetching result. " "Optimally, this should not happen.") # Calling result() twice (already once in callback) will have # no effect # Tune the timeout for better performance prefetch_task.result(timeout=1) # Search in hot_cache self.manager_lock.acquire() memory_obj = self.hot_cache.get(key, None) if memory_obj is not None: self.memory_allocator.ref_count_up(memory_obj) self.hot_cache.move_to_end(key) self.manager_lock.release() return memory_obj self.manager_lock.release() # Search all backends for blocking get for backend_name, backend in self.storage_backends.items(): # Avoid read-write contention #if key in self.put_tasks[backend_name]: # continue # NOTE(Jiayi): bypass the allocator for now memory_obj = backend.get_blocking(key) if memory_obj is not None: self._update_hot_cache(key, memory_obj) return memory_obj return None
# TODO(Jiayi): we need to consider eviction in prefetch
[docs] def prefetch_callback(self, future, key): """ Update metadata after prefetch. """ self.manager_lock.acquire() prefetch_task = self.prefetch_tasks.pop(key) self.manager_lock.release() try: buffer_memory_obj = prefetch_task.result() except Exception as e: logger.error( f"Exception captured from future in prefetch_callback: {e}") raise e kv_chunk = buffer_memory_obj.tensor kv_shape = kv_chunk.shape kv_dtype = kv_chunk.dtype memory_obj = self.memory_allocator.allocate(kv_shape, kv_dtype) if memory_obj is None: logger.warning("Memory allocation failed in prefetch_callback") return assert memory_obj.tensor is not None, "Encounter invalid tensor" # TODO(Jiayi): this part should be done in another process if # the cpu->pinned cpu copy is blocking. prefetch_stream = torch.cuda.Stream() with torch.cuda.stream(prefetch_stream): memory_obj.tensor.copy_(kv_chunk, non_blocking=True) prefetch_stream.synchronize() # TODO(Jiayi): please remove this hardcode memory_obj.metadata.fmt = MemoryFormat.KV_BLOB # NOTE: no need to ref_count_up here because # the memory_obj's ref_count is already 1 self.manager_lock.acquire() self.hot_cache[key] = memory_obj self.manager_lock.release()
[docs] def prefetch(self, key: CacheEngineKey) -> None: """Launch a prefetch request in the storage backend. Non-blocking """ # Call contains for each backend. Find the nearest cache self.manager_lock.acquire() if key in self.hot_cache: self.manager_lock.release() return if key in self.prefetch_tasks: self.manager_lock.release() return self.manager_lock.release() for backend in self.storage_backends.values(): prefetch_task = backend.submit_prefetch_task(key) if prefetch_task is None: continue lambda_callback = lambda f: \ self.prefetch_callback(f, key) self.manager_lock.acquire() self.prefetch_tasks[key] = prefetch_task prefetch_task.add_done_callback(lambda_callback) self.manager_lock.release() break
# TODO(Jiayi): Currently, search_range is only used for testing.
[docs] def contains( self, key: CacheEngineKey, search_range: Optional[List[str]] = None, ) -> bool: """ Check whether the key exists in the storage backend. :param CacheEngineKey key: The key to check. :param Optional[List[str]] search_range: The range of storage backends to search in. Should be a subset of ["Hot", "LocalDiskBackend"] for now. If None, search in all backends. return: True if the key exists in the specified storage backends. """ with self.manager_lock: if search_range is None or "Hot" in search_range: if key in self.hot_cache: return True for backend_name, backend in self.storage_backends.items(): if search_range is not None and \ backend_name not in search_range: continue if backend.contains(key): return True return False
[docs] def close(self): if self.lookup_server is not None: self.manager_lock.acquire() self.lookup_server.batched_remove(list(self.hot_cache.keys())) self.manager_lock.release() for backend in self.storage_backends.values(): backend.close() # using threadsafe method here as stop modifies # the internal state of the loop (in another thread) if self.loop.is_running(): self.loop.call_soon_threadsafe(self.loop.stop) if self.thread.is_alive(): self.thread.join() logger.info("Storage manager closed.")