Source code for lmcache.experimental.storage_backend.naive_serde.cachegen_basics

from dataclasses import dataclass
from typing import List

from transformers import AutoConfig

from lmcache.logging import init_logger
from lmcache.storage_backend.serde.cachegen_basics import QuantizationSpec

logger = init_logger(__name__)


[docs] @dataclass class CacheGenConfig: # TODO: move this class to another file like "cachegen_basics.py" nlayers: int kspecs: List[QuantizationSpec] vspecs: List[QuantizationSpec] def __getitem__(self, key: str) -> int: return getattr(self, key)
[docs] @staticmethod def from_model_name(model_name: str) -> "CacheGenConfig": family_7b = [ "mistralai/Mistral-7B-Instruct-v0.2", "lmsys/longchat-7b-16k", "Qwen/Qwen-7B" ] family_8b = ["meta-llama/Llama-3.1-8B-Instruct"] family_9b = ["THUDM/glm-4-9b-chat"] if model_name in family_7b: return CacheGenConfig( nlayers=32, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=32, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=32, bins=16), ], ) elif model_name in family_8b: return CacheGenConfig( nlayers=32, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=32, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=32, bins=16), ], ) # TODO(Jiayi): needs tuning for better quality elif model_name in family_9b: return CacheGenConfig( nlayers=40, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=40, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=40, bins=16), ], ) elif model_name == "test_model": return CacheGenConfig( nlayers=32, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec(start_layer=10, end_layer=32, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec(start_layer=2, end_layer=32, bins=16), ], ) else: try: config = AutoConfig.from_pretrained(model_name) # Default name caught by num_hidden_layers if config.num_hidden_layers is None: raise ValueError( f"num_hidden_layers is None for model {model_name}") if config.num_hidden_layers < 10: return CacheGenConfig( nlayers=config.num_hidden_layers, kspecs=[ QuantizationSpec( start_layer=0, end_layer=config.num_hidden_layers, bins=32), ], vspecs=[ QuantizationSpec( start_layer=0, end_layer=config.num_hidden_layers, bins=32), ], ) else: return CacheGenConfig( nlayers=config.num_hidden_layers, kspecs=[ QuantizationSpec(start_layer=0, end_layer=10, bins=32), QuantizationSpec( start_layer=10, end_layer=config.num_hidden_layers, bins=16), ], vspecs=[ QuantizationSpec(start_layer=0, end_layer=2, bins=32), QuantizationSpec( start_layer=2, end_layer=config.num_hidden_layers, bins=16), ], ) except Exception as e: raise ValueError( f"Model {model_name} not supported by CacheGenConfig" ) from e