interactSpeech / swift /utils /torchacc_utils.py
Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
import types
from typing import List, Optional, Tuple
import safetensors
import torch
import torch.nn.functional as F
import transformers
from packaging import version
from peft import PeftModel
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, trainer
from transformers.modeling_utils import unwrap_model
from swift.utils import get_logger, torchacc_trim_graph, use_torchacc
logger = get_logger()
# DataLoader
def get_bucket_sizes(max_length: int) -> List[int]:
"""Get the bucket sizes for TorchAcc.
You can set the environment variable TORCHACC_DATA_BUCKETS to specify
the bucket sizes. If not set, we use a normal distribution bucketing with
8 buckets.
"""
padding_p_base = 2
if os.getenv('TORCHACC_DATA_BUCKETS') is not None:
bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')]
bucket_sizes.append(max_length)
else:
if os.getenv('TORCHACC_CACHE_PATH') is not None: # padding strategy when persistent cache is enabled
padding_p_base = 1.4
padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base)
try:
padding_p_base = float(padding_p_base)
except ValueError as e:
logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}')
raise e
bucket_sizes = [16, 32, 48, 64, 96, 128]
base_size = 256
while base_size < max_length:
bucket_sizes.append((int(base_size) + 127) // 128 * 128)
base_size *= padding_p_base
bucket_sizes.append(max_length)
return bucket_sizes
def _get_closet_bucket(bucket_sizes, data_length):
"""Select the one from bucket_sizes that is closest in distance to
data_length. This is required for TorchAcc.
"""
closest_length = sys.maxsize
for b in bucket_sizes:
if b == data_length or ((b < closest_length) and (b > data_length)):
closest_length = b
if closest_length == sys.maxsize:
bucket_sizes.append(data_length)
closest_length = data_length
return closest_length
def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank,
world_size, padding_right):
if padding_to is None:
longest_len = input_ids.shape[-1]
bucket_sizes = get_bucket_sizes(max_length)
bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len)
padding_length = bucket_data_length - input_ids.shape[1]
pad_tuple = (0, padding_length) if padding_right else (padding_length, 0)
input_ids = F.pad(input_ids, pad_tuple, 'constant', tokenizer.pad_token_id)
attention_mask = F.pad(attention_mask, pad_tuple, 'constant', 0)
if loss_scale:
loss_scale = F.pad(loss_scale, pad_tuple, 'constant', 0.)
labels = F.pad(labels, pad_tuple, 'constant', -100)
# manually split the batch to different DP rank.
batch_size = input_ids.shape[0] // world_size
if batch_size > 0:
start = rank * batch_size
end = (rank + 1) * batch_size
input_ids = input_ids[start:end, :]
attention_mask = attention_mask[start:end, :]
labels = labels[start:end, :]
if loss_scale:
loss_scale = loss_scale[start:end, :]
return input_ids, attention_mask, labels, loss_scale
def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size):
# patch skip_first_batches for customized dataloader.
def acc_skip_first_batches(dataloader, num_batches=0):
from accelerate.data_loader import SkipBatchSampler
batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches)
try:
dataset = dataloader.dataset
except AttributeError:
dataset = dataloader._loader.dataset
dataloader_params = {
'collate_fn': data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params['batch_sampler'] = batch_sampler
dataloader_params['worker_init_fn'] = trainer.seed_worker
return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device)
trainer.skip_first_batches = acc_skip_first_batches
# dataloader for TorchAcc.
import torchacc as ta
dataloader_params = {
'batch_size': batch_size,
'collate_fn': data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params['sampler'] = sampler
dataloader_params['drop_last'] = args.dataloader_drop_last
dataloader_params['worker_init_fn'] = trainer.seed_worker
return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device)
def ta_eval_dataloader(eval_dataset, data_collator, sampler, args):
import torchacc as ta
dataloader_params = {
'batch_size': args.eval_batch_size,
'collate_fn': data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
}
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params['sampler'] = sampler
dataloader_params['drop_last'] = args.dataloader_drop_last
return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device)
def ta_test_dataloader(test_dataset, data_collator, sampler, args):
import torchacc as ta
dataloader_params = {
'batch_size': args.eval_batch_size,
'collate_fn': data_collator,
'num_workers': args.dataloader_num_workers,
'pin_memory': args.dataloader_pin_memory,
'persistent_workers': args.dataloader_persistent_workers,
}
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params['sampler'] = sampler
dataloader_params['drop_last'] = args.dataloader_drop_last
# We use the same batch_size as for eval.
return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device)
# Save/load checkpoint
def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir):
import torch_xla.core.xla_model as xm
xm.rendezvous('saving_optimizer_states')
xm.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt'), master_only=False)
xm.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt'), master_only=False)
xm.rendezvous('saving_optimizer_states_done')
def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device):
import torch_xla.core.xla_model as xm
optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu')
lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu')
xm.send_cpu_data_to_device(optimizer_state, device)
xm.send_cpu_data_to_device(lr_scheduler_state, device)
optimizer.load_state_dict(optimizer_state)
lr_scheduler.load_state_dict(lr_scheduler_state)
return optimizer, lr_scheduler
def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else args.output_dir
import torch_xla.core.xla_model as xm
model = self_model
if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
xm.mark_step()
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
supported_classes = (PreTrainedModel, PeftModel)
if not isinstance(model, supported_classes):
if isinstance(unwrap_model(model), supported_classes):
unwrap_model(model).save_pretrained(
output_dir,
is_main_process=args.should_save,
state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
save_function=xm.save,
safe_serialization=args.save_safetensors,
)
else:
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
state_dict = xm._maybe_convert_to_cpu(model.state_dict())
if args.save_safetensors:
safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
else:
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
else:
model.save_pretrained(
output_dir,
is_main_process=args.should_save,
save_function=xm.save,
safe_serialization=args.save_safetensors,
state_dict=xm._maybe_convert_to_cpu(model.state_dict()))
if tokenizer is not None and args.should_save:
tokenizer.save_pretrained(output_dir)
def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir):
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
xm.mark_step()
if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
supported_classes = (PreTrainedModel, PeftModel)
model = self_model._get_underlay_model().module.module
unwrapped_model = unwrap_model(model)
xm.rendezvous('saving_checkpoint')
ckpt = {
'model': self_model._get_underlay_model().state_dict(),
'shard_metadata': self_model._get_underlay_model().get_shard_metadata(),
}
if isinstance(model, PeftModel):
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin')
else:
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin')
xm.save(ckpt, ckpt_path, master_only=False)
# Make sure all ranks have saved checkpoints
xm.rendezvous('save_full_checkpoints')
if tokenizer is not None and args.should_save:
tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save)
# rank 0 consolidates and saves the whole checkpoint.
if xm.is_master_ordinal(local=False):
if isinstance(model, PeftModel):
ckpt_suffix = 'rank*-of-*-adapter_model.bin'
else:
ckpt_suffix = 'rank*-of-*-pytorch_model.bin'
full_state_dict, _ = consolidate_sharded_model_checkpoints(
ckpt_prefix=os.path.join(output_dir, ''), ckpt_suffix=ckpt_suffix, save_model=False)
if isinstance(unwrapped_model, supported_classes):
unwrapped_model.save_pretrained(
output_dir,
state_dict=full_state_dict,
save_function=xm.save,
safe_serialization=args.save_safetensors,
)
else:
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
if args.save_safetensors:
safetensors.torch.save_file(full_state_dict, os.path.join(output_dir, 'model.safetensors'))
else:
torch.save(full_state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
xm.rendezvous('ckpt_consolidation')
# delete the sharded checkpoint.
os.remove(ckpt_path)
def ta_trim_graph():
if use_torchacc() and torchacc_trim_graph():
import torchacc as ta
ta.mark_step()
# Model patch
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
if position_ids is not None:
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
else:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def patch_acc_model(args, model):
if not args.use_flash_attn:
logger.warn('Currently use flash attn for torchacc.')
if args.model_type.startswith('qwen1half') or args.model_type.startswith('qwen2'):
model = patch_qwen2_model(model)
elif args.model_type.startswith('qwen'):
import torchacc as ta
model = ta.patch_qwen_model(model)
elif args.model_type.startswith('baichuan'):
model = patch_baichuan_model(model)
elif args.model_type.startswith('llama') or args.model_type.startswith('yi'):
model = patch_llama_model(model)
elif args.model_type.startswith('chatglm'):
model = patah_chatglm_model(model)
return model
def patch_llama_model(model):
def update_causal_mask(self, *args, **kwargs):
# attention_mask is not supported in TorchAcc.
return None
def llama_attn_forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
from torchacc.ops import flash_attn_varlen_xla
import einops
bsz, q_len, _ = hidden_states.size()
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2))
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2))
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, 'past_key_value is not supported'
if version.parse(transformers.__version__) >= version.parse('4.36'):
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
assert not output_attentions, 'output_attentions is not supported'
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# See https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# if attention_mask is not None:
# value_states = value_states * attention_mask.unsqueeze(1).unsqueeze(-1)
q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...')
k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...')
v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...')
max_s = q_len
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
output = flash_attn_varlen_xla(
q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True)
output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz)
return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value
for layer in model.model.layers:
layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn)
if version.parse(transformers.__version__) >= version.parse('4.38'):
model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model)
return model
def patah_chatglm_model(model):
def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3)
rot_dim = rope_cache.shape[-2] * 2
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)
def chatglm_attn_forward(self,
hidden_states,
attention_mask,
rotary_pos_emb,
kv_cache=None,
use_cache=True,
**kwargs):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head))
key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head))
value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head))
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb)
# adjust key and value for inference
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=0)
value_layer = torch.cat((cache_v, value_layer), dim=0)
if use_cache:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
if self.multi_query_attention:
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head))
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
value_layer = value_layer.contiguous().view(value_layer.size()[:2]
+ (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head))
# ==================================
# core attention computation
# ==================================
from torchacc.ops import flash_attn_varlen_qkvpacked_xla
import einops
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
bsz, _, q_len, _ = query_layer.size()
qkv = torch.stack([query_layer, key_layer, value_layer], dim=2)
qkv = qkv.transpose(1, 3)
qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
context_layer = flash_attn_varlen_qkvpacked_xla(
qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True)
context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz)
context_layer = context_layer.permute(1, 0, 2, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, )
context_layer = context_layer.reshape(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
def torchacc_swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]).to(x[0].dtype) * x[1]
# patch attention
for layer in model.transformer.encoder.layers:
layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention)
layer.mlp.activation_func = torchacc_swiglu
return model
def patch_baichuan_model(model):
def baichuan_attn_forward(self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
import einops
bsz, q_len, _ = hidden_states.size()
proj = self.W_pack(hidden_states)
proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2))
query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
from torchacc.ops import flash_attn_varlen_xla
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
output = flash_attn_varlen_xla(
q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True)
output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz)
output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)'))
return output, None, past_key_value
for layer in model.base_model.layers:
layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn)
return model
def patch_qwen2_model(model):
def update_causal_mask(self, *args, **kwargs):
# attention_mask is not supported in TorchAcc.
return None
def qwen2_attn_forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
cache_position=None,
position_embeddings=None,
**kwargs,
):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
# rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
rotary_seq_len = kv_seq_len + 1
if version.parse(transformers.__version__) >= version.parse('4.45'):
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
else:
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reshape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
from torchacc.ops import flash_attn_varlen_xla
import einops
q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]]
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_xla(
q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def qwen2_forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time')
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds')
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states, )
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
from transformers.modeling_outputs import BaseModelOutputWithPast
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
for layer in model.model.layers:
layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn)
if version.parse(transformers.__version__) >= version.parse('4.43'):
model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model)
else:
model.model.forward = types.MethodType(qwen2_forward, model.model)
return model
def patch_clip_grad_norm(accelerator):
from accelerate.utils import DistributedType
from accelerate.optimizer import AcceleratedOptimizer
import torch_xla.core.xla_model as xm
def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
"""
Should be used in place of `torch.nn.utils.clip_grad_norm_`.
Returns:
`torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator(gradient_accumulation_steps=2)
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
>>> for input, target in dataloader:
... optimizer.zero_grad()
... output = model(input)
... loss = loss_func(output, target)
... accelerator.backward(loss)
... if accelerator.sync_gradients:
... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
... optimizer.step()
```
"""
if self.distributed_type == DistributedType.FSDP:
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model.clip_grad_norm_(max_norm, norm_type)
elif self.distributed_type == DistributedType.DEEPSPEED:
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
# We cannot return the gradient norm because DeepSpeed does it.
return None
elif self.distributed_type == DistributedType.XLA:
# Reduce gradients first for XLA
for acc_opt in self._optimizers:
if not acc_opt.gradient_state.is_xla_gradients_synced:
opt = acc_opt
while isinstance(opt, AcceleratedOptimizer):
opt = opt.optimizer
gradients = xm._fetch_gradients(opt)
# Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
# one by one in self.reduce is non-inplace.
xm.all_reduce('sum', gradients, scale=1.0 / self.num_processes)
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
acc_opt.gradient_state.is_xla_gradients_synced = True
if os.environ.get('ACCELERATE_USE_FSDP', 'false') == 'true':
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model._get_underlay_model().clip_grad_norm_(max_norm, norm_type)
self.unscale_gradients()
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
# TODO(baole): This should be removed once accelerate is updated.
accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator)
return accelerator
def ta_accelerate(model,
fsdp_num,
layer_cls_name,
bf16=True,
fp16=False,
gradient_checkpointing=True,
fsdp_flatten_parameters=False):
""" accelerate LLM training using TorchAcc(only available internally).
"""
import torchacc as ta
assert layer_cls_name is not None
def get_ta_config():
config = ta.Config()
config.compute.fp16 = fp16
config.compute.bf16 = bf16
config.memory.gc = gradient_checkpointing
if config.memory.gc:
config.memory.gc_cls = {layer_cls_name}
config.dist.fsdp.size = fsdp_num
config.dist.fsdp.wrap_layer_cls = {layer_cls_name}
config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters
config.dist.dp.size = 1
if fsdp_num > 1:
os.environ['ACCELERATE_USE_FSDP'] = 'true'
return config
ta_config = get_ta_config()
model = ta.accelerate(model, config=ta_config)
return model