from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Optional, Tuple
import torch
from lmcache.blend.interfaces import (BlendRetriever, BlendRetrieverResult,
BlendRetrieverTask)
from lmcache.cache_engine import LMCacheEngine
from lmcache.config import LMCacheEngineMetadata
from lmcache.logging import init_logger
logger = init_logger(__name__)
[docs]
class SPTBlendRetrieverTask(BlendRetrieverTask):
def __init__(self, token_segments: List[torch.Tensor], tasks: List[Future],
fmt: str):
"""Initialize the SBT retriever task by the futures and corresponding
token segments.
The result of tasks should be the Tuple[torch.Tensor, int] and the
shape of the tensor L2HTD or L2THD
"""
assert len(token_segments) == len(tasks), \
"The number of token segments and tasks should match."
self.token_segments = token_segments
self.tasks = tasks
self.fmt = fmt
self.rebuilt_key: Optional[torch.Tensor] = None
self.rebuilt_value: Optional[torch.Tensor] = None
self.valid_mask: Optional[torch.Tensor] = None
self.rebuilt_positions: Optional[torch.Tensor] = None
@staticmethod
def _PrepareOutputTensor(
fmt: str,
input_tensor: torch.Tensor,
real_length: int,
expected_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Input tensor is L2THD or L2HTD depending on fmt
Output tensor is K and V with shape LTH or LHT depending on fmt
Could also be None, None if nothing is retrieved
"""
if real_length == expected_length:
return input_tensor[:, 0, ...], input_tensor[:, 1, ...]
if real_length == 0:
return None, None
ret_shape = list(input_tensor.shape)
match fmt:
case "vllm":
ret_shape[2] = expected_length
case "huggingface":
ret_shape[3] = expected_length
case _:
raise ValueError(f"Unknown KV format {fmt}")
ret_tensor = torch.empty(ret_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
match fmt:
case "vllm":
ret_tensor[:, :, :real_length, ...] = input_tensor
case "huggingface":
ret_tensor[:, :, :, :real_length, ...] = input_tensor
case _:
raise ValueError(f"Unknown KV format {fmt}")
return ret_tensor[:, 0, ...], ret_tensor[:, 1, ...]
def _wait_for_result(self):
"""Wait for the results of the tasks and rebuild the K and V tensors.
"""
keys = []
values = []
valid_masks = []
all_positions = []
num_layers = None
num_heads = None
head_size = None
dtype = None
device = None
def update_shape(kv, fmt):
nonlocal num_layers, num_heads, head_size, dtype, device
num_layers = kv.shape[0]
head_size = kv.shape[-1]
num_heads = kv.shape[3] if fmt == "vllm" else kv.shape[2]
dtype = kv.dtype
device = kv.device
for token_segment, task in zip(self.token_segments, self.tasks):
kv, ret_mask = task.result()
length = int(torch.sum(ret_mask))
if length > 0:
update_shape(kv, self.fmt)
k, v = self._PrepareOutputTensor(self.fmt, kv, length,
len(token_segment))
valid_mask = torch.zeros(len(token_segment),
dtype=torch.int,
device="cpu")
valid_mask[:length] = 1
positions = torch.zeros(len(token_segment),
dtype=torch.int,
device="cpu")
positions[:length] = torch.arange(length)
keys.append(k)
values.append(v)
valid_masks.append(valid_mask)
all_positions.append(positions)
# Create valid mask and rebuilt positions before returning
self.valid_mask = torch.cat(valid_masks, dim=0)
self.rebuilt_positions = torch.cat(all_positions, dim=0)
# return if nothing is retrieved
if num_layers is None:
return
match self.fmt:
case "vllm":
token_dim = 1
shape_placeholder = [num_layers, 0, num_heads, head_size]
case "huggingface":
token_dim = 2
shape_placeholder = [num_layers, num_heads, 0, head_size]
case _:
raise ValueError(f"Unknown KV format {self.fmt}")
# Update the shape of the None tensors
for i, (k, v) in enumerate(zip(keys, values)):
shape_placeholder[token_dim] = len(self.token_segments[i])
if k is None:
keys[i] = torch.empty(shape_placeholder,
dtype=dtype,
device=device)
if v is None:
values[i] = torch.empty(shape_placeholder,
dtype=dtype,
device=device)
# NOTE: mypy will complain about the element of rebuilt_key
# and rebuilt_value could be None, but it is not the case
self.rebuilt_key = torch.cat(keys, dim=token_dim) # type: ignore
self.rebuilt_value = torch.cat(values, dim=token_dim) # type: ignore
[docs]
def result(self, layer_id: int) -> BlendRetrieverResult:
"""Blocking function to get a single layer of K and V tensor.
The returned the K and V tensor should match the length of the
input tokens passed to the `BlendRetriever.new_request` function.
:param int layer_id: the layer id
:return: Tuple of K and V tensor
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
if self.valid_mask is None:
self._wait_for_result()
assert self.valid_mask is not None
assert self.rebuilt_positions is not None
ret = BlendRetrieverResult(
k = self.rebuilt_key[layer_id] \
if self.rebuilt_key is not None else None,
v = self.rebuilt_value[layer_id] \
if self.rebuilt_value is not None else None,
valid_mask = self.valid_mask,
original_positions = self.rebuilt_positions)
return ret
[docs]
class SPTBlendRetriever(BlendRetriever):
"""Implement the retrieval logic using "SPecial Token" (SPT) as delimiter.
This implementation assumes that there MUST be a special token at the end
of the input text chunk.
Example:
Input = [x, x, x, spt, y, y, spt, z, z, z, z]
Requests sent to LMCache engine:
- [x, x, x, spt]
- [y, y, spt]
- [z, z, z, z]
Therefore, to use this retriever, the text chunks are better to also be
ended with the special token.
"""
def __init__(
self,
spt: torch.Tensor,
cache_engine: LMCacheEngine,
metadata: LMCacheEngineMetadata,
):
"""Initialize the SPT retriever.
:param torch.Tensor spt: The special token to use as delimiter
:param LMCacheEngine cache_engine: The cache engine to retrieve
the KV caches
:param LMCacheEngineMetadata metadata: The metadata of the cache engine
"""
self.spt = spt
self.cache_engine = cache_engine
self.metadata = metadata
def _split_input_tokens(self, input_tokens_single_query: torch.Tensor):
"""Split the input tokens into multiple requests based on the ROI.
Returns a list of split tokens for cache_engine to retrieve
"""
spt_len = len(self.spt)
if spt_len == 1:
indices = (
input_tokens_single_query == self.spt).nonzero().squeeze()
else:
windows = input_tokens_single_query.unfold(0, spt_len, 1)
indices = (windows == self.spt).all(dim=1).nonzero().squeeze()
if indices.dim() == 0:
indices = indices.unsqueeze(0)
start = 0
splitted_tokens = []
for i in indices:
splitted_tokens.append(input_tokens_single_query[start:i +
spt_len])
start = i + spt_len
if start < len(input_tokens_single_query):
splitted_tokens.append(input_tokens_single_query[start:])
return splitted_tokens
[docs]
def new_request(
self,
input_tokens: torch.Tensor,
query_start_loc: torch.Tensor,
) -> BlendRetrieverTask:
"""Create a new BlendRetrieverTask to retrieve the KV caches.
It may launch async tasks in the background during the retrieval.
:param torch.Tensor input_tokens: The input tokens, could include
multiple requests in a batch
:param torch.Tensor query_start_loc: The start location of the query if
input_tokens has multiple requests in a batch. The length should be
the number of requests in the batch + 1.
:return: The retriever task to retrieve the KV caches
:rtype: BlendRetrieverTask
"""
with ThreadPoolExecutor(max_workers=1) as executor:
splitted_tokens = []
start_loc = query_start_loc[0]
for loc in query_start_loc[1:]:
logger.debug(f"Request start loc = {start_loc}")
splitted_tokens.extend(
self._split_input_tokens(input_tokens[start_loc:loc]))
start_loc = loc
logger.debug("Split input tokens into %d requests",
len(splitted_tokens))
tasks = [
executor.submit(
self.cache_engine.retrieve,
tokens, # tokens
None, # mask
False, # return_tuple
) for tokens in splitted_tokens
]
return SPTBlendRetrieverTask(token_segments=splitted_tokens,
tasks=tasks,
fmt=self.metadata.fmt)