Source code for lmcache.experimental.storage_backend.naive_serde.cachegen_encoder

import torch

from lmcache.config import LMCacheEngineMetadata
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.memory_management import (BytesBufferMemoryObj,
                                                    MemoryAllocatorInterface,
                                                    MemoryObj)
from lmcache.experimental.storage_backend.naive_serde.cachegen_basics import \
    CacheGenConfig
from lmcache.experimental.storage_backend.naive_serde.serde import Serializer
from lmcache.logging import init_logger
from lmcache.storage_backend.serde.cachegen_encoder import encode_function
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)


[docs] class CacheGenSerializer(Serializer): def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, memory_allocator: MemoryAllocatorInterface): self.cachegen_config = CacheGenConfig.from_model_name( metadata.model_name) self.chunk_size = config.chunk_size self.fmt = metadata.fmt self.key_bins = self.make_key_bins(self.cachegen_config) self.value_bins = self.make_value_bins(self.cachegen_config) self.kv_shape = metadata.kv_shape self.memory_allocator = memory_allocator
[docs] def make_key_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.kspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
[docs] def make_value_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.vspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
# TODO(Jiayi): A lot of memory copies can be avoided in this function.
[docs] @_lmcache_nvtx_annotate def serialize(self, memory_obj: MemoryObj) -> BytesBufferMemoryObj: """ Serialize a KV_BLOB MemoryObj to CACHEGEN_BINARY MemoryObj. Input: memory_obj: the memory object to be serialized. Returns: MemoryObj: the serialized binary memory object. """ # TODO(Jiayi): please avoid this copy by directly performing # serialization inside gpu connector. assert memory_obj.tensor is not None tensor = memory_obj.tensor.cuda() # Temporary fix for issue #83: encoder will have the default device 0 # on all the ray workers. Need to set it to the correct device. # Also need to figure out why this happens. if torch.cuda.current_device != tensor.device: torch.cuda.set_device(tensor.device) if tensor.device != self.key_bins.device: self.key_bins = self.key_bins.to(tensor.device) if tensor.device != self.value_bins.device: self.value_bins = self.value_bins.to(tensor.device) # tensor is [2, num_layers, num_tokens, hidden_size] tensor = tensor.view(*tensor.shape[:-1], self.kv_shape[-2], self.kv_shape[-1]) tensor = tensor.permute([1, 0, 2, 3, 4]) # TODO(Jiayi): remove hardcoded "2" """ expecting a tensor of shape [num_layers, 2, num_tokens, num_heads, head_size] """ ntokens = tensor.shape[2] output_dict = encode_function( tensor, self.cachegen_config, self.key_bins, self.value_bins, ntokens, ) return BytesBufferMemoryObj(output_dict.to_bytes())