diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index dec1e8189..91821989f 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -49,7 +49,7 @@ def __init__( self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 ) - self.free_blocks = torch.arange(self.num_blocks, device=device) + self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device) self.max_cache_len = max_cache_len self.num_kv_heads = config.num_key_value_heads self.num_hidden_layers = config.num_hidden_layers @@ -88,12 +88,10 @@ def update_for_prefill( all_slot_offsets = [] num_blocks = (input_lens + self.block_size - 1) // self.block_size for i in range(batch_size): - for b_idx in range(num_blocks[i]): - if self.block_tables[i][b_idx] == -1: - # need a free block - self.block_tables[i][b_idx] = self.free_blocks[0] - self.free_blocks = self.free_blocks[1:] - + nb = num_blocks[i] + block_table = self.free_blocks.nonzero().view(-1)[0:nb] + self.block_tables[i][0:nb] = block_table + self.free_blocks[block_table] = 0 slots_range = torch.arange(input_lens[i], device=key_states.device) block_indices = slots_range // self.block_size slot_offsets = slots_range % self.block_size @@ -103,7 +101,6 @@ def update_for_prefill( all_block_indices = torch.cat(all_block_indices) all_slot_offsets = torch.cat(all_slot_offsets) self.slots = all_block_indices * self.block_size + all_slot_offsets - # Update the cache PagedAttention.reshape_and_cache( key_states, @@ -127,16 +124,16 @@ def update_for_decode( ): if layer_idx == 0: start_block_idx = self._seen_tokens // self.block_size - num_blocks = (self._seen_tokens + self.block_size) // self.block_size slot_offset_in_block = (self._seen_tokens) % self.block_size self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) for i in range(batch_size): - for b_idx in range(start_block_idx[i], num_blocks[i]): + if slot_offset_in_block[i] == 0: + # need a new block: + b_idx = start_block_idx[i] if self.block_tables[i][b_idx] == -1: # need a free block - self.block_tables[i][b_idx] = self.free_blocks[0] - self.free_blocks = self.free_blocks[1:] - + self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1] + self.free_blocks[self.block_tables[i][b_idx]] = 0 self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] # Update the cache PagedAttention.reshape_and_cache( @@ -196,7 +193,7 @@ def reset(self): """Resets the cache values while preserving the objects""" self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) self.block_tables.fill_(-1) - self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device) + self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device) self.max_seq_len = 0 def reorder_cache(self, beam_idx: torch.LongTensor): @@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor): updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) num_blocks = mask.cumsum(-1)[:, -1] - updated_table = [] + updated_table = torch.zeros_like(beam_idx) for i in range(beam_idx.shape[0]): - self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1] - updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]]) - updated_table = torch.cat(tuple(updated_table), dim=0) + nb = num_blocks[i] + self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1] + updated_table[i] = self.block_tables[i][nb - 1] for layer_idx in range(self.num_hidden_layers): self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]] self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]] free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) - self.free_blocks = torch.cat((self.free_blocks, free_table)) + for i in free_table: + if not (self.block_tables == i).any(): + self.free_blocks[i] = 1 def crop(self, maximum_length: int): """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be @@ -235,4 +234,6 @@ def crop(self, maximum_length: int): self._seen_tokens[bs] = new_tokens self.max_seq_len, _ = self._seen_tokens.max(dim=0) free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1)) - self.free_blocks = torch.cat((self.free_blocks, free_table)) + for i in free_table: + if not (self.block_tables == i).any(): + self.free_blocks[i] = 1