Source code for lmcache.experimental.token_database
import abc
import hashlib
from typing import Iterable, Optional, Tuple
import torch
from lmcache.config import LMCacheEngineMetadata
from lmcache.experimental.config import LMCacheEngineConfig
from lmcache.utils import CacheEngineKey
[docs]
class TokenDatabase(metaclass=abc.ABCMeta):
"""TokenDatabase is used to convert input tokens into list of
cache engine keys. There are multiple ways to implement this:
- ChunkedTokenDatabase: It processes tokens into chunks and convert
each chunk into a cache engine key using prefix hash.
- RadixTokenDatabase: more advanced implementation using radix tree.
"""
[docs]
@abc.abstractmethod
def process_tokens(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, CacheEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param torch.Tensor tokens: The tokens to process, in 1-D CPU tensor.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key for the tokens.
"""
raise NotImplementedError
[docs]
class ChunkedTokenDatabase(TokenDatabase):
def __init__(self, config: LMCacheEngineConfig,
metadata: LMCacheEngineMetadata):
self.chunk_size = config.chunk_size
self.metadata = metadata
def _make_key_by_hash(self, chunk_hash: str):
return CacheEngineKey(self.metadata.fmt, self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id, chunk_hash)
def _get_init_hash(self) -> str:
return ""
def _hash(
self,
tokens: torch.Tensor,
prefix_hash: str,
) -> str:
# TODO: change it to a more efficient hash function
return hashlib.sha256(
prefix_hash.encode("ascii") +
tokens.cpu().numpy().tobytes()).hexdigest()
def _chunk_tokens(
self,
tokens: torch.Tensor,
) -> Iterable[torch.Tensor]:
"""
Chunk the tokens into chunks of size self.chunk_size.
:param tokens: the input tokens, with shape [seq_len]
device: the target device after chunking
:return: a generator of chunks of tokens, each with
shape [chunk_size]
"""
for i in range(0, len(tokens), self.chunk_size):
yield tokens[i:i + self.chunk_size]
def _prefix_hash(
self,
token_chunks: Iterable[torch.Tensor],
) -> Iterable[str]:
prefix_hash = self._get_init_hash()
for token_chunk in token_chunks:
prefix_hash = self._hash(token_chunk, prefix_hash)
yield prefix_hash
[docs]
def process_tokens(
self,
tokens: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, CacheEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param torch.Tensor tokens: The tokens to process, in 1-D CPU tensor.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
num_falses = mask.numel() - mask.long().sum()
else:
num_falses = 0
if num_falses % self.chunk_size != 0:
raise ValueError("The number of Falses in the mask is not a "
"multiple of the chunk size.")
total_len = len(tokens)
token_chunks = self._chunk_tokens(tokens)
prefix_hashes = self._prefix_hash(token_chunks)
start_idx = 0
for chunk_id, hash_val in enumerate(prefix_hashes):
start_idx = chunk_id * self.chunk_size
end_idx = min(start_idx + self.chunk_size, total_len)
if start_idx < num_falses:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)