| | import re |
| | import torch |
| | import requests |
| | from PIL import Image, ImageDraw |
| | from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration |
| |
|
| | repo = "microsoft/kosmos-2.5-chat" |
| | device = "cuda:0" |
| | dtype = torch.bfloat16 |
| |
|
| | model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, |
| | device_map=device, |
| | torch_dtype=dtype, |
| | attn_implementation="flash_attention_2") |
| | processor = AutoProcessor.from_pretrained(repo) |
| |
|
| | |
| | url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" |
| |
|
| | image = Image.open(requests.get(url, stream=True).raw) |
| |
|
| | question = "What is the sub total of the receipt?" |
| | template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" |
| | prompt = template.format(question) |
| | inputs = processor(text=prompt, images=image, return_tensors="pt") |
| |
|
| | height, width = inputs.pop("height"), inputs.pop("width") |
| | raw_width, raw_height = image.size |
| | scale_height = raw_height / height |
| | scale_width = raw_width / width |
| |
|
| | inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()} |
| | inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) |
| | generated_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=1024, |
| | ) |
| |
|
| | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) |
| | print(generated_text[0]) |
| |
|