from typing import Optional
import torch
from lmcache.config import LMCacheEngineMetadata
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.experimental.memory_management import (BytesBufferMemoryObj,
MemoryAllocatorInterface,
MemoryFormat, MemoryObj,
MemoryObjMetadata,
TensorMemoryObj)
from lmcache.experimental.storage_backend.naive_serde.cachegen_basics import \
CacheGenConfig
from lmcache.experimental.storage_backend.naive_serde.serde import Deserializer
from lmcache.logging import init_logger
from lmcache.storage_backend.serde.cachegen_basics import \
CacheGenGPUEncoderOutput
from lmcache.storage_backend.serde.cachegen_decoder import (
decode_function_gpu, do_dequantize)
from lmcache.utils import _lmcache_nvtx_annotate
logger = init_logger(__name__)
[docs]
class CacheGenDeserializer(Deserializer):
def __init__(self, config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata,
memory_allocator: MemoryAllocatorInterface):
self.dtype = metadata.kv_dtype
self.cachegen_config = CacheGenConfig.from_model_name(
metadata.model_name)
self.chunk_size = config.chunk_size
self.output_buffer: Optional[torch.Tensor] = None
self.fmt = metadata.fmt
self.key_bins = self.make_key_bins(self.cachegen_config)
self.value_bins = self.make_value_bins(self.cachegen_config)
self.memory_allocator = memory_allocator
[docs]
def make_key_bins(self, config: CacheGenConfig) -> torch.Tensor:
ret = torch.zeros(config.nlayers)
for spec in config.kspecs:
ret[spec.start_layer:spec.end_layer] = spec.bins
return ret.cuda()
[docs]
def make_value_bins(self, config: CacheGenConfig) -> torch.Tensor:
ret = torch.zeros(config.nlayers)
for spec in config.vspecs:
ret[spec.start_layer:spec.end_layer] = spec.bins
return ret.cuda()
[docs]
def get_output_buffer(self, nlayers: int, nchannels: int, ntokens: int):
if (self.output_buffer is None
or self.output_buffer.shape[1] != 2 * nlayers * nchannels):
self.output_buffer = torch.zeros(
(self.chunk_size, 2 * nlayers * nchannels),
dtype=torch.uint8).cuda()
return self.output_buffer[:ntokens, :]
# TODO(Jiayi): A lot of memory copies can be avoided in this function.
[docs]
@_lmcache_nvtx_annotate
def deserialize(
self,
buffer_memory_obj: BytesBufferMemoryObj) -> Optional[MemoryObj]:
encoder_output = CacheGenGPUEncoderOutput.from_bytes(
buffer_memory_obj.byte_array)
encoder_output.max_tensors_key = encoder_output.max_tensors_key.cuda()
encoder_output.max_tensors_value = (
encoder_output.max_tensors_value.cuda())
ntokens = encoder_output.max_tensors_key.shape[1]
layers_in_key = encoder_output.max_tensors_key.shape[0]
key, value = decode_function_gpu(
encoder_output.cdf,
encoder_output.data_chunks,
layers_in_key,
ntokens,
self.get_output_buffer(
encoder_output.cdf.shape[0] // 2,
encoder_output.cdf.shape[1],
ntokens,
),
)
# Temporary fix for #83: change the device of key_bins and value_bins
# to the device of key and value
# This requires a long-term fix in the future. Currently,
# CacheGenGPUEncoderOutput has implicit device in itself.
# More specifically, if the encoder encodes the tensor on GPU0, the
# from_bytes will also return a tensor on GPU0
# We may want to dynamically configure the device based on config and
# metadata in the future
if self.key_bins.device != key.device:
self.key_bins = self.key_bins.to(key.device)
if self.value_bins.device != value.device:
self.value_bins = self.value_bins.cuda()
key = do_dequantize(key, self.key_bins, encoder_output.max_tensors_key)
value = do_dequantize(value, self.value_bins,
encoder_output.max_tensors_value)
""" merge key and value back and reshape """
nlayers, ntokens, nchannels = key.shape
blob = torch.stack([key, value]) # [2, nlayers, ntokens, nchannels]
blob = blob.reshape((
2,
nlayers,
ntokens,
encoder_output.num_heads,
encoder_output.head_size,
))
match self.fmt:
case "vllm":
hidden_dim = blob.shape[-1] * blob.shape[-2]
kv_chunk = blob.reshape(*blob.shape[:-2], hidden_dim).to(
self.dtype) # [nlayers, 2, ntokens, num_heads, head_size]
case _:
raise RuntimeError("Unknown format %s" % self.fmt)
memory_obj = TensorMemoryObj(
raw_data=kv_chunk,
metadata=MemoryObjMetadata(
shape=kv_chunk.shape,
dtype=kv_chunk.dtype,
address=-1,
phy_size=kv_chunk.numel() * kv_chunk.element_size(),
ref_count=-1, # HACK: avoid mis-free
fmt=MemoryFormat.KV_BLOB))
return memory_obj
#memory_obj = self.memory_allocator.allocate(kv_chunk.shape,
# kv_chunk.dtype,
# fmt=MemoryFormat.KV_BLOB)
#if memory_obj is None:
# logger.warning("Memory allocation failed in cachegen deserializer")
# return None
#assert memory_obj.tensor is not None
#memory_obj.tensor.copy_(kv_chunk)
#return memory_obj