from typing import Callable, Optional, Tuple
import torch
from lmcache.blend.interfaces import BlendExecutor, BlendOutput
from lmcache.logging import init_logger
logger = init_logger(__name__)
# TODO: add configuration item
[docs]
def mask_to_indices(mask):
indices = mask.nonzero(as_tuple=True)[0]
return indices
[docs]
def indices_to_mask(indices, size):
mask = torch.zeros(size, dtype=torch.long)
mask[indices] = 1
return mask
[docs]
def create_index(ndims, target_dim, index):
index_obj = [slice(None)] * ndims
index_obj[target_dim] = index
return tuple(index_obj)
PositionalEncoder = Callable[[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor]]
[docs]
class CacheBlendImpl(BlendExecutor):
def __init__(
self,
recompute_ratio: float,
):
self.recompute_ratio = recompute_ratio
# Indexes in the retrieved_kv of the tokens from the fresh_q
self.indexes_in_kv = torch.tensor([], dtype=torch.long, device="cpu")
self.positional_encoder: Optional[PositionalEncoder] = None
self.reverse_positional_encoder: Optional[PositionalEncoder] = \
None
[docs]
def set_positional_encoder(self, positional_encoder: PositionalEncoder):
self.positional_encoder = positional_encoder
[docs]
def set_reverse_positional_encoder(
self, reverse_positional_encoder: PositionalEncoder):
self.reverse_positional_encoder = reverse_positional_encoder
def _select_tokens_single_query(self, rk: torch.Tensor, rv: torch.Tensor,
valid: torch.Tensor, fq: torch.Tensor,
fk: torch.Tensor, fv: torch.Tensor,
token_dim: int) -> torch.Tensor:
"""
Input: retrieved KV, valid_mask, and fresh QKV for a single query
Output: selected tokens indices
"""
# We compare the retrieved KVs with the fresh KVs and keep the
# following tokens:
# 1. Invalid tokens
# 2. Token with top difference in the fresh KV, if the token is
# valid. Based on previous CacheBlend implementation, we only
# use V to compare the difference. The number of tokens to
# keep is determined by the `recompute_ratio`
assert fk.shape == rk.shape
assert fv.shape == rv.shape
# Find the top different tokens
dims_to_average = [i for i in range(fv.dim()) if i != token_dim]
diff_per_token = torch.mean((fv - rv)**2, dims_to_average)
diff_per_token = diff_per_token * valid.to(diff_per_token.device)
num_valid_tokens = valid.sum()
num_selected_tokens = int(num_valid_tokens * self.recompute_ratio)
top_indices = torch.topk(diff_per_token, num_selected_tokens).indices
#logger.debug(f"Indices of the top differences: {top_indices}")
# Merge the positions with the invalid tokens
top_mask = indices_to_mask(top_indices, valid.shape[0])
total_selected_mask = (1 - valid) + top_mask
local_indices = mask_to_indices(total_selected_mask)
#logger.debug(f"Local indices of the selected tokens: {local_indices}")
return local_indices
def _build_positions(self, query_start_loc: torch.Tensor,
device) -> torch.Tensor:
"""Rebuild the positions based on the query start locs
"""
#ret = torch.arange(int(query_start_loc[-1]), device=device)
ret = torch.arange(query_start_loc[-1], device=device) # type: ignore
for start, end in zip(query_start_loc[:-1], query_start_loc[1:]):
ret[start:end] -= start
return ret.long()
[docs]
def blend(
self,
layer_id: int,
retrieved_k: torch.Tensor,
retrieved_v: torch.Tensor,
valid_mask: torch.Tensor,
original_positions: torch.Tensor,
fresh_q: torch.Tensor,
fresh_k: torch.Tensor,
fresh_v: torch.Tensor,
positions: torch.Tensor,
query_start_loc: torch.Tensor,
token_dim: int,
) -> BlendOutput:
"""This function blends the retrieved KV with fresh KVs, and
returns the short Q + long KV (blended) + positions of the tokens in Q
:param int layer_id: The layer id
:param torch.Tensor retrieved_k: The retrieved K layer, in shape
[num_tokens, hidden_dims]
:param torch.Tensor retrieved_v: The retrieved V layer, in shape
[num_tokens, hidden_dims]
:param torch.Tensor valid_mask: A CPU tensor returned from the
retriever indicating whether the KV is valid.
:param torch.Tensor original_positions: The original positions of the
tokens in the retrieved KV
:param torch.Tensor fresh_q: The fresh Q tensor from QKV split,
in shape [num_tokens, hidden_dims]
:param torch.Tensor fresh_k: The fresh K tensor from QKV split,
in shape [num_tokens, hidden_dims]
:param torch.Tensor fresh_v: The fresh V tensor from QKV split,
in shape [num_tokens, hidden_dims]
:param torch.Tensor positions: The positions in the input of the
tokens in the fresh_q
:param torch.Tensor query_start_loc: The start location of the query if
input_tokens has multiple requests in a batch. The length should be
the number of requests in the batch + 1. Note this will NOT be
changed after token selection.
:param int token_dim: The token dimension
:return: The blended Q, K, V, and positions
"""
# We should convert the shape of KV to [num_elems, hidden_dimensions]
assert valid_mask.is_cpu, "valid_mask should be on CPU"
if layer_id == 0:
return BlendOutput(fresh_q,
fresh_k,
fresh_v,
positions,
torch.arange(fresh_q.shape[token_dim],
device="cpu",
dtype=torch.long),
query_start_loc=None)
elif layer_id == 1:
new_query_start_locs = [0]
for qstart, qend in zip(query_start_loc[:-1], query_start_loc[1:]):
# Select the tokens for each query
local_indices = self._select_tokens_single_query(
retrieved_k[qstart:qend], retrieved_v[qstart:qend],
valid_mask[qstart:qend], fresh_q[qstart:qend],
fresh_k[qstart:qend], fresh_v[qstart:qend], token_dim)
new_query_start_locs.append(new_query_start_locs[-1] +
len(local_indices))
self.indexes_in_kv = torch.cat(
(self.indexes_in_kv, local_indices + int(qstart)))
new_q = fresh_q[self.indexes_in_kv]
new_positions = positions[self.indexes_in_kv]
query_start_locs_tensor = torch.tensor(
new_query_start_locs,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
logger.info(f"Selected {len(self.indexes_in_kv)} tokens out of "
f"{len(retrieved_k)} tokens to blend")
return BlendOutput(new_q, fresh_k, fresh_v, new_positions,
self.indexes_in_kv, query_start_locs_tensor)
else:
assert len(self.indexes_in_kv) == fresh_k.shape[token_dim]
index_obj = create_index(fresh_k.dim(), token_dim,
self.indexes_in_kv)
if self.positional_encoder is not None and \
self.reverse_positional_encoder is not None:
# Clear the positional encoding
dumb_q = torch.zeros(retrieved_k.shape,
device=fresh_q.device,
dtype=fresh_q.dtype)
dumb_q, rk_no_position = self.reverse_positional_encoder(
original_positions.to(device=retrieved_k.device,
dtype=torch.long), dumb_q,
retrieved_k)
# Re-apply positional encodings based on query_start_loc
new_positions = self._build_positions(query_start_loc,
device=fresh_q.device)
dumb_q, rk_with_position = self.positional_encoder(
new_positions, dumb_q, rk_no_position)
else:
logger.warning("Positional encoder and reverse positional "
"encoder is not set. This may lead to "
"incorrect results.")
rk_with_position = retrieved_k
rk_with_position[index_obj] = fresh_k
retrieved_v[index_obj] = fresh_v
return BlendOutput(fresh_q, rk_with_position, retrieved_v,
positions, self.indexes_in_kv, None)