| | import gradio as gr |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import spaces |
| | import torch |
| |
|
| | model_name = "sarvamai/sarvam-translate" |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | model = AutoModelForCausalLM.from_pretrained(model_name).to('cuda:0') |
| |
|
| | @spaces.GPU |
| | def generate(tgt_lang, input_txt): |
| | messages = [ |
| | {"role": "system", "content": f"Translate the following sentence into {tgt_lang}."}, |
| | {"role": "user", "content": input_txt}, |
| | ] |
| | |
| | |
| | text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | |
| | |
| | model_inputs = tokenizer([text], return_tensors="pt").to(model.device) |
| | |
| | |
| | generated_ids = model.generate( |
| | **model_inputs, |
| | max_new_tokens=1024, |
| | do_sample=True, |
| | temperature=0.01, |
| | num_return_sequences=1 |
| | ) |
| | output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() |
| | return tokenizer.decode(output_ids, skip_special_tokens=True) |
| |
|
| | demo = gr.Interface( |
| | fn=generate, |
| | inputs=[ |
| | gr.Radio(["Hindi", "Bengali", "Marathi", "Telugu", "Tamil", "Gujarati", "Urdu", "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", "Santali", "Kashmiri", "Nepali", "Sindhi", "Dogri", "Konkani", "Manipuri (Meitei)", "Bodo", "Sanskrit"], label="Target Language", value="Hindi"), |
| | gr.Textbox(label="Input Text", value="Be the change you wish to see in the world."), |
| | ], |
| | outputs=gr.Textbox(label="Translation"), |
| | title="translate" |
| | ) |
| | demo.launch() |