import json import numpy as np from datasets import Dataset, DatasetDict from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, EvalPrediction ) from sklearn.metrics import mean_squared_error, mean_absolute_error from huggingface_hub import HfFolder, notebook_login MODEL_NAME = "roberta-base" DATASET_PATH = "/content/data/dataset_for_scorer.json" MODEL_OUTPUT_DIR = "./ielts_grader_model" HUB_MODEL_ID = "diminch/ielts-grader-ai" def load_and_prepare_data(dataset_path): print(f"Đang tải dữ liệu từ {dataset_path}...") with open(dataset_path, "r", encoding="utf-8") as f: raw_data = json.load(f) processed_data = [] for item in raw_data: text = item['prompt_text'] + " [SEP] " + item['essay_text'] labels = [ float(item['scores']['task_response']), float(item['scores']['coherence_cohesion']), float(item['scores']['lexical_resource']), float(item['scores']['grammatical_range']) ] processed_data.append({"text": text, "label": labels}) print(f"Tổng cộng {len(processed_data)} mẫu.") dataset = Dataset.from_list(processed_data) train_test_split = dataset.train_test_split(test_size=0.1) dataset_dict = DatasetDict({ 'train': train_test_split['train'], 'test': train_test_split['test'] }) return dataset_dict def tokenize_data(dataset_dict, tokenizer): print("Đang tokenize dữ liệu...") def tokenize_function(examples): return tokenizer( examples['text'], padding="max_length", truncation=True, max_length=512 ) tokenized_datasets = dataset_dict.map(tokenize_function, batched=True) return tokenized_datasets def compute_metrics(p: EvalPrediction): preds = p.predictions labels = p.label_ids rmse_tr = np.sqrt(mean_squared_error(labels[:, 0], preds[:, 0])) rmse_cc = np.sqrt(mean_squared_error(labels[:, 1], preds[:, 1])) rmse_lr = np.sqrt(mean_squared_error(labels[:, 2], preds[:, 2])) rmse_gra = np.sqrt(mean_squared_error(labels[:, 3], preds[:, 3])) mae_tr = mean_absolute_error(labels[:, 0], preds[:, 0]) mae_cc = mean_absolute_error(labels[:, 1], preds[:, 1]) mae_lr = mean_absolute_error(labels[:, 2], preds[:, 2]) mae_gra = mean_absolute_error(labels[:, 3], preds[:, 3]) avg_rmse = np.mean([rmse_tr, rmse_cc, rmse_lr, rmse_gra]) return { "avg_rmse": avg_rmse, "rmse_task_response": rmse_tr, "rmse_coherence_cohesion": rmse_cc, "rmse_lexical_resource": rmse_lr, "rmse_grammatical_range": rmse_gra, "mae_task_response": mae_tr, "mae_coherence_cohesion": mae_cc, # ... có thể thêm các MAE khác } def main(): print("Vui lòng dán token Hugging Face (quyền 'write') của bạn:") # (Nếu chạy trên Colab, nó sẽ hiện ô input) # notebook_login() # Hoặc nếu chạy local, dùng 'huggingface-cli login' trước tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) dataset_dict = load_and_prepare_data(DATASET_PATH) tokenized_datasets = tokenize_data(dataset_dict, tokenizer) print("Đang tải mô hình nền tảng...") model = AutoModelForSequenceClassification.from_pretrained( MODEL_NAME, num_labels=4, problem_type="regression" ) training_args = TrainingArguments( output_dir=MODEL_OUTPUT_DIR, learning_rate=2e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, eval_strategy="epoch", # Changed evaluation_strategy to eval_strategy save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="avg_rmse", greater_is_better=False, push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="end", ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], compute_metrics=compute_metrics, tokenizer=tokenizer, ) print("--- BẮT ĐẦU HUẤN LUYỆN ---") trainer.train() print("--- HUẤN LUYỆN HOÀN TẤT ---") print("--- ĐÁNH GIÁ TRÊN TẬP TEST ---") eval_results = trainer.evaluate() print(json.dumps(eval_results, indent=2)) print("Đang đẩy model tốt nhất lên Hugging Face Hub...") trainer.push_to_hub() print(f"Hoàn tất! Model của bạn đã ở trên Hub: https://huggingface.co/{HUB_MODEL_ID}") if __name__ == "__main__": import os if not os.path.exists(DATASET_PATH): print(f"LỖI: Không tìm thấy file {DATASET_PATH}.") else: main()