Source code for lmcache.storage_backend.serde.cachegen_encoder

from typing import Dict, Tuple

import torch
import torchac_cuda  # type: ignore

import lmcache.storage_backend.serde.cachegen_basics as CGBasics
from lmcache.config import LMCacheEngineConfig, LMCacheEngineMetadata
from lmcache.logging import init_logger
from lmcache.storage_backend.serde.cachegen_basics import (
    CacheGenConfig, CacheGenGPUBytestream, CacheGenGPUEncoderOutput)
from lmcache.storage_backend.serde.serde import Serializer
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)


[docs] @_lmcache_nvtx_annotate def torch_quant(bins: int, qA: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize a float tensor to fixed number of bins Input: bins: number of bins qA: the input tensor Returns: xq: the quantized tensor, in float32 max1: the maximum value of the tensor """ MAX = bins // 2 - 1 C = MAX max1 = torch.amax(torch.abs(qA), dim=-1, keepdim=True) xq = torch.round(qA * (C / max1)).to(torch.int8) return xq, max1
[docs] @_lmcache_nvtx_annotate def torch_quant_vectorized( bins: torch.Tensor, input_groups: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize each group of a tensor to fixed number of bins Input: bins: number of bins for different layers, with shape [nlayer] input_groups: with shape [nlayers, ntokens, nchannels] Returns: quantized groups: [nlayers, ntokens, nchannels] maxes: [nlayers, ntokens, 1] """ MAX = (bins // 2 - 1)[:, None, None] # shape [nlayers, 1, 1] max1 = torch.amax(torch.abs(input_groups), dim=-1, keepdim=True) # shape [nlayers, ntokens, 1] factor = MAX / max1 # shape [nlayers, ntokens, 1] xq = torch.round(input_groups * factor + MAX).to( torch.int8) # shape [nlayers, ntokens, nchannels] return xq, max1
[docs] @_lmcache_nvtx_annotate def concat_max(max1): """ Given a dict of max tensors, concatenate them into a single tensor """ # TODO: this function can be optimized, we don't really need this maxes = [] for i in range(len(max1)): maxes.append(max1[i].unsqueeze(0)) return torch.cat(maxes, dim=0)
def _split_kv(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]: """ Split a blob KV tensor to K and V tensors with the merged heads Input: tensor: the KV tensor with shape [num_layers, 2, num_tokens, num_heads, head_size] Returns: K and V tensors with shape [num_layers, num_tokens, num_channels] """ num_layers, _, num_tokens, num_heads, head_size = tensor.shape return torch.unbind(tensor.reshape(num_layers, 2, num_tokens, num_heads * head_size), dim=1) @_lmcache_nvtx_annotate def _convert_to_int_and_normalize(cdf_float, needs_normalization): """ Convert floatingpoint CDF to integers. See README for more info. The idea is the following: When we get the cdf here, it is (assumed to be) between 0 and 1, i.e, cdf in [0, 1) (note that 1 should not be included.) We now want to convert this to int16 but make sure we do not get the same value twice, as this would break the arithmetic coder (you need a strictly monotonically increasing function). So, if needs_normalization==True, we multiply the input CDF with 2**16 - (Lp - 1). This means that now, cdf in [0, 2**16 - (Lp - 1)]. Then, in a final step, we add an arange(Lp), which is just a line with slope one. This ensure that for sure, we will get unique, strictly monotonically increasing CDFs, which are in [0, 2**16) """ PRECISION = 16 Lp = cdf_float.shape[-1] factor = torch.tensor(2, dtype=torch.float32, device=cdf_float.device).pow_(PRECISION) new_max_value = factor if needs_normalization: new_max_value = new_max_value - (Lp - 1) cdf_float = cdf_float.mul(new_max_value) cdf_float = cdf_float.round() cdf = cdf_float.to(dtype=torch.int16, non_blocking=True) if needs_normalization: r = torch.arange(Lp, dtype=torch.int16, device=cdf.device) cdf.add_(r) return cdf
[docs] class CacheGenEncoderImpl: def __init__(self, **kwargs) -> None: """ Fields: - fp_kv: should be a tensor of shape (num_layers, num_tokens, num_channels) - fp_v: should be a tensor of shape (num_layers, num_tokens, num_channels) """ self.fp_k = kwargs["fp_k"] self.fp_v = kwargs["fp_v"] self.quantized_key: Dict[int, torch.Tensor] = {} self.max_tensors_key: Dict[int, torch.Tensor] = {} self.quantized_value: Dict[int, torch.Tensor] = {} self.max_tensors_value: Dict[int, torch.Tensor] = {} self.config = kwargs["config"]
[docs] @_lmcache_nvtx_annotate def quantize(self): """Quantize the key and value tensors (self.fp_k and self.fp_v) """ for layer in range(len(self.fp_k)): if layer < self.config["key_first_layers"]: bins = self.config["key_first_bins"] elif layer < self.config["key_second_layers"]: bins = self.config["key_second_bins"] else: bins = self.config["key_third_bins"] tmp = torch_quant(bins, self.fp_k[layer].float()) self.quantized_key[layer] = tmp[0] + bins // 2 - 1 self.max_tensors_key[layer] = tmp[1] for layer in range(len(self.fp_v)): if layer < self.config["value_first_layers"]: bins = self.config["value_first_bins"] else: bins = self.config["value_second_bins"] tmp = torch_quant(bins, self.fp_v[layer].float()) self.quantized_value[layer] = tmp[0] + bins // 2 - 1 self.max_tensors_value[layer] = tmp[1]
[docs] @_lmcache_nvtx_annotate def compute_cdf(self, is_key): """ Compute the CDF based on the quantized tensors Field: - start_layer: the start layer to compute the CDF - end_layer: the end layer to compute the CDF """ # TODO: Add start_index here channels = self.fp_k[0].shape[-1] def process_batch(X, max_val): """ input shape should be [channels, tokens] """ nchannels, ntokens = X.shape one_hot = torch.nn.functional.one_hot( X.long(), num_classes=max_val + 1).to( torch.float32) # Use float32 to avoid integer overflow counts = one_hot.sum(dim=1) / ntokens ret = torch.cumsum(counts, dim=1).roll(1) ret[:, 0] = 0 return ret def process_layers(X, max_val): """ x is a iterator of dict values each element's shape is [tokens, channels] """ results = [] for x in X: """do permute here""" batch_counts = process_batch(x.cuda().permute(1, 0), max_val) results.append(batch_counts) final_counts = torch.cat(results, dim=0) return final_counts if is_key: X = self.quantized_key.values() else: X = self.quantized_value.values() value_range = 32 cdfs = process_layers( X, value_range) # 4096 is batch size, ==> 18GB GPU memory final_cdf = cdfs.reshape((len(self.fp_k), channels, value_range + 1)) return final_cdf
[docs] @_lmcache_nvtx_annotate def collect_bytes(output_buffer, output_lengths) -> torch.Tensor: """ Collect a byte tensor from the output_buffer + output_lengths """ output_buffer_size = output_buffer.shape[-1] flattened_lengths = output_lengths.flatten() flattened_buffer = output_buffer.flatten() summed_length = (output_buffer_size - flattened_lengths).cumsum(0) summed_length = summed_length.roll(1) summed_length[0] = 0 indexes = summed_length.repeat_interleave(flattened_lengths) indexes = indexes + torch.arange(len(indexes), device=indexes.device) return flattened_buffer[indexes]
[docs] @_lmcache_nvtx_annotate def encode_ntokens(cdf_int, encode_input, output_buffer, output_lengths) -> torch.Tensor: """Encode a batch of ntokens. :param cdf_int: int16 tensor on GPU with shape [nlayers, nchannels, Lp] :param encode_input: int8 tensor on GPU with shape :param [nlayers, ntokens, nchannels] :param output_buffer: uint8 tensor on GPU with shape [nlayers, nchannels, BUFFER_SIZE] :param output_lengths: int32 tensor on GPU with shape [nlayers, nchannels] :return byte_tensor: the byte tensor """ torchac_cuda.encode_fast_new( cdf_int, encode_input, output_buffer, output_lengths, ) byte_tensor = collect_bytes(output_buffer, output_lengths) return byte_tensor
# return byte_tensor.cpu().numpy().tobytes()
[docs] @_lmcache_nvtx_annotate def encode_function( kv: torch.Tensor, config: CacheGenConfig, key_bins: torch.Tensor, value_bins: torch.Tensor, chunk_size: int, ) -> CacheGenGPUEncoderOutput: """ Given the path to the original key value cache, encode the KV cache """ num_heads, head_size = kv.shape[-2:] fp_k, fp_v = _split_kv(kv) nchannels = num_heads * head_size nlayers = fp_k.shape[0] + fp_v.shape[0] new_key, max_tensors_key = torch_quant_vectorized(key_bins, fp_k) new_value, max_tensors_value = torch_quant_vectorized(value_bins, fp_v) encode_input = torch.cat((new_key, new_value), dim=0).reshape(nlayers, chunk_size, nchannels) new_cdf_key = torchac_cuda.calculate_cdf(new_key, int(key_bins.max())) new_cdf_value = torchac_cuda.calculate_cdf(new_value, int(value_bins.max())) cdf_int = torch.cat([new_cdf_key, new_cdf_value]) output_buffer = torch.zeros( (nlayers, nchannels, CGBasics.CACHEGEN_GPU_MAX_TOKENS_PER_CHUNK), dtype=torch.uint8, device=encode_input.device, ) output_lengths = torch.zeros((nlayers, nchannels), dtype=torch.int32, device=encode_input.device) data_chunks = [] for i in range(0, chunk_size, CGBasics.CACHEGEN_GPU_MAX_TOKENS_PER_CHUNK): start = i end = min(i + CGBasics.CACHEGEN_GPU_MAX_TOKENS_PER_CHUNK, chunk_size) bytestream = encode_ntokens( cdf_int, encode_input[:, start:end, :], output_buffer, output_lengths, ) data_chunks.append( CacheGenGPUBytestream( bytestream=bytestream, bytestream_lengths=output_lengths.clone(), ntokens=end - start, )) return CacheGenGPUEncoderOutput( data_chunks, cdf_int, max_tensors_key=max_tensors_key, max_tensors_value=max_tensors_value, num_heads=num_heads, head_size=head_size, )
[docs] class CacheGenSerializer(Serializer): def __init__(self, config: LMCacheEngineConfig, metadata: LMCacheEngineMetadata): self.cachegen_config = CacheGenConfig.from_model_name( metadata.model_name) self.chunk_size = config.chunk_size self.fmt = metadata.fmt self.key_bins = self.make_key_bins(self.cachegen_config) self.value_bins = self.make_value_bins(self.cachegen_config)
[docs] def make_key_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.kspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
[docs] def make_value_bins(self, config: CacheGenConfig) -> torch.Tensor: ret = torch.zeros(config.nlayers) for spec in config.vspecs: ret[spec.start_layer:spec.end_layer] = spec.bins return ret.cuda()
[docs] @_lmcache_nvtx_annotate def to_bytes(self, tensor: torch.Tensor) -> bytes: """ Serialize a pytorch tensor to bytes. The serialized bytes should contain both the data and the metadata (shape, dtype, etc.) of the tensor. Input: t: the input pytorch tensor, can be on any device, in any shape, with any dtype Returns: bytes: the serialized bytes """ # Temporary fix for issue #83: encoder will have the default device 0 # on all the ray workers. Need to set it to the correct device. # Also need to figure out why this happens. if torch.cuda.current_device != tensor.device: torch.cuda.set_device(tensor.device) if tensor.device != self.key_bins.device: self.key_bins = self.key_bins.to(tensor.device) if tensor.device != self.value_bins.device: self.value_bins = self.value_bins.to(tensor.device) # TODO: permute is expensive here, need a better way to do it at lower # level if self.fmt == "huggingface": tensor = tensor.permute(0, 1, 3, 2, 4) """ expecting a tensor of shape [num_layers, 2, num_tokens, num_heads, head_size] """ ntokens = tensor.shape[2] output_dict = encode_function( tensor.cuda(), self.cachegen_config, self.key_bins, self.value_bins, ntokens, ) return output_dict.to_bytes()