# Copyright (c) Alibaba, Inc. and its affiliates. import os import re import signal import sys import time from copy import deepcopy from datetime import datetime from functools import partial from typing import List, Type import gradio as gr import json import torch from json import JSONDecodeError from transformers.utils import is_torch_cuda_available, is_torch_npu_available from swift.llm import DeployArguments, InferArguments, InferClient, InferRequest, RequestConfig from swift.ui.base import BaseUI from swift.ui.llm_infer.model import Model from swift.ui.llm_infer.runtime import Runtime from swift.utils import get_device_count, get_logger logger = get_logger() class LLMInfer(BaseUI): group = 'llm_infer' is_multimodal = True sub_ui = [Model, Runtime] locale_dict = { 'generate_alert': { 'value': { 'zh': '请先部署模型', 'en': 'Please deploy model first', } }, 'port': { 'label': { 'zh': '端口', 'en': 'port' }, }, 'llm_infer': { 'label': { 'zh': 'LLM推理', 'en': 'LLM Inference', } }, 'load_alert': { 'value': { 'zh': '部署中,请点击"展示部署状态"查看', 'en': 'Start to deploy model, ' 'please Click "Show running ' 'status" to view details', } }, 'loaded_alert': { 'value': { 'zh': '模型加载完成', 'en': 'Model loaded' } }, 'port_alert': { 'value': { 'zh': '该端口已被占用', 'en': 'The port has been occupied' } }, 'chatbot': { 'value': { 'zh': '对话框', 'en': 'Chat bot' }, }, 'infer_model_type': { 'label': { 'zh': 'Lora模块', 'en': 'Lora module' }, 'info': { 'zh': '发送给server端哪个LoRA,默认为`default`', 'en': 'Which LoRA to use on server, default value is `default`' } }, 'prompt': { 'label': { 'zh': '请输入:', 'en': 'Input:' }, }, 'clear_history': { 'value': { 'zh': '清除对话信息', 'en': 'Clear history' }, }, 'submit': { 'value': { 'zh': '🚀 发送', 'en': '🚀 Send' }, }, 'gpu_id': { 'label': { 'zh': '选择可用GPU', 'en': 'Choose GPU' }, 'info': { 'zh': '选择训练使用的GPU号,如CUDA不可用只能选择CPU', 'en': 'Select GPU to train' } }, } choice_dict = BaseUI.get_choices_from_dataclass(InferArguments) default_dict = BaseUI.get_default_value_from_dataclass(InferArguments) arguments = BaseUI.get_argument_names(InferArguments) @classmethod def do_build_ui(cls, base_tab: Type['BaseUI']): with gr.TabItem(elem_id='llm_infer', label=''): default_device = 'cpu' device_count = get_device_count() if device_count > 0: default_device = '0' with gr.Blocks(): infer_request = gr.State(None) Model.build_ui(base_tab) Runtime.build_ui(base_tab) with gr.Row(): gr.Dropdown( elem_id='gpu_id', multiselect=True, choices=[str(i) for i in range(device_count)] + ['cpu'], value=default_device, scale=8) infer_model_type = gr.Textbox(elem_id='infer_model_type', scale=4) gr.Textbox(elem_id='port', lines=1, value='8000', scale=4) chatbot = gr.Chatbot(elem_id='chatbot', elem_classes='control-height') with gr.Row(): prompt = gr.Textbox(elem_id='prompt', lines=1, interactive=True) with gr.Tabs(visible=cls.is_multimodal): with gr.TabItem(label='Image'): image = gr.Image(type='filepath') with gr.TabItem(label='Video'): video = gr.Video() with gr.TabItem(label='Audio'): audio = gr.Audio(type='filepath') with gr.Row(): clear_history = gr.Button(elem_id='clear_history') submit = gr.Button(elem_id='submit') cls.element('load_checkpoint').click( cls.deploy_model, list(base_tab.valid_elements().values()), [cls.element('runtime_tab'), cls.element('running_tasks')]) submit.click( cls.send_message, inputs=[ cls.element('running_tasks'), cls.element('template'), prompt, image, video, audio, infer_request, infer_model_type, cls.element('system'), cls.element('max_new_tokens'), cls.element('temperature'), cls.element('top_k'), cls.element('top_p'), cls.element('repetition_penalty') ], outputs=[prompt, chatbot, image, video, audio, infer_request], queue=True) clear_history.click( fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot, image, video, audio, infer_request]) base_tab.element('running_tasks').change( partial(Runtime.task_changed, base_tab=base_tab), [base_tab.element('running_tasks')], list(cls.valid_elements().values()) + [cls.element('log')]) Runtime.element('kill_task').click( Runtime.kill_task, [Runtime.element('running_tasks')], [Runtime.element('running_tasks')] + [Runtime.element('log')], ) @classmethod def deploy(cls, *args): deploy_args = cls.get_default_value_from_dataclass(DeployArguments) kwargs = {} kwargs_is_list = {} other_kwargs = {} more_params = {} more_params_cmd = '' keys = cls.valid_element_keys() for key, value in zip(keys, args): compare_value = deploy_args.get(key) compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value if key in deploy_args and compare_value_ui != compare_value_arg and value: if isinstance(value, str) and re.fullmatch(cls.int_regex, value): value = int(value) elif isinstance(value, str) and re.fullmatch(cls.float_regex, value): value = float(value) elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value): value = True if value.lower() == 'true' else False kwargs[key] = value if not isinstance(value, list) else ' '.join(value) kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False) else: other_kwargs[key] = value if key == 'more_params' and value: try: more_params = json.loads(value) except (JSONDecodeError or TypeError): more_params_cmd = value kwargs.update(more_params) model = kwargs.get('model') if os.path.exists(model) and os.path.exists(os.path.join(model, 'args.json')): kwargs['ckpt_dir'] = kwargs.pop('model') with open(os.path.join(kwargs['ckpt_dir'], 'args.json'), 'r', encoding='utf-8') as f: _json = json.load(f) kwargs['model_type'] = _json['model_type'] kwargs['train_type'] = _json['train_type'] deploy_args = DeployArguments( **{ key: value.split(' ') if key in kwargs_is_list and kwargs_is_list[key] else value for key, value in kwargs.items() }) if deploy_args.port in Runtime.get_all_ports(): raise gr.Error(cls.locale('port_alert', cls.lang)['value']) params = '' sep = f'{cls.quote} {cls.quote}' for e in kwargs: if isinstance(kwargs[e], list): params += f'--{e} {cls.quote}{sep.join(kwargs[e])}{cls.quote} ' elif e in kwargs_is_list and kwargs_is_list[e]: all_args = [arg for arg in kwargs[e].split(' ') if arg.strip()] params += f'--{e} {cls.quote}{sep.join(all_args)}{cls.quote} ' else: params += f'--{e} {cls.quote}{kwargs[e]}{cls.quote} ' if 'port' not in kwargs: params += f'--port "{deploy_args.port}" ' params += more_params_cmd + ' ' devices = other_kwargs['gpu_id'] devices = [d for d in devices if d] assert (len(devices) == 1 or 'cpu' not in devices) gpus = ','.join(devices) cuda_param = '' if gpus != 'cpu': if is_torch_npu_available(): cuda_param = f'ASCEND_RT_VISIBLE_DEVICES={gpus}' elif is_torch_cuda_available(): cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}' else: cuda_param = '' now = datetime.now() time_str = f'{now.year}{now.month}{now.day}{now.hour}{now.minute}{now.second}' file_path = f'output/{deploy_args.model_type}-{time_str}' if not os.path.exists(file_path): os.makedirs(file_path, exist_ok=True) log_file = os.path.join(os.getcwd(), f'{file_path}/run_deploy.log') deploy_args.log_file = log_file params += f'--log_file "{log_file}" ' params += '--ignore_args_error true ' if sys.platform == 'win32': if cuda_param: cuda_param = f'set {cuda_param} && ' run_command = f'{cuda_param}start /b swift deploy {params} > {log_file} 2>&1' else: run_command = f'{cuda_param} nohup swift deploy {params} > {log_file} 2>&1 &' return run_command, deploy_args, log_file @classmethod def deploy_model(cls, *args): run_command, deploy_args, log_file = cls.deploy(*args) logger.info(f'Running deployment command: {run_command}') os.system(run_command) gr.Info(cls.locale('load_alert', cls.lang)['value']) time.sleep(2) running_task = Runtime.refresh_tasks(log_file) return gr.update(open=True), running_task @classmethod def register_clean_hook(cls): signal.signal(signal.SIGINT, LLMInfer.signal_handler) if os.name != 'nt': signal.signal(signal.SIGTERM, LLMInfer.signal_handler) @staticmethod def signal_handler(*args, **kwargs): LLMInfer.clean_deployment() sys.exit(0) @classmethod def clear_session(cls): return '', [], gr.update(value=None), gr.update(value=None), gr.update(value=None), [] @classmethod def _replace_tag_with_media(cls, infer_request: InferRequest): total_history = [] messages = deepcopy(infer_request.messages) if messages[0]['role'] == 'system': messages.pop(0) for i in range(0, len(messages), 2): slices = messages[i:i + 2] if len(slices) == 2: user, assistant = slices else: user = slices[0] assistant = {'role': 'assistant', 'content': None} user['content'] = (user['content'] or '').replace('', '').replace('