File size: 2,209 Bytes
73e457a
 
 
 
 
 
 
324acb1
73e457a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5316e5b
73e457a
 
 
 
 
 
 
 
 
 
 
 
 
5316e5b
73e457a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# /// script
# dependencies = [
#   "trl>=0.12.0",
#   "peft>=0.7.0",
#   "trackio",
#   "transformers>=4.45.0",
#   "torch",
#   "torchvision",
#   "datasets",
#   "pillow",
#   "qwen-vl-utils"
# ]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
import trackio
import torch

# Load 1% of the train split
print("Loading dataset...")
dataset = load_dataset("trl-lib/llava-instruct-mix", split="train[:1%]")

print(f"Dataset size: {len(dataset)} examples")

# Create a small eval split (10% of the 1%)
dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = dataset_split["train"]
eval_dataset = dataset_split["test"]

print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")

# Configure trainer with VL-specific settings
trainer = SFTTrainer(
    model="Qwen/Qwen2.5-VL-3B-Instruct",
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    ),
    args=SFTConfig(
        output_dir="qwen3-vl-3b-llava-instruct",
        push_to_hub=True,
        hub_model_id="merve/qwen2.5-vl-3b-llava-instruct",
        num_train_epochs=3,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        gradient_checkpointing=True,
        learning_rate=2e-4,
        warmup_steps=100,
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        save_strategy="steps",
        save_steps=100,
        save_total_limit=2,
        bf16=True,
        report_to="trackio",
        project="qwen3-vl-finetuning",
        run_name="qwen3-vl-3b-llava-1pct",
        max_length=None,  # Important for VL models - don't truncate image tokens
        hub_strategy="every_save",
        remove_unused_columns=False,  # Keep all columns for VL processing
    )
)

print("Starting training...")
trainer.train()

print("Pushing final model to Hub...")
trainer.push_to_hub()

print("Training complete!")