|
|
|
|
|
import dataclasses |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import typing |
|
|
from collections import OrderedDict |
|
|
from dataclasses import fields |
|
|
from datetime import datetime |
|
|
from functools import wraps |
|
|
from typing import Any, Dict, List, Type |
|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
from gradio import Accordion, Audio, Button, Checkbox, Dropdown, File, Image, Slider, Tab, TabItem, Textbox, Video |
|
|
from modelscope.hub.utils.utils import get_cache_dir |
|
|
|
|
|
from swift.llm import TEMPLATE_MAPPING, BaseArguments, get_matched_model_meta |
|
|
|
|
|
all_langs = ['zh', 'en'] |
|
|
builder: Type['BaseUI'] = None |
|
|
base_builder: Type['BaseUI'] = None |
|
|
|
|
|
|
|
|
def update_data(fn): |
|
|
|
|
|
@wraps(fn) |
|
|
def wrapper(*args, **kwargs): |
|
|
elem_id = kwargs.get('elem_id', None) |
|
|
self = args[0] |
|
|
|
|
|
if builder is not None: |
|
|
choices = base_builder.choice(elem_id) |
|
|
if choices: |
|
|
choices = [str(choice) if choice is not None else None for choice in choices] |
|
|
kwargs['choices'] = choices |
|
|
|
|
|
if not isinstance(self, (Tab, TabItem, Accordion)) and 'interactive' not in kwargs: |
|
|
kwargs['interactive'] = True |
|
|
|
|
|
if 'is_list' in kwargs: |
|
|
self.is_list = kwargs.pop('is_list') |
|
|
|
|
|
if base_builder and base_builder.default(elem_id) is not None and not kwargs.get('value'): |
|
|
kwargs['value'] = base_builder.default(elem_id) |
|
|
|
|
|
if builder is not None: |
|
|
if elem_id in builder.locales(builder.lang): |
|
|
values = builder.locale(elem_id, builder.lang) |
|
|
if 'info' in values: |
|
|
kwargs['info'] = values['info'] |
|
|
if 'value' in values: |
|
|
kwargs['value'] = values['value'] |
|
|
if 'label' in values: |
|
|
kwargs['label'] = values['label'] |
|
|
if hasattr(builder, 'visible'): |
|
|
kwargs['visible'] = builder.visible |
|
|
argument = base_builder.argument(elem_id) |
|
|
if argument and 'label' in kwargs: |
|
|
kwargs['label'] = kwargs['label'] + f'({argument})' |
|
|
|
|
|
kwargs['elem_classes'] = 'align' |
|
|
ret = fn(self, **kwargs) |
|
|
self.constructor_args.update(kwargs) |
|
|
|
|
|
if builder is not None: |
|
|
builder.element_dict[elem_id] = self |
|
|
return ret |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
Textbox.__init__ = update_data(Textbox.__init__) |
|
|
Dropdown.__init__ = update_data(Dropdown.__init__) |
|
|
Checkbox.__init__ = update_data(Checkbox.__init__) |
|
|
Slider.__init__ = update_data(Slider.__init__) |
|
|
TabItem.__init__ = update_data(TabItem.__init__) |
|
|
Accordion.__init__ = update_data(Accordion.__init__) |
|
|
Button.__init__ = update_data(Button.__init__) |
|
|
File.__init__ = update_data(File.__init__) |
|
|
Image.__init__ = update_data(Image.__init__) |
|
|
Video.__init__ = update_data(Video.__init__) |
|
|
Audio.__init__ = update_data(Audio.__init__) |
|
|
|
|
|
|
|
|
class BaseUI: |
|
|
|
|
|
choice_dict: Dict[str, List] = {} |
|
|
default_dict: Dict[str, Any] = {} |
|
|
locale_dict: Dict[str, Dict] = {} |
|
|
element_dict: Dict[str, Dict] = {} |
|
|
arguments: Dict[str, str] = {} |
|
|
sub_ui: List[Type['BaseUI']] = [] |
|
|
group: str = None |
|
|
lang: str = all_langs[0] |
|
|
int_regex = r'^[-+]?[0-9]+$' |
|
|
float_regex = r'[-+]?(?:\d*\.*\d+)' |
|
|
bool_regex = r'^(T|t)rue$|^(F|f)alse$' |
|
|
cache_dir = os.path.join(get_cache_dir(), 'swift-web-ui') |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
quote = '\'' if sys.platform != 'win32' else '"' |
|
|
visible = True |
|
|
_locale = { |
|
|
'local_dir_alert': { |
|
|
'value': { |
|
|
'zh': '无法识别model_type和template,请手动选择', |
|
|
'en': 'Cannot recognize the model_type and template, please choose manually' |
|
|
} |
|
|
}, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def build_ui(cls, base_tab: Type['BaseUI']): |
|
|
"""Build UI""" |
|
|
global builder, base_builder |
|
|
cls.element_dict = {} |
|
|
old_builder = builder |
|
|
old_base_builder = base_builder |
|
|
builder = cls |
|
|
base_builder = base_tab |
|
|
cls.do_build_ui(base_tab) |
|
|
builder = old_builder |
|
|
base_builder = old_base_builder |
|
|
if cls is base_tab: |
|
|
for ui in cls.sub_ui: |
|
|
ui.after_build_ui(base_tab) |
|
|
|
|
|
@classmethod |
|
|
def after_build_ui(cls, base_tab: Type['BaseUI']): |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def do_build_ui(cls, base_tab: Type['BaseUI']): |
|
|
"""Build UI""" |
|
|
pass |
|
|
|
|
|
@classmethod |
|
|
def save_cache(cls, key, value): |
|
|
timestamp = str(int(time.time())) |
|
|
key = key.replace('/', '-') |
|
|
filename = os.path.join(cls.cache_dir, key + '-' + timestamp) |
|
|
with open(filename, 'w', encoding='utf-8') as f: |
|
|
json.dump(value, f) |
|
|
|
|
|
@classmethod |
|
|
def list_cache(cls, key): |
|
|
files = [] |
|
|
key = key.replace('/', '-') |
|
|
for _, _, filenames in os.walk(cls.cache_dir): |
|
|
for filename in filenames: |
|
|
if filename.startswith(key): |
|
|
idx = filename.rfind('-') |
|
|
key, ts = filename[:idx], filename[idx + 1:] |
|
|
dt_object = datetime.fromtimestamp(int(ts)) |
|
|
formatted_time = dt_object.strftime('%Y/%m/%d %H:%M:%S') |
|
|
files.append(formatted_time) |
|
|
return sorted(files, reverse=True) |
|
|
|
|
|
@classmethod |
|
|
def load_cache(cls, key, timestamp) -> BaseArguments: |
|
|
dt_object = datetime.strptime(timestamp, '%Y/%m/%d %H:%M:%S') |
|
|
timestamp = int(dt_object.timestamp()) |
|
|
key = key.replace('/', '-') |
|
|
filename = key + '-' + str(timestamp) |
|
|
with open(os.path.join(cls.cache_dir, filename), 'r', encoding='utf-8') as f: |
|
|
return json.load(f) |
|
|
|
|
|
@classmethod |
|
|
def clear_cache(cls, key): |
|
|
key = key.replace('/', '-') |
|
|
for _, _, filenames in os.walk(cls.cache_dir): |
|
|
for filename in filenames: |
|
|
if filename.startswith(key): |
|
|
os.remove(os.path.join(cls.cache_dir, filename)) |
|
|
|
|
|
@classmethod |
|
|
def choice(cls, elem_id): |
|
|
"""Get choice by elem_id""" |
|
|
for sub_ui in BaseUI.sub_ui: |
|
|
_choice = sub_ui.choice(elem_id) |
|
|
if _choice: |
|
|
return _choice |
|
|
return cls.choice_dict.get(elem_id, []) |
|
|
|
|
|
@classmethod |
|
|
def default(cls, elem_id): |
|
|
"""Get choice by elem_id""" |
|
|
if elem_id in cls.default_dict: |
|
|
return cls.default_dict.get(elem_id) |
|
|
for sub_ui in BaseUI.sub_ui: |
|
|
_choice = sub_ui.default(elem_id) |
|
|
if _choice: |
|
|
return _choice |
|
|
return None |
|
|
|
|
|
@classmethod |
|
|
def locale(cls, elem_id, lang): |
|
|
"""Get locale by elem_id""" |
|
|
return cls.locales(lang)[elem_id] |
|
|
|
|
|
@classmethod |
|
|
def locales(cls, lang): |
|
|
"""Get locale by lang""" |
|
|
locales = OrderedDict() |
|
|
for sub_ui in cls.sub_ui: |
|
|
_locales = sub_ui.locales(lang) |
|
|
locales.update(_locales) |
|
|
for key, value in cls.locale_dict.items(): |
|
|
locales[key] = {k: v[lang] for k, v in value.items()} |
|
|
return locales |
|
|
|
|
|
@classmethod |
|
|
def elements(cls): |
|
|
"""Get all elements""" |
|
|
elements = OrderedDict() |
|
|
elements.update(cls.element_dict) |
|
|
for sub_ui in cls.sub_ui: |
|
|
_elements = sub_ui.elements() |
|
|
elements.update(_elements) |
|
|
return elements |
|
|
|
|
|
@classmethod |
|
|
def valid_elements(cls): |
|
|
valid_elements = OrderedDict() |
|
|
elements = cls.elements() |
|
|
for key, value in elements.items(): |
|
|
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record': |
|
|
valid_elements[key] = value |
|
|
return valid_elements |
|
|
|
|
|
@classmethod |
|
|
def element_keys(cls): |
|
|
return list(cls.elements().keys()) |
|
|
|
|
|
@classmethod |
|
|
def valid_element_keys(cls): |
|
|
return [ |
|
|
key for key, value in cls.elements().items() |
|
|
if isinstance(value, (Textbox, Dropdown, Slider, Checkbox)) and key != 'train_record' |
|
|
] |
|
|
|
|
|
@classmethod |
|
|
def element(cls, elem_id): |
|
|
"""Get element by elem_id""" |
|
|
elements = cls.elements() |
|
|
return elements[elem_id] |
|
|
|
|
|
@classmethod |
|
|
def argument(cls, elem_id): |
|
|
"""Get argument by elem_id""" |
|
|
return cls.arguments.get(elem_id) |
|
|
|
|
|
@classmethod |
|
|
def set_lang(cls, lang): |
|
|
cls.lang = lang |
|
|
for sub_ui in cls.sub_ui: |
|
|
sub_ui.lang = lang |
|
|
|
|
|
@staticmethod |
|
|
def get_choices_from_dataclass(dataclass): |
|
|
choice_dict = {} |
|
|
for f in fields(dataclass): |
|
|
default_value = f.default |
|
|
if 'MISSING_TYPE' in str(default_value): |
|
|
default_value = None |
|
|
if 'choices' in f.metadata: |
|
|
choice_dict[f.name] = list(f.metadata['choices']) |
|
|
if 'Literal' in str(f.type) and typing.get_args(f.type): |
|
|
choice_dict[f.name] = list(typing.get_args(f.type)) |
|
|
if f.name in choice_dict and default_value not in choice_dict[f.name]: |
|
|
choice_dict[f.name].insert(0, default_value) |
|
|
return choice_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_default_value_from_dataclass(dataclass): |
|
|
default_dict = {} |
|
|
for f in fields(dataclass): |
|
|
if f.default.__class__ is dataclasses._MISSING_TYPE: |
|
|
default_dict[f.name] = f.default_factory() |
|
|
else: |
|
|
default_dict[f.name] = f.default |
|
|
if isinstance(default_dict[f.name], list): |
|
|
try: |
|
|
default_dict[f.name] = ' '.join(default_dict[f.name]) |
|
|
except TypeError: |
|
|
default_dict[f.name] = None |
|
|
if not default_dict[f.name]: |
|
|
default_dict[f.name] = None |
|
|
return default_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_argument_names(dataclass): |
|
|
arguments = {} |
|
|
for f in fields(dataclass): |
|
|
arguments[f.name] = f'--{f.name}' |
|
|
return arguments |
|
|
|
|
|
@classmethod |
|
|
def update_input_model(cls, model, allow_keys=None, has_record=True, arg_cls=BaseArguments, is_ref_model=False): |
|
|
keys = cls.valid_element_keys() |
|
|
if allow_keys: |
|
|
keys = [key for key in keys if key in allow_keys] |
|
|
|
|
|
if not model: |
|
|
ret = [gr.update()] * (len(keys) + int(has_record)) |
|
|
if len(ret) == 1: |
|
|
return ret[0] |
|
|
else: |
|
|
return ret |
|
|
|
|
|
model_meta = get_matched_model_meta(model) |
|
|
local_args_path = os.path.join(model, 'args.json') |
|
|
if model_meta is None and not os.path.exists(local_args_path): |
|
|
gr.Info(cls._locale['local_dir_alert']['value'][cls.lang]) |
|
|
ret = [gr.update()] * (len(keys) + int(has_record)) |
|
|
if len(ret) == 1: |
|
|
return ret[0] |
|
|
else: |
|
|
return ret |
|
|
|
|
|
if os.path.exists(local_args_path): |
|
|
try: |
|
|
if hasattr(arg_cls, 'resume_from_checkpoint'): |
|
|
try: |
|
|
args = arg_cls(resume_from_checkpoint=model, load_data_args=True) |
|
|
except Exception as e: |
|
|
if 'using `--model`' in str(e): |
|
|
args = arg_cls(model=model, load_data_args=True) |
|
|
else: |
|
|
raise e |
|
|
else: |
|
|
args = arg_cls(ckpt_dir=model, load_data_args=True) |
|
|
except ValueError: |
|
|
return [gr.update()] * (len(keys) + int(has_record)) |
|
|
values = [] |
|
|
for key in keys: |
|
|
arg_value = getattr(args, key, None) |
|
|
if arg_value and key != 'model': |
|
|
if key in ('torch_dtype', 'bnb_4bit_compute_dtype'): |
|
|
arg_value = str(arg_value).split('.')[1] |
|
|
if isinstance(arg_value, list) and key != 'dataset': |
|
|
try: |
|
|
arg_value = ' '.join(arg_value) |
|
|
except Exception: |
|
|
arg_value = None |
|
|
values.append(gr.update(value=arg_value)) |
|
|
else: |
|
|
values.append(gr.update()) |
|
|
ret = [gr.update(choices=[])] * int(has_record) + values |
|
|
if len(ret) == 1: |
|
|
return ret[0] |
|
|
else: |
|
|
return ret |
|
|
else: |
|
|
values = [] |
|
|
for key in keys: |
|
|
if key not in ('template', 'model_type', 'ref_model_type', 'system'): |
|
|
values.append(gr.update()) |
|
|
elif key in ('template', 'model_type', 'ref_model_type'): |
|
|
if key == 'ref_model_type': |
|
|
if is_ref_model: |
|
|
values.append(gr.update(value=getattr(model_meta, 'model_type'))) |
|
|
else: |
|
|
values.append(gr.update()) |
|
|
else: |
|
|
values.append(gr.update(value=getattr(model_meta, key))) |
|
|
else: |
|
|
values.append(gr.update(value=TEMPLATE_MAPPING[model_meta.template].default_system)) |
|
|
|
|
|
if has_record: |
|
|
return [gr.update(choices=cls.list_cache(model))] + values |
|
|
else: |
|
|
if len(values) == 1: |
|
|
return values[0] |
|
|
return values |
|
|
|
|
|
@classmethod |
|
|
def update_all_settings(cls, model, train_record, base_tab): |
|
|
if not train_record: |
|
|
return [gr.update()] * len(cls.elements()) |
|
|
cache = cls.load_cache(model, train_record) |
|
|
updates = [] |
|
|
for key, value in base_tab.valid_elements().items(): |
|
|
if key in cache: |
|
|
updates.append(gr.update(value=cache[key])) |
|
|
else: |
|
|
updates.append(gr.update()) |
|
|
return updates |
|
|
|