|
|
|
|
|
import os |
|
|
import sys |
|
|
import types |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import safetensors |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import transformers |
|
|
from packaging import version |
|
|
from peft import PeftModel |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import PreTrainedModel, trainer |
|
|
from transformers.modeling_utils import unwrap_model |
|
|
|
|
|
from swift.utils import get_logger, torchacc_trim_graph, use_torchacc |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
|
|
|
def get_bucket_sizes(max_length: int) -> List[int]: |
|
|
"""Get the bucket sizes for TorchAcc. |
|
|
You can set the environment variable TORCHACC_DATA_BUCKETS to specify |
|
|
the bucket sizes. If not set, we use a normal distribution bucketing with |
|
|
8 buckets. |
|
|
""" |
|
|
padding_p_base = 2 |
|
|
if os.getenv('TORCHACC_DATA_BUCKETS') is not None: |
|
|
bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')] |
|
|
bucket_sizes.append(max_length) |
|
|
else: |
|
|
if os.getenv('TORCHACC_CACHE_PATH') is not None: |
|
|
padding_p_base = 1.4 |
|
|
padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base) |
|
|
try: |
|
|
padding_p_base = float(padding_p_base) |
|
|
except ValueError as e: |
|
|
logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}') |
|
|
raise e |
|
|
bucket_sizes = [16, 32, 48, 64, 96, 128] |
|
|
base_size = 256 |
|
|
while base_size < max_length: |
|
|
bucket_sizes.append((int(base_size) + 127) // 128 * 128) |
|
|
base_size *= padding_p_base |
|
|
bucket_sizes.append(max_length) |
|
|
|
|
|
return bucket_sizes |
|
|
|
|
|
|
|
|
def _get_closet_bucket(bucket_sizes, data_length): |
|
|
"""Select the one from bucket_sizes that is closest in distance to |
|
|
data_length. This is required for TorchAcc. |
|
|
""" |
|
|
closest_length = sys.maxsize |
|
|
for b in bucket_sizes: |
|
|
if b == data_length or ((b < closest_length) and (b > data_length)): |
|
|
closest_length = b |
|
|
|
|
|
if closest_length == sys.maxsize: |
|
|
bucket_sizes.append(data_length) |
|
|
closest_length = data_length |
|
|
|
|
|
return closest_length |
|
|
|
|
|
|
|
|
def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank, |
|
|
world_size, padding_right): |
|
|
if padding_to is None: |
|
|
longest_len = input_ids.shape[-1] |
|
|
bucket_sizes = get_bucket_sizes(max_length) |
|
|
bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len) |
|
|
padding_length = bucket_data_length - input_ids.shape[1] |
|
|
pad_tuple = (0, padding_length) if padding_right else (padding_length, 0) |
|
|
input_ids = F.pad(input_ids, pad_tuple, 'constant', tokenizer.pad_token_id) |
|
|
attention_mask = F.pad(attention_mask, pad_tuple, 'constant', 0) |
|
|
if loss_scale: |
|
|
loss_scale = F.pad(loss_scale, pad_tuple, 'constant', 0.) |
|
|
labels = F.pad(labels, pad_tuple, 'constant', -100) |
|
|
|
|
|
|
|
|
batch_size = input_ids.shape[0] // world_size |
|
|
if batch_size > 0: |
|
|
start = rank * batch_size |
|
|
end = (rank + 1) * batch_size |
|
|
input_ids = input_ids[start:end, :] |
|
|
attention_mask = attention_mask[start:end, :] |
|
|
labels = labels[start:end, :] |
|
|
if loss_scale: |
|
|
loss_scale = loss_scale[start:end, :] |
|
|
return input_ids, attention_mask, labels, loss_scale |
|
|
|
|
|
|
|
|
def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size): |
|
|
|
|
|
def acc_skip_first_batches(dataloader, num_batches=0): |
|
|
from accelerate.data_loader import SkipBatchSampler |
|
|
batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches) |
|
|
try: |
|
|
dataset = dataloader.dataset |
|
|
except AttributeError: |
|
|
dataset = dataloader._loader.dataset |
|
|
dataloader_params = { |
|
|
'collate_fn': data_collator, |
|
|
'num_workers': args.dataloader_num_workers, |
|
|
'pin_memory': args.dataloader_pin_memory, |
|
|
'persistent_workers': args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params['batch_sampler'] = batch_sampler |
|
|
dataloader_params['worker_init_fn'] = trainer.seed_worker |
|
|
|
|
|
return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device) |
|
|
|
|
|
trainer.skip_first_batches = acc_skip_first_batches |
|
|
|
|
|
|
|
|
import torchacc as ta |
|
|
|
|
|
dataloader_params = { |
|
|
'batch_size': batch_size, |
|
|
'collate_fn': data_collator, |
|
|
'num_workers': args.dataloader_num_workers, |
|
|
'pin_memory': args.dataloader_pin_memory, |
|
|
'persistent_workers': args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params['sampler'] = sampler |
|
|
dataloader_params['drop_last'] = args.dataloader_drop_last |
|
|
dataloader_params['worker_init_fn'] = trainer.seed_worker |
|
|
|
|
|
return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device) |
|
|
|
|
|
|
|
|
def ta_eval_dataloader(eval_dataset, data_collator, sampler, args): |
|
|
import torchacc as ta |
|
|
|
|
|
dataloader_params = { |
|
|
'batch_size': args.eval_batch_size, |
|
|
'collate_fn': data_collator, |
|
|
'num_workers': args.dataloader_num_workers, |
|
|
'pin_memory': args.dataloader_pin_memory, |
|
|
'persistent_workers': args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(eval_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params['sampler'] = sampler |
|
|
dataloader_params['drop_last'] = args.dataloader_drop_last |
|
|
|
|
|
return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device) |
|
|
|
|
|
|
|
|
def ta_test_dataloader(test_dataset, data_collator, sampler, args): |
|
|
import torchacc as ta |
|
|
|
|
|
dataloader_params = { |
|
|
'batch_size': args.eval_batch_size, |
|
|
'collate_fn': data_collator, |
|
|
'num_workers': args.dataloader_num_workers, |
|
|
'pin_memory': args.dataloader_pin_memory, |
|
|
'persistent_workers': args.dataloader_persistent_workers, |
|
|
} |
|
|
|
|
|
if not isinstance(test_dataset, torch.utils.data.IterableDataset): |
|
|
dataloader_params['sampler'] = sampler |
|
|
dataloader_params['drop_last'] = args.dataloader_drop_last |
|
|
|
|
|
|
|
|
return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device) |
|
|
|
|
|
|
|
|
|
|
|
def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir): |
|
|
import torch_xla.core.xla_model as xm |
|
|
xm.rendezvous('saving_optimizer_states') |
|
|
xm.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt'), master_only=False) |
|
|
xm.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt'), master_only=False) |
|
|
xm.rendezvous('saving_optimizer_states_done') |
|
|
|
|
|
|
|
|
def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device): |
|
|
import torch_xla.core.xla_model as xm |
|
|
optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu') |
|
|
lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu') |
|
|
xm.send_cpu_data_to_device(optimizer_state, device) |
|
|
xm.send_cpu_data_to_device(lr_scheduler_state, device) |
|
|
|
|
|
optimizer.load_state_dict(optimizer_state) |
|
|
lr_scheduler.load_state_dict(lr_scheduler_state) |
|
|
return optimizer, lr_scheduler |
|
|
|
|
|
|
|
|
def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None): |
|
|
output_dir = output_dir if output_dir is not None else args.output_dir |
|
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
model = self_model |
|
|
|
|
|
if xm.is_master_ordinal(local=False): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
torch.save(args, os.path.join(output_dir, 'training_args.bin')) |
|
|
|
|
|
xm.mark_step() |
|
|
|
|
|
|
|
|
supported_classes = (PreTrainedModel, PeftModel) |
|
|
if not isinstance(model, supported_classes): |
|
|
if isinstance(unwrap_model(model), supported_classes): |
|
|
unwrap_model(model).save_pretrained( |
|
|
output_dir, |
|
|
is_main_process=args.should_save, |
|
|
state_dict=xm._maybe_convert_to_cpu(model.state_dict()), |
|
|
save_function=xm.save, |
|
|
safe_serialization=args.save_safetensors, |
|
|
) |
|
|
else: |
|
|
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') |
|
|
state_dict = xm._maybe_convert_to_cpu(model.state_dict()) |
|
|
if args.save_safetensors: |
|
|
safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) |
|
|
else: |
|
|
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) |
|
|
else: |
|
|
model.save_pretrained( |
|
|
output_dir, |
|
|
is_main_process=args.should_save, |
|
|
save_function=xm.save, |
|
|
safe_serialization=args.save_safetensors, |
|
|
state_dict=xm._maybe_convert_to_cpu(model.state_dict())) |
|
|
if tokenizer is not None and args.should_save: |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir): |
|
|
import torch_xla.core.xla_model as xm |
|
|
from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints |
|
|
|
|
|
xm.mark_step() |
|
|
|
|
|
if xm.is_master_ordinal(local=False): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
torch.save(args, os.path.join(output_dir, 'training_args.bin')) |
|
|
|
|
|
supported_classes = (PreTrainedModel, PeftModel) |
|
|
model = self_model._get_underlay_model().module.module |
|
|
unwrapped_model = unwrap_model(model) |
|
|
|
|
|
xm.rendezvous('saving_checkpoint') |
|
|
ckpt = { |
|
|
'model': self_model._get_underlay_model().state_dict(), |
|
|
'shard_metadata': self_model._get_underlay_model().get_shard_metadata(), |
|
|
} |
|
|
if isinstance(model, PeftModel): |
|
|
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin') |
|
|
else: |
|
|
ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin') |
|
|
xm.save(ckpt, ckpt_path, master_only=False) |
|
|
|
|
|
xm.rendezvous('save_full_checkpoints') |
|
|
|
|
|
if tokenizer is not None and args.should_save: |
|
|
tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save) |
|
|
|
|
|
|
|
|
if xm.is_master_ordinal(local=False): |
|
|
if isinstance(model, PeftModel): |
|
|
ckpt_suffix = 'rank*-of-*-adapter_model.bin' |
|
|
else: |
|
|
ckpt_suffix = 'rank*-of-*-pytorch_model.bin' |
|
|
full_state_dict, _ = consolidate_sharded_model_checkpoints( |
|
|
ckpt_prefix=os.path.join(output_dir, ''), ckpt_suffix=ckpt_suffix, save_model=False) |
|
|
|
|
|
if isinstance(unwrapped_model, supported_classes): |
|
|
unwrapped_model.save_pretrained( |
|
|
output_dir, |
|
|
state_dict=full_state_dict, |
|
|
save_function=xm.save, |
|
|
safe_serialization=args.save_safetensors, |
|
|
) |
|
|
else: |
|
|
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') |
|
|
if args.save_safetensors: |
|
|
safetensors.torch.save_file(full_state_dict, os.path.join(output_dir, 'model.safetensors')) |
|
|
else: |
|
|
torch.save(full_state_dict, os.path.join(output_dir, 'pytorch_model.bin')) |
|
|
|
|
|
xm.rendezvous('ckpt_consolidation') |
|
|
|
|
|
os.remove(ckpt_path) |
|
|
|
|
|
|
|
|
def ta_trim_graph(): |
|
|
if use_torchacc() and torchacc_trim_graph(): |
|
|
import torchacc as ta |
|
|
ta.mark_step() |
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x): |
|
|
"""Rotates half the hidden dims of the input.""" |
|
|
x1 = x[..., :x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2:] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
|
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
|
|
Args: |
|
|
q (`torch.Tensor`): The query tensor. |
|
|
k (`torch.Tensor`): The key tensor. |
|
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
|
position_ids (`torch.Tensor`): |
|
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
|
|
used to pass offsetted position ids when working with a KV-cache. |
|
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
|
Returns: |
|
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
|
""" |
|
|
if position_ids is not None: |
|
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
|
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
|
|
else: |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
def patch_acc_model(args, model): |
|
|
if not args.use_flash_attn: |
|
|
logger.warn('Currently use flash attn for torchacc.') |
|
|
if args.model_type.startswith('qwen1half') or args.model_type.startswith('qwen2'): |
|
|
model = patch_qwen2_model(model) |
|
|
elif args.model_type.startswith('qwen'): |
|
|
import torchacc as ta |
|
|
model = ta.patch_qwen_model(model) |
|
|
elif args.model_type.startswith('baichuan'): |
|
|
model = patch_baichuan_model(model) |
|
|
elif args.model_type.startswith('llama') or args.model_type.startswith('yi'): |
|
|
model = patch_llama_model(model) |
|
|
elif args.model_type.startswith('chatglm'): |
|
|
model = patah_chatglm_model(model) |
|
|
return model |
|
|
|
|
|
|
|
|
def patch_llama_model(model): |
|
|
|
|
|
def update_causal_mask(self, *args, **kwargs): |
|
|
|
|
|
return None |
|
|
|
|
|
def llama_attn_forward(self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
from torchacc.ops import flash_attn_varlen_xla |
|
|
import einops |
|
|
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
|
|
key_states = ( |
|
|
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) |
|
|
value_states = ( |
|
|
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) |
|
|
|
|
|
kv_seq_len = key_states.shape[-2] |
|
|
assert past_key_value is None, 'past_key_value is not supported' |
|
|
|
|
|
if version.parse(transformers.__version__) >= version.parse('4.36'): |
|
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
else: |
|
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
|
|
|
assert not output_attentions, 'output_attentions is not supported' |
|
|
|
|
|
if past_key_value is not None: |
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...') |
|
|
k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...') |
|
|
v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...') |
|
|
max_s = q_len |
|
|
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
|
|
output = flash_attn_varlen_xla( |
|
|
q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True) |
|
|
output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) |
|
|
|
|
|
return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value |
|
|
|
|
|
for layer in model.model.layers: |
|
|
layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn) |
|
|
|
|
|
if version.parse(transformers.__version__) >= version.parse('4.38'): |
|
|
model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def patah_chatglm_model(model): |
|
|
|
|
|
def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) |
|
|
rot_dim = rope_cache.shape[-2] * 2 |
|
|
x, x_pass = x[..., :rot_dim], x[..., rot_dim:] |
|
|
|
|
|
rope_cache = rope_cache[:sq] |
|
|
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) |
|
|
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) |
|
|
x_out2 = torch.stack( |
|
|
[ |
|
|
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], |
|
|
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], |
|
|
], |
|
|
-1, |
|
|
) |
|
|
x_out2 = x_out2.flatten(3) |
|
|
return torch.cat((x_out2, x_pass), dim=-1) |
|
|
|
|
|
def chatglm_attn_forward(self, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
rotary_pos_emb, |
|
|
kv_cache=None, |
|
|
use_cache=True, |
|
|
**kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states) |
|
|
|
|
|
if self.multi_query_attention: |
|
|
(query_layer, key_layer, value_layer) = mixed_x_layer.split( |
|
|
[ |
|
|
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, |
|
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
|
|
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
|
|
], |
|
|
dim=-1, |
|
|
) |
|
|
query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, |
|
|
self.hidden_size_per_attention_head)) |
|
|
key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, |
|
|
self.hidden_size_per_attention_head)) |
|
|
value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, |
|
|
self.hidden_size_per_attention_head)) |
|
|
else: |
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, |
|
|
3 * self.hidden_size_per_attention_head) |
|
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
|
|
|
|
|
|
|
|
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) |
|
|
|
|
|
|
|
|
if rotary_pos_emb is not None: |
|
|
query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb) |
|
|
key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb) |
|
|
|
|
|
|
|
|
if kv_cache is not None: |
|
|
cache_k, cache_v = kv_cache |
|
|
key_layer = torch.cat((cache_k, key_layer), dim=0) |
|
|
value_layer = torch.cat((cache_v, value_layer), dim=0) |
|
|
if use_cache: |
|
|
kv_cache = (key_layer, value_layer) |
|
|
else: |
|
|
kv_cache = None |
|
|
|
|
|
if self.multi_query_attention: |
|
|
key_layer = key_layer.unsqueeze(-2) |
|
|
key_layer = key_layer.expand( |
|
|
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) |
|
|
key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, |
|
|
self.hidden_size_per_attention_head)) |
|
|
value_layer = value_layer.unsqueeze(-2) |
|
|
value_layer = value_layer.expand( |
|
|
-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) |
|
|
value_layer = value_layer.contiguous().view(value_layer.size()[:2] |
|
|
+ (self.num_attention_heads_per_partition, |
|
|
self.hidden_size_per_attention_head)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torchacc.ops import flash_attn_varlen_qkvpacked_xla |
|
|
import einops |
|
|
|
|
|
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] |
|
|
bsz, _, q_len, _ = query_layer.size() |
|
|
qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) |
|
|
qkv = qkv.transpose(1, 3) |
|
|
qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') |
|
|
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) |
|
|
context_layer = flash_attn_varlen_qkvpacked_xla( |
|
|
qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True) |
|
|
context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz) |
|
|
context_layer = context_layer.permute(1, 0, 2, 3) |
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, ) |
|
|
context_layer = context_layer.reshape(*new_context_layer_shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = self.dense(context_layer) |
|
|
|
|
|
return output, kv_cache |
|
|
|
|
|
def torchacc_swiglu(x): |
|
|
x = torch.chunk(x, 2, dim=-1) |
|
|
return F.silu(x[0]).to(x[0].dtype) * x[1] |
|
|
|
|
|
|
|
|
for layer in model.transformer.encoder.layers: |
|
|
layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention) |
|
|
layer.mlp.activation_func = torchacc_swiglu |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def patch_baichuan_model(model): |
|
|
|
|
|
def baichuan_attn_forward(self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
**kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
import einops |
|
|
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
proj = self.W_pack(hidden_states) |
|
|
proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)) |
|
|
query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
|
|
key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
|
|
value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
|
|
|
|
|
kv_seq_len = key_states.shape[-2] |
|
|
if past_key_value is not None: |
|
|
kv_seq_len += past_key_value[0].shape[-2] |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
|
|
past_key_value = (key_states, value_states) if use_cache else None |
|
|
|
|
|
from torchacc.ops import flash_attn_varlen_xla |
|
|
query_states = query_states.transpose(1, 2) |
|
|
key_states = key_states.transpose(1, 2) |
|
|
value_states = value_states.transpose(1, 2) |
|
|
q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] |
|
|
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
|
|
output = flash_attn_varlen_xla( |
|
|
q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True) |
|
|
output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) |
|
|
output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')) |
|
|
return output, None, past_key_value |
|
|
|
|
|
for layer in model.base_model.layers: |
|
|
layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def patch_qwen2_model(model): |
|
|
|
|
|
def update_causal_mask(self, *args, **kwargs): |
|
|
|
|
|
return None |
|
|
|
|
|
def qwen2_attn_forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
position_ids=None, |
|
|
past_key_value=None, |
|
|
output_attentions=False, |
|
|
use_cache=False, |
|
|
cache_position=None, |
|
|
position_embeddings=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
kv_seq_len = key_states.shape[-2] |
|
|
if past_key_value is not None: |
|
|
if self.layer_idx is None: |
|
|
raise ValueError( |
|
|
f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' |
|
|
'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' |
|
|
'with a layer index.') |
|
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
|
|
|
|
|
|
|
|
|
|
rotary_seq_len = kv_seq_len + 1 |
|
|
|
|
|
if version.parse(transformers.__version__) >= version.parse('4.45'): |
|
|
if position_embeddings is None: |
|
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
|
else: |
|
|
cos, sin = position_embeddings |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
else: |
|
|
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
|
|
|
dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
|
if input_dtype == torch.float32: |
|
|
if torch.is_autocast_enabled(): |
|
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
|
|
elif hasattr(self.config, '_pre_quantization_dtype'): |
|
|
target_dtype = self.config._pre_quantization_dtype |
|
|
else: |
|
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
|
|
query_states = query_states.to(target_dtype) |
|
|
key_states = key_states.to(target_dtype) |
|
|
value_states = value_states.to(target_dtype) |
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
|
key_states = key_states.transpose(1, 2) |
|
|
value_states = value_states.transpose(1, 2) |
|
|
|
|
|
from torchacc.ops import flash_attn_varlen_xla |
|
|
import einops |
|
|
|
|
|
q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] |
|
|
cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
|
|
|
|
|
attn_output = flash_attn_varlen_xla( |
|
|
q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True) |
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
|
|
if not output_attentions: |
|
|
attn_weights = None |
|
|
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
def qwen2_forward(self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs): |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time') |
|
|
elif input_ids is not None: |
|
|
batch_size, seq_length = input_ids.shape |
|
|
elif inputs_embeds is not None: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
else: |
|
|
raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds') |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
if use_cache: |
|
|
use_cache = False |
|
|
|
|
|
past_key_values_length = 0 |
|
|
|
|
|
if use_cache: |
|
|
use_legacy_cache = not isinstance(past_key_values, Cache) |
|
|
if use_legacy_cache: |
|
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
|
|
past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
|
|
|
if position_ids is None: |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
position_ids = torch.arange( |
|
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) |
|
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
|
|
else: |
|
|
position_ids = position_ids.view(-1, seq_length).long() |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attns = () if output_attentions else None |
|
|
next_decoder_cache = None |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states, ) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
layer_outputs = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
position_ids, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
use_cache, |
|
|
) |
|
|
else: |
|
|
layer_outputs = decoder_layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_value=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if use_cache: |
|
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
|
|
if output_attentions: |
|
|
all_self_attns += (layer_outputs[1], ) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states, ) |
|
|
|
|
|
next_cache = None |
|
|
if use_cache: |
|
|
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=next_cache, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attns, |
|
|
) |
|
|
|
|
|
for layer in model.model.layers: |
|
|
layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn) |
|
|
|
|
|
if version.parse(transformers.__version__) >= version.parse('4.43'): |
|
|
model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) |
|
|
else: |
|
|
model.model.forward = types.MethodType(qwen2_forward, model.model) |
|
|
return model |
|
|
|
|
|
|
|
|
def patch_clip_grad_norm(accelerator): |
|
|
from accelerate.utils import DistributedType |
|
|
from accelerate.optimizer import AcceleratedOptimizer |
|
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
def clip_grad_norm_(self, parameters, max_norm, norm_type=2): |
|
|
""" |
|
|
Should be used in place of `torch.nn.utils.clip_grad_norm_`. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector). |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from accelerate import Accelerator |
|
|
|
|
|
>>> accelerator = Accelerator(gradient_accumulation_steps=2) |
|
|
>>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) |
|
|
|
|
|
>>> for input, target in dataloader: |
|
|
... optimizer.zero_grad() |
|
|
... output = model(input) |
|
|
... loss = loss_func(output, target) |
|
|
... accelerator.backward(loss) |
|
|
... if accelerator.sync_gradients: |
|
|
... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) |
|
|
... optimizer.step() |
|
|
``` |
|
|
""" |
|
|
if self.distributed_type == DistributedType.FSDP: |
|
|
self.unscale_gradients() |
|
|
parameters = [p for p in parameters] |
|
|
for model in self._models: |
|
|
if parameters == [p for p in model.parameters()]: |
|
|
return model.clip_grad_norm_(max_norm, norm_type) |
|
|
elif self.distributed_type == DistributedType.DEEPSPEED: |
|
|
|
|
|
|
|
|
return None |
|
|
elif self.distributed_type == DistributedType.XLA: |
|
|
|
|
|
for acc_opt in self._optimizers: |
|
|
if not acc_opt.gradient_state.is_xla_gradients_synced: |
|
|
opt = acc_opt |
|
|
while isinstance(opt, AcceleratedOptimizer): |
|
|
opt = opt.optimizer |
|
|
gradients = xm._fetch_gradients(opt) |
|
|
|
|
|
|
|
|
xm.all_reduce('sum', gradients, scale=1.0 / self.num_processes) |
|
|
|
|
|
acc_opt.gradient_state.is_xla_gradients_synced = True |
|
|
if os.environ.get('ACCELERATE_USE_FSDP', 'false') == 'true': |
|
|
self.unscale_gradients() |
|
|
parameters = [p for p in parameters] |
|
|
for model in self._models: |
|
|
if parameters == [p for p in model.parameters()]: |
|
|
return model._get_underlay_model().clip_grad_norm_(max_norm, norm_type) |
|
|
self.unscale_gradients() |
|
|
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) |
|
|
|
|
|
|
|
|
accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) |
|
|
return accelerator |
|
|
|
|
|
|
|
|
def ta_accelerate(model, |
|
|
fsdp_num, |
|
|
layer_cls_name, |
|
|
bf16=True, |
|
|
fp16=False, |
|
|
gradient_checkpointing=True, |
|
|
fsdp_flatten_parameters=False): |
|
|
""" accelerate LLM training using TorchAcc(only available internally). |
|
|
""" |
|
|
import torchacc as ta |
|
|
assert layer_cls_name is not None |
|
|
|
|
|
def get_ta_config(): |
|
|
config = ta.Config() |
|
|
config.compute.fp16 = fp16 |
|
|
config.compute.bf16 = bf16 |
|
|
|
|
|
config.memory.gc = gradient_checkpointing |
|
|
if config.memory.gc: |
|
|
config.memory.gc_cls = {layer_cls_name} |
|
|
|
|
|
config.dist.fsdp.size = fsdp_num |
|
|
config.dist.fsdp.wrap_layer_cls = {layer_cls_name} |
|
|
config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters |
|
|
config.dist.dp.size = 1 |
|
|
|
|
|
if fsdp_num > 1: |
|
|
os.environ['ACCELERATE_USE_FSDP'] = 'true' |
|
|
|
|
|
return config |
|
|
|
|
|
ta_config = get_ta_config() |
|
|
model = ta.accelerate(model, config=ta_config) |
|
|
return model |
|
|
|