from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Optional, Tuple
import torch
from lmcache.blend.interfaces import (BlendRetriever, BlendRetrieverResult,
BlendRetrieverTask)
from lmcache.cache_engine import LMCacheEngine
from lmcache.config import LMCacheEngineMetadata
from lmcache.logging import init_logger
logger = init_logger(__name__)
[docs]
class SPTBlendRetrieverTask(BlendRetrieverTask):
def __init__(self, token_segments: List[torch.Tensor], tasks: List[Future],
fmt: str):
"""Initialize the SBT retriever task by the futures and corresponding
token segments.
The result of tasks should be the Tuple[torch.Tensor, int] and the
shape of the tensor L2HTD or L2THD
"""
assert len(token_segments) == len(tasks), \
"The number of token segments and tasks should match."
self.token_segments = token_segments
self.tasks = tasks
self.fmt = fmt
self.rebuilt_key: Optional[torch.Tensor] = None
self.rebuilt_value: Optional[torch.Tensor] = None
self.valid_mask: Optional[torch.Tensor] = None
self.rebuilt_positions: Optional[torch.Tensor] = None
@staticmethod
def _PrepareOutputTensor(
fmt: str,
input_tensor: torch.Tensor,
real_length: int,
expected_length: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Input tensor is L2THD or L2HTD depending on fmt
Output tensor is K and V with shape LTH or LHT depending on fmt
Could also be None, None if nothing is retrieved
"""
if real_length == expected_length:
return input_tensor[:, 0, ...], input_tensor[:, 1, ...]
if real_length == 0:
return None, None
ret_shape = list(input_tensor.shape)
match fmt:
case "vllm":
ret_shape[2] = expected_length
case "huggingface":
ret_shape[3] = expected_length
case _:
raise ValueError(f"Unknown KV format {fmt}")
ret_tensor = torch.empty(ret_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
match fmt:
case "vllm":
ret_tensor[:, :, :real_length, ...] = input_tensor
case "huggingface":
ret_tensor[:, :, :, :real_length, ...] = input_tensor
case _:
raise ValueError(f"Unknown KV format {fmt}")
return ret_tensor[:, 0, ...], ret_tensor[:, 1, ...]
def _wait_for_result(self):
"""Wait for the results of the tasks and rebuild the K and V tensors.
"""
keys = []
values = []
valid_masks = []
all_positions = []
num_layers = None
num_heads = None
head_size = None
dtype = None
device = None
def update_shape(kv, fmt):
nonlocal num_layers, num_heads, head_size, dtype, device
num_layers = kv.shape[0]
head_size = kv.shape[-1]
num_heads = kv.shape[3] if fmt == "vllm" else kv.shape[2]
dtype = kv.dtype
device = kv.device
for token_segment, task in zip(self.token_segments, self.tasks):
kv, ret_mask = task.result()
length = int(torch.sum(ret_mask))
if length > 0:
update_shape(kv, self.fmt)
k, v = self._PrepareOutputTensor(self.fmt, kv, length,
len(token_segment))
valid_mask = torch.zeros(len(token_segment),
dtype=torch.int,
device="cpu")
valid_mask[:length] = 1
positions = torch.zeros(len(token_segment),
dtype=torch.int,
device="cpu")
positions[:length] = torch.arange(length)
keys.append(k)
values.append(v)
valid_masks.append(valid_mask)
all_positions.append(positions)
# Create valid mask and rebuilt positions before returning
self.valid_mask = torch.cat(valid_masks, dim=0)
self.rebuilt_positions = torch.cat(all_positions, dim=0)
# return if nothing is retrieved
if num_layers is None:
return
match self.fmt:
case "vllm":
token_dim = 1
shape_placeholder = [num_layers, 0, num_heads, head_size]
case "huggingface":
token_dim = 2
shape_placeholder = [num_layers, num_heads, 0, head_size]
case _:
raise ValueError(f"Unknown KV format {self.fmt}")
# Update the shape of the None tensors
for i, (k, v) in enumerate(zip(keys, values)):
shape_placeholder[token_dim] = len(self.token_segments[i])
if k is None:
keys[i] = torch.empty(shape_placeholder,
dtype=dtype,
device=device)
if v is None:
values[i] = torch.empty(shape_placeholder,
dtype=dtype,
device=device)
# NOTE: mypy will complain about the element of rebuilt_key
# and rebuilt_value could be None, but it is not the case
self.rebuilt_key = torch.cat(keys, dim=token_dim) # type: ignore
self.rebuilt_value = torch.cat(values, dim=token_dim) # type: ignore
[docs]
def result(self, layer_id: int) -> BlendRetrieverResult:
"""Blocking function to get a single layer of K and V tensor.
The returned the K and V tensor should match the length of the
input tokens passed to the `BlendRetriever.new_request` function.
:param int layer_id: the layer id
:return: Tuple of K and V tensor
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
if self.valid_mask is None:
self._wait_for_result()
assert self.valid_mask is not None
assert self.rebuilt_positions is not None
ret = BlendRetrieverResult(
k = self.rebuilt_key[layer_id] \
if self.rebuilt_key is not None else None,
v = self.rebuilt_value[layer_id] \
if self.rebuilt_value is not None else None,
valid_mask = self.valid_mask,
original_positions = self.rebuilt_positions)
return ret
[docs]
class SPTBlendRetriever(BlendRetriever):
"""Implement the retrieval logic using "SPecial Token" (SPT) as delimiter.
This implementation assumes that there MUST be a special token at the end
of the input text chunk.
Example:
Input = [x, x, x, spt, y, y, spt, z, z, z, z]
Requests sent to LMCache engine when using drop_spt_and_get_indices
and new_request:
- [x, x, x]
- [y, y]
- [z, z, z, z]
Therefore, to use this retriever, the text chunks are better to also be
ended with the special token.
"""
def __init__(
self,
spt: List[int],
cache_engine: LMCacheEngine,
metadata: LMCacheEngineMetadata,
):
"""Initialize the SPT retriever.
:param List[int] spt: The special token to use as delimiter
:param LMCacheEngine cache_engine: The cache engine to retrieve
the KV caches
:param LMCacheEngineMetadata metadata: The metadata of the cache engine
"""
self.spt = spt
self.cache_engine = cache_engine
self.metadata = metadata
[docs]
def drop_spt_and_get_indices(
self, full_prompt: List[int]) -> Tuple[List[int], List[int]]:
"""Drop the special token and get the indices of the split requests.
:param List[int] full_prompt: The full prompt after tokenization.
:return: The new prompts without the special token and the indices of
the split segments.
The indices is recording the start of each segment, ending with
the end of the full prompt.
e.g. [0, index_of_segment2, len(full_prompt)]
"""
spt_len = len(self.spt)
assert spt_len >= 1
i = 0
splitted_tokens = []
start = 0
while True:
next_len = i + spt_len
if next_len > len(full_prompt):
break
if full_prompt[i:next_len] == self.spt:
splitted_tokens.append(full_prompt[start:i])
start = next_len
i = next_len
else:
i += 1
if start < len(full_prompt):
splitted_tokens.append(full_prompt[start:])
new_prompt = []
new_indices = []
this_seg_start = 0
for split in splitted_tokens:
new_prompt.extend(split)
new_indices.append(this_seg_start + len(split))
this_seg_start = new_indices[-1]
if len(new_indices) > 0:
new_indices.pop()
return new_prompt, new_indices
[docs]
def new_request(
self,
full_prompts: List[torch.Tensor],
indices: List[List[int]],
) -> BlendRetrieverTask:
"""Create a new BlendRetrieverTask to retrieve the KV caches.
It may launch async tasks in the background during the retrieval.
:param List[torch.Tensor] full_prompts: The full prompts for each
request in this batch, which will contain the tokens
hitting the vLLM's internal prefix caching.
:param List[List[int]] indices: The indices of where the
segmengted requests start in the full prompts.
:return: The retriever task to retrieve the KV caches
:rtype: BlendRetrieverTask
"""
assert len(full_prompts) == len(indices)
with ThreadPoolExecutor(max_workers=1) as executor:
splitted_tokens: List[torch.Tensor] = []
for prompt_idx, prompt in enumerate(full_prompts):
prompt_indices = indices[prompt_idx]
splitted_tokens.extend(
torch.tensor_split(prompt, prompt_indices))
logger.debug("Split input tokens into %d requests",
len(splitted_tokens))
tasks = [
executor.submit(
self.cache_engine.retrieve,
tokens, # tokens
None, # mask
False, # return_tuple
) for tokens in splitted_tokens if len(tokens) > 0
]
return SPTBlendRetrieverTask(token_segments=splitted_tokens,
tasks=tasks,
fmt=self.metadata.fmt)