Source code for lmcache.utils

import hashlib
from dataclasses import dataclass
from typing import Tuple

import torch
from nvtx import annotate  # type: ignore

# Type definition
KVCache = Tuple[Tuple[torch.Tensor, torch.Tensor], ...]


[docs] @dataclass class DiskCacheMetadata: path: str size: int # in bytes
TORCH_DTYPE_TO_STR_DTYPE = { torch.half: "half", torch.float16: "half", torch.bfloat16: "bfloat16", torch.float: "float", torch.float32: "float", torch.float64: "double", torch.double: "double", torch.uint8: "fp8", torch.float8_e4m3fn: "fp8_e4m3", torch.float8_e5m2: "fp8_e5m2", }
[docs] @dataclass class CacheEngineKey: fmt: str model_name: str world_size: int worker_id: int chunk_hash: str def __hash__(self): return hash(( self.fmt, self.model_name, self.world_size, self.worker_id, self.chunk_hash, ))
[docs] def to_string(self): return f"{self.fmt}@{self.model_name}@{self.world_size}"\ f"@{self.worker_id}@{self.chunk_hash}"
[docs] @staticmethod def from_string(s): parts = s.split("@") if len(parts) != 5: raise ValueError(f"Invalid key string: {s}") return CacheEngineKey(parts[0], parts[1], int(parts[2]), int(parts[3]), parts[4])
##### NVTX annotation ##### _NVTX_COLORS = ["green", "blue", "purple", "rapids"] def _get_color_for_nvtx(name): m = hashlib.sha256() m.update(name.encode()) hash_value = int(m.hexdigest(), 16) idx = hash_value % len(_NVTX_COLORS) return _NVTX_COLORS[idx] def _lmcache_nvtx_annotate(func, domain="lmcache"): """Decorator for applying nvtx annotations to methods in lmcache.""" return annotate( message=func.__qualname__, color=_get_color_for_nvtx(func.__qualname__), domain=domain, )(func)