Source code for lmcache.storage_backend.serde.cachegen_decoder

from typing import List, Optional

import torch
import torchac_cuda  # type: ignore

import lmcache.storage_backend.serde.cachegen_basics as CGBasics
from lmcache.config import LMCacheEngineConfig, LMCacheEngineMetadata
from lmcache.logging import init_logger
from lmcache.storage_backend.serde.cachegen_basics import (
    CacheGenConfig, CacheGenGPUBytestream, CacheGenGPUEncoderOutput)
from lmcache.storage_backend.serde.serde import Deserializer
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)


[docs] @_lmcache_nvtx_annotate def quant(bins: int, xq: torch.Tensor, max1: float): C = bins // 2 - 1 x = xq / C * max1 return x
[docs] def do_dequantize(t: torch.Tensor, bins: torch.Tensor, maxtensors: torch.Tensor): """ t: [nlayers, ntokens, nchannels] bins: [nlayers] maxtensors: [nlayers, ntokens, 1] """ C = (bins // 2 - 1)[:, None, None] t = t - C t = t / C t = t * maxtensors return t
[docs] @_lmcache_nvtx_annotate def recombine_bytes(bytes_tensor, output_lengths) -> torch.Tensor: output_buffer_size = CGBasics.CACHEGEN_GPU_MAX_TOKENS_PER_CHUNK offsets = (output_lengths.flatten().cumsum(0).roll(1).reshape( output_lengths.shape)) offsets[0][0] = 0 indexes = torch.arange(output_buffer_size, device=offsets.device).tile( (output_lengths.shape[0], output_lengths.shape[1], 1)) final_indexes = (indexes + offsets[:, :, None]).clamp(max=len(bytes_tensor) - 1) return bytes_tensor[final_indexes]
[docs] @_lmcache_nvtx_annotate def decode_chunk( cdf: torch.Tensor, data_chunk: CacheGenGPUBytestream, target_buffer: torch.Tensor, ) -> None: """ Write the decode output in target_buffer Expected shape: [nlayers (kv in total), ntokens, nchannels] """ bytes_tensor = data_chunk.bytestream length_prefsum = ( data_chunk.bytestream_lengths.flatten().cumsum(0).reshape( data_chunk.bytestream_lengths.shape)) torchac_cuda.decode_fast_prefsum(cdf, bytes_tensor, length_prefsum, target_buffer)
[docs] @_lmcache_nvtx_annotate def decode_function_gpu( cdf: torch.Tensor, data_chunks: List[CacheGenGPUBytestream], layers_in_key: int, chunk_size: int, output: torch.Tensor, ): # TODO: dtype and shape -- still have 128 and 8 """ Given the path to the encoded KV bytestream, decode the KV cache Inputs: cdf: the cdf tensor, in shape [2 * nlayers, nchannels, bins + 1] data_chunks: the data_chunks in the encoder's output layers_in_key: number of layers in K (or V) (K/V should have the same number of layers) chunk_size: the chunk_size output: output buffer, in shape [ntokens, 2 * nlayers * nchannels] Outputs: key: the decoded key tensor in the shape of (layers, tokens, nchannels) value: the decoded value tensor in the shape of (layers, tokens, nchannels) """ nlayers, nchannels, _ = cdf.shape output = output.reshape((nlayers, chunk_size, nchannels)) start = 0 for data_chunk in data_chunks: end = start + data_chunk.ntokens decode_chunk(cdf, data_chunk, output[:, start:end, :]) start = end out = output.reshape((2, layers_in_key, chunk_size, nchannels)) key, value = out.float() return key, value
[docs] class CacheGenDeserializer(Deserializer): def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata, dtype): self.dtype = dtype self.cachegen_config = CacheGenConfig.from_model_name( metadata.model_name) self.chunk_size = config.chunk_size self.output_buffer: Optional[torch.Tensor] = None self.fmt = metadata.fmt self.key_bins = self.make_key_bins(self.cachegen_config) self.value_bins = self.make_value_bins(self.cachegen_config)
[docs] def make_key_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.kspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
[docs] def make_value_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.vspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
[docs] def get_output_buffer(self, nlayers: int, nchannels: int, ntokens: int): if (self.output_buffer is None or self.output_buffer.shape[1] != 2 * nlayers * nchannels): self.output_buffer = torch.zeros( (self.chunk_size, 2 * nlayers * nchannels), dtype=torch.uint8).cuda() return self.output_buffer[:ntokens, :]
[docs] @_lmcache_nvtx_annotate def from_bytes(self, bs: bytes) -> torch.Tensor: encoder_output = CacheGenGPUEncoderOutput.from_bytes(bs) encoder_output.max_tensors_key = encoder_output.max_tensors_key.cuda() encoder_output.max_tensors_value = ( encoder_output.max_tensors_value.cuda()) ntokens = encoder_output.max_tensors_key.shape[1] layers_in_key = encoder_output.max_tensors_key.shape[0] key, value = decode_function_gpu( encoder_output.cdf, encoder_output.data_chunks, layers_in_key, ntokens, self.get_output_buffer( encoder_output.cdf.shape[0] // 2, encoder_output.cdf.shape[1], ntokens, ), ) # Temporary fix for #83: change the device of key_bins and value_bins # to the device of key and value # This requires a long-term fix in the future. Currently, # CacheGenGPUEncoderOutput has implicit device in itself. # More specifically, if the encoder encodes the tensor on GPU0, the # from_bytes will also return a tensor on GPU0 # We may want to dynamically configure the device based on config and # metadata in the future if self.key_bins.device != key.device: self.key_bins = self.key_bins.to(key.device) if self.value_bins.device != value.device: self.value_bins = self.value_bins.cuda() key = do_dequantize(key, self.key_bins, encoder_output.max_tensors_key) value = do_dequantize(value, self.value_bins, encoder_output.max_tensors_value) """ merge key and value back and reshape """ nlayers, ntokens, nchannels = key.shape blob = torch.stack([key, value]) # [2, nlayers, ntokens, nchannels] blob = blob.reshape(( 2, nlayers, ntokens, encoder_output.num_heads, encoder_output.head_size, )) match self.fmt: case "vllm": return blob.permute((1, 0, 2, 3, 4)).to( self.dtype) # [nlayers, 2, ntokens, num_heads, head_size] case "huggingface": return blob.permute((1, 0, 3, 2, 4)).to( self.dtype) # [nlayers, 2, num_heads, ntokens, head_size] case _: raise RuntimeError("Unknown format %s" % self.fmt)