Source code for lmcache.storage_backend.serde.cachegen_basics

import io
import pickle
from dataclasses import dataclass
from typing import List

import torch
from transformers import AutoConfig

from lmcache.logging import init_logger
from lmcache.utils import _lmcache_nvtx_annotate

logger = init_logger(__name__)

CACHEGEN_GPU_MAX_TOKENS_PER_CHUNK = 256


[docs] @dataclass class QuantizationSpec: start_layer: int end_layer: int bins: int def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] @dataclass class CacheGenConfig: # TODO: move this class to another file like "cachegen_basics.py" nlayers: int kspecs: List[QuantizationSpec] vspecs: List[QuantizationSpec] def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] @staticmethod def from_model_name(model_name: str) -> "CacheGenConfig": family_7b = [ "mistralai/Mistral-7B-Instruct-v0.2", "lmsys/longchat-7b-16k", "Qwen/Qwen-7B" ] family_8b = ["meta-llama/Llama-3.1-8B-Instruct"] family_9b = ["THUDM/glm-4-9b-chat"] if model_name in family_7b: return CacheGenConfig( nlayers=32, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=32, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=32, bins=16), ], ) elif model_name in family_8b: return CacheGenConfig( nlayers=32, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=32, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=32, bins=16), ], ) # TODO(Jiayi): needs tuning for better quality elif model_name in family_9b: return CacheGenConfig( nlayers=40, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=40, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=40, bins=16), ], ) else: try: config = AutoConfig.from_pretrained(model_name) # Default name caught by num_hidden_layers if config.num_hidden_layers is None: raise ValueError( f"num_hidden_layers is None for model {model_name}") if config.num_hidden_layers < 10: return CacheGenConfig( nlayers=config.num_hidden_layers, kspecs=[ QuantizationSpec( start_layer=0, end_layer=config.num_hidden_layers, bins=32), ], vspecs=[ QuantizationSpec( start_layer=0, end_layer=config.num_hidden_layers, bins=32), ], ) else: return CacheGenConfig( nlayers=config.num_hidden_layers, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec( start_layer=10, end_layer=config.num_hidden_layers, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec( start_layer=2, end_layer=config.num_hidden_layers, bins=16), ], ) except Exception as e: raise ValueError( f"Model {model_name} not supported by CacheGenConfig" ) from e
[docs] @dataclass class CacheGenEncoderOutput: # TODO: maybe use numpy array so that we can directly tobytes() and # frombuffer() to have a better performance bytestream: bytes start_indices: torch.Tensor cdf: torch.Tensor max_tensors_key: torch.Tensor max_tensors_value: torch.Tensor num_heads: int head_size: int def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] def to_bytes(self) -> bytes: """Save the output to a file""" with io.BytesIO() as f: # torch.save(self, f) pickle.dump(self, f) return f.getvalue()
[docs] @staticmethod def from_bytes(bs: bytes) -> "CacheGenEncoderOutput": with io.BytesIO(bs) as f: return pickle.load(f)
[docs] @dataclass class CacheGenGPUBytestream: bytestream: torch.Tensor bytestream_lengths: torch.Tensor # [nlayers, nchannels, bytestream_length] ntokens: int def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] @dataclass class CacheGenGPUEncoderOutput: data_chunks: List[CacheGenGPUBytestream] cdf: torch.Tensor max_tensors_key: torch.Tensor max_tensors_value: torch.Tensor num_heads: int head_size: int def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] @_lmcache_nvtx_annotate def to_bytes(self) -> bytes: """Save the output to a file""" with io.BytesIO() as f: pickle.dump(self, f) return f.getvalue()
[docs] @staticmethod @_lmcache_nvtx_annotate def from_bytes(bs: bytes) -> "CacheGenGPUEncoderOutput": with io.BytesIO(bs) as f: return pickle.load(f)
[docs] def debug_print_device(self): logger.debug( f"bytestream device: {self.data_chunks[0].bytestream.device}") logger.debug(f"bytestream_lengths device: " f"{self.data_chunks[0].bytestream_lengths.device}") logger.debug(f"cdf device: {self.cdf.device}") logger.debug(f"max_tensors_key device: {self.max_tensors_key.device}") logger.debug( f"max_tensors_value device: {self.max_tensors_value.device}")