Source code for lmcache.utils
import hashlib
from dataclasses import dataclass
from typing import Tuple
import torch
from nvtx import annotate # type: ignore
# Type definition
KVCache = Tuple[Tuple[torch.Tensor, torch.Tensor], ...]
TORCH_DTYPE_TO_STR_DTYPE = {
torch.half: "half",
torch.float16: "half",
torch.bfloat16: "bfloat16",
torch.float: "float",
torch.float32: "float",
torch.float64: "double",
torch.double: "double",
torch.uint8: "fp8",
torch.float8_e4m3fn: "fp8_e4m3",
torch.float8_e5m2: "fp8_e5m2",
}
[docs]
@dataclass
class CacheEngineKey:
fmt: str
model_name: str
world_size: int
worker_id: int
chunk_hash: str
def __hash__(self):
return hash((
self.fmt,
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
))
[docs]
def to_string(self):
return f"{self.fmt}@{self.model_name}@{self.world_size}"\
f"@{self.worker_id}@{self.chunk_hash}"
[docs]
@staticmethod
def from_string(s):
parts = s.split("@")
if len(parts) != 5:
raise ValueError(f"Invalid key string: {s}")
return CacheEngineKey(parts[0], parts[1], int(parts[2]), int(parts[3]),
parts[4])
##### NVTX annotation #####
_NVTX_COLORS = ["green", "blue", "purple", "rapids"]
def _get_color_for_nvtx(name):
m = hashlib.sha256()
m.update(name.encode())
hash_value = int(m.hexdigest(), 16)
idx = hash_value % len(_NVTX_COLORS)
return _NVTX_COLORS[idx]
def _lmcache_nvtx_annotate(func, domain="lmcache"):
"""Decorator for applying nvtx annotations to methods in lmcache."""
return annotate(
message=func.__qualname__,
color=_get_color_for_nvtx(func.__qualname__),
domain=domain,
)(func)