Skip to content

Commit

Permalink
Refactor code and fix warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
Luodian committed Nov 6, 2023
1 parent ce9a22e commit 70f06bf
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 105 deletions.
5 changes: 1 addition & 4 deletions src/otter_ai/models/flamingo/falcon/modelling_RW.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,10 +853,7 @@ def forward(
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
else:
sequence_lengths = -1
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
logger.warning(f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`")

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

Expand Down
12 changes: 2 additions & 10 deletions src/otter_ai/models/flamingo/mpt/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,7 @@ def __init__(
elif self.attn_impl == "torch":
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
"Using `attn_impl: torch`. If your model does not use `alibi` or "
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
+ "we recommend using `attn_impl: triton`."
)
warnings.warn("Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`.")
else:
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
Expand Down Expand Up @@ -362,11 +358,7 @@ def __init__(
elif self.attn_impl == "torch":
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available() and verbose:
warnings.warn(
"Using `attn_impl: torch`. If your model does not use `alibi` or "
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
+ "we recommend using `attn_impl: triton`."
)
warnings.warn("Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`.")
else:
raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
Expand Down
4 changes: 1 addition & 3 deletions src/otter_ai/models/flamingo/mpt/flash_attn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,9 +801,7 @@ def backward(ctx, do):
with torch.inference_mode():
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(
do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale
)
_flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return (dq, dkv, None, None, None)


Expand Down
44 changes: 11 additions & 33 deletions src/otter_ai/models/flamingo/mpt/hf_prefixlm_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
assert isinstance(model, BloomForCausalLM)
assert model.config.add_cross_attention == False, "Only supports BLOOM decoder-only models"

def _prepare_attn_mask(
self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
combined_attention_mask = None
device = attention_mask.device
(_, src_length) = input_shape
Expand Down Expand Up @@ -226,9 +224,7 @@ def forward(
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning
)
warnings.warn("`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. " + "You can safely ignore passing `position_ids`.", FutureWarning)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -262,12 +258,8 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = self._build_alibi_tensor(
batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device
)
causal_mask = self._prepare_attn_mask(
attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length
)
alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
hst = (hidden_states,)
Expand Down Expand Up @@ -306,9 +298,7 @@ def custom_forward(*inputs):
all_hidden_states = all_hidden_states + hst
if not return_dict:
return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions
)
return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)

setattr(model.transformer, "_prepare_attn_mask", MethodType(_prepare_attn_mask, model.transformer))
setattr(model.transformer, "_build_alibi_tensor", MethodType(_build_alibi_tensor, model.transformer))
Expand All @@ -332,9 +322,7 @@ def forward(
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
"""Replacement forward method for BloomCausalLM."""
if deprecated_arguments.pop("position_ids", False) is not False:
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning
)
warnings.warn("`position_ids` have no functionality in BLOOM and will be removed " + "in v5.0.0. You can safely ignore passing `position_ids`.", FutureWarning)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -370,9 +358,7 @@ def forward(
attentions=transformer_outputs.attentions,
)

def prepare_inputs_for_generation(
self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs
) -> dict:
def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> dict:
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
bidirectional_mask = None
Expand Down Expand Up @@ -409,18 +395,12 @@ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_em
if input_shape[-1] > 1:
if self.bidirectional_mask == "g":
(bsz, src_length) = input_shape
combined_attention_mask = torch.zeros(
(bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
else:
combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(
inputs_embeds.device
)
combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
if self.bidirectional_mask is not None:
assert attention_mask.shape == self.bidirectional_mask.shape
expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
if attention_mask is not None:
expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
Expand Down Expand Up @@ -568,8 +548,6 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
for i, continuation_indices in enumerate(batch["continuation_indices"]):
batch["bidirectional_mask"][i, continuation_indices] = 0
elif "labels" in batch and "attention_mask" in batch:
batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(
batch["attention_mask"]
)
batch["bidirectional_mask"] = torch.logical_and(torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)).type_as(batch["attention_mask"])
else:
raise KeyError("No bidirectional_mask in batch and not sure how to construct one.")
27 changes: 6 additions & 21 deletions src/otter_ai/models/flamingo/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def __init__(self, config: MPTConfig):
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
self.norm_f = norm_class(config.d_model, device=config.init_device)
if config.init_device != "meta":
print(
f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.'
)
print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
self.apply(self.param_init_fn)
self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False
Expand Down Expand Up @@ -146,11 +144,7 @@ def _attn_bias(
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
(s_k, s_q) = attn_bias.shape[-2:]
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
raise ValueError(
"attn_bias does not match the expected shape. "
+ f"The last two dimensions should both be {self.config.max_length} "
+ f"but are {s_k} and {s_q}."
)
raise ValueError("attn_bias does not match the expected shape. " + f"The last two dimensions should both be {self.config.max_length} " + f"but are {s_k} and {s_q}.")
seq_len = prefix_mask.shape[-1]
if seq_len > self.config.max_seq_len:
raise ValueError(f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}")
Expand Down Expand Up @@ -215,13 +209,10 @@ def forward(

if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError(
"sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode."
)
raise ValueError("sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True " + "and the model is in train mode.")
elif (self.attn_uses_sequence_id is False) and (sequence_id is not None):
warnings.warn(
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
+ "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
"MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. " + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
)

S = input_ids.size(1)
Expand All @@ -235,10 +226,7 @@ def forward(
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(
f"past_key_values must provide a past_key_value for each attention "
+ f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=})."
)
raise ValueError(f"past_key_values must provide a past_key_value for each attention " + f"layer in the network ({len(past_key_values)=}; {self.config.n_layers=}).")
# For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
# For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
# Here we shift position embedding using the `seq` dim of the past key
Expand All @@ -247,10 +235,7 @@ def forward(
past_position = past_key_values[0][0].size(3)

if S + past_position > self.config.max_seq_len:
raise ValueError(
f"Cannot forward input with past sequence length {past_position} and current sequence length "
f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
)
raise ValueError(f"Cannot forward input with past sequence length {past_position} and current sequence length " f"{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.")
pos = torch.arange(
past_position,
S + past_position,
Expand Down
5 changes: 1 addition & 4 deletions src/otter_ai/models/flamingo/mpt/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ def generic_param_init_fn_(
raise ValueError(f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}")
if init_div_is_residual is not False:
if verbose > 1:
warnings.warn(
f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
+ f"Set `init_div_is_residual: false` in init config to disable this."
)
warnings.warn(f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. " + f"Set `init_div_is_residual: false` in init config to disable this.")
if isinstance(module, nn.Linear):
if hasattr(module, "_fused"):
fused_init_helper_(module, init_fn_)
Expand Down
6 changes: 1 addition & 5 deletions src/otter_ai/models/flamingo/mpt_redpajama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,7 @@ def __init__(
elif self.attn_impl == "torch":
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available():
warnings.warn(
"Using `attn_impl: torch`. If your model does not use `alibi` or "
+ "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
+ "we recommend using `attn_impl: triton`."
)
warnings.warn("Using `attn_impl: torch`. If your model does not use `alibi` or " + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " + "we recommend using `attn_impl: triton`.")
else:
raise ValueError(f"{attn_impl=} is an invalid setting.")

Expand Down
Loading

0 comments on commit 70f06bf

Please sign in to comment.