smugri4-preview / app.py
mphi's picture
Update app.py
781ddba verified
import gradio as gr
import torch
from kuidastaltsutadalaamat.trainllm import load_model, load_tokenizer
from kuidastaltsutadalaamat.inference import llm_generate
from kuidastaltsutadalaamat.data import LazyTokenizingInferenceDataset
from kuidastaltsutadalaamat.promptops import *
accel = None
model_id = "tartuNLP/smugri4-mt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(model_id, device, accelerator=accel, attention="eager") #eager for cpu
model.eval()
tokenizer = load_tokenizer(model_id, accelerator=accel)
lang_raw_to_label = {"English": "English",
"Erzya": "Erzya",
"Estonian": "Estonian",
"Estonian, Alutaguse, Lüg, dictionary": "Alutaguse",
"Estonian, Hiiu, Rei, dictionary": "Hiiu",
"Estonian, Ida, Kod, dictionary": "Ida",
"Estonian, Kesk, Kjn, dictionary": "Keskmurre",
"Estonian, Kihnu, dictionary": "Kihnu",
"Estonian, Lääne, Mar, dictionary": "Lääne",
"Estonian, Muhu, dictionary": "Muhu",
"Estonian, Ranna, Kuu, dictionary": "Rannakeel",
"Estonian, Saare, Khk, dictionary": "Saare",
"Southern Estonian, Mulgi, Krk, dictionary": "Mulgi",
"Southern Estonian, Seto, dictionary": "Seto",
"Southern Estonian, Tartu, Nõo, dictionary": "Tartu",
"Southern Estonian, Võro, Lei, dictionary": "Leivu",
"Southern Estonian, Võro, Lut, dictionary": "Lutsi",
"Southern Estonian, Võro, Sõnaq": "Võro (Sõnaq orth)",
"Southern Estonian, Võro, Uma": "Võro (Umaleht orth)",
"Finnish": "Finnish",
"Kven": "Kven",
"Meänkieli": "Meänkieli",
"Hill Mari": "Hill Mari",
"Meadow Mari": "Meadow Mari",
"Hungarian": "Hungarian",
"Inari Sami": "Inari Sami",
"Pite Sami": "Pite Sami",
"Kildin Sami, Antonova": "Kildin Sami (Antonova orth)",
"Kildin Sami, Kuruch": "Kildin Sami (Kuruch orth)",
"Lule Sami": "Lule Sami",
"Northern Sami": "Northern Sami",
"Skolt Sami": "Skolt Sami",
"Southern Sami": "Southern Sami",
"Ume Sami": "Ume Sami",
"Izhorian, Mehmet": "Ingrian (Ala-Laukaa / simplified)",
"Izhorian, Alamaluuga, speech": "Ingrian (Ala-Laukaa)",
"Izhorian, Soikkola": "Ingrian (Soikkola)",
"Votic, Standard": "Votic",
"Komi-Permyak": "Komi-Permyak",
"Komi-Zyrian": "Komi-Zyrian",
"Latvian": "Latvian",
"Livonian, Standard": "Livonian",
"Livvi, Newwritten": "Livvi",
"Ludian, Miikul": "Ludian (ü)",
"Ludian, Newwritten": "Ludian (y)",
"Mansi, Unk": "Mansi (Northern)",
"Moksha": "Moksha",
"Norwegian": "Norwegian",
"Kazym Khanty, 2013": "Kazym Khanty",
"Priur Khanty": "Priur Khanty",
"Shur Khanty, 2013": "Shur Khanty",
"Sred Khanty": "Sred Khanty",
"Surgut Khanty, 2013": "Surgut Khanty",
"Vakh Khanty, 2013": "Vakh Khanty",
"Proper Karelian, Newwritten": "Proper Karelian",
"Russian": "Russian",
"Swedish": "Swedish",
"Udmurt": "Udmurt",
"Veps, Newwritten": "Veps"
}
label_to_raw = { e[1]: e[0] for e in lang_raw_to_label.items() }
languages_labels = sorted(list(label_to_raw.keys()))
def run_inference(text, from_lang, to_lang, mode):
entry = {"src_segm": text, "task": mode}
if mode == "translate":
entry.update({"src_lang": label_to_raw[from_lang], "tgt_lang": label_to_raw[to_lang]})
prompt_format = PF_SMUGRI_MT
else:
prompt_format = PF_SMUGRI_LID
ds = LazyTokenizingInferenceDataset([entry], tokenizer, prompt_format)
tok = ds[0]
output = llm_generate(model, tokenizer, tok, debug=False, max_len=512)
return output[0]
with gr.Blocks() as demo:
text_input = gr.Textbox(label="Text", lines=6, placeholder="Enter text...")
#identify_btn = gr.Button("Identify language", interactive=False)
with gr.Row():
from_dropdown = gr.Dropdown(choices=languages_labels, label="From", value=None)
to_dropdown = gr.Dropdown(choices=languages_labels, label="To", value=None)
translate_btn = gr.Button("Translate", interactive=False)
output = gr.Textbox(label="Output", lines=6)
#def toggle_identify(text):
# return gr.update(interactive=bool(text.strip()))
#text_input.change(toggle_identify, [text_input], [identify_btn])
def toggle_translate(text, f, t):
return gr.update(interactive=bool(text.strip() and f and t))
text_input.change(toggle_translate, [text_input, from_dropdown, to_dropdown], [translate_btn])
from_dropdown.change(toggle_translate, [text_input, from_dropdown, to_dropdown], [translate_btn])
to_dropdown.change(toggle_translate, [text_input, from_dropdown, to_dropdown], [translate_btn])
#identify_btn.click(
# fn=lambda text: run_inference(text, None, None, mode="identify"),
# inputs=[text_input],
# outputs=[from_dropdown],
#).then(
# lambda *args: gr.update(interactive=bool(text_input.value.strip() and from_dropdown.value and to_dropdown.value)),
# [], [translate_btn]
#)
translate_btn.click(
fn=lambda text, f, t: run_inference(text, f, t, mode="translate"),
inputs=[text_input, from_dropdown, to_dropdown],
outputs=[output]
)
if __name__ == "__main__":
demo.launch()