|
|
|
|
|
""" |
|
|
FastAPI application for FunctionGemma with HuggingFace login support. |
|
|
This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860 |
|
|
修复:增加token计算 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from fastapi import FastAPI |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
model_name = None |
|
|
pipe = None |
|
|
tokenizer = None |
|
|
app = FastAPI(title="FunctionGemma API", version="1.0.0") |
|
|
|
|
|
def check_and_download_model(): |
|
|
"""Check if model exists in cache, if not download it""" |
|
|
global model_name, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
model_name = "unsloth/functiongemma-270m-it" |
|
|
|
|
|
cache_dir = "./my_model_cache" |
|
|
|
|
|
|
|
|
model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}" |
|
|
snapshot_path = model_path / "snapshots" |
|
|
|
|
|
if snapshot_path.exists() and any(snapshot_path.iterdir()): |
|
|
print(f"✓ Model {model_name} already exists in cache") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
return model_name, cache_dir |
|
|
|
|
|
print(f"✗ Model {model_name} not found in cache") |
|
|
print("Downloading model...") |
|
|
|
|
|
|
|
|
token = os.getenv("HUGGINGFACE_TOKEN") |
|
|
if token: |
|
|
try: |
|
|
print("Logging in to Hugging Face...") |
|
|
login(token=token) |
|
|
print("✓ HuggingFace login successful!") |
|
|
except Exception as e: |
|
|
print(f"⚠ Login failed: {e}") |
|
|
print("Continuing without login (public models only)") |
|
|
else: |
|
|
print("ℹ No HUGGINGFACE_TOKEN set - using public models only") |
|
|
|
|
|
try: |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
print("✓ Tokenizer loaded successfully!") |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) |
|
|
print("✓ Model loaded successfully!") |
|
|
|
|
|
print(f"✓ Model and tokenizer downloaded successfully to {cache_dir}") |
|
|
return model_name, cache_dir |
|
|
|
|
|
except Exception as e: |
|
|
print(f"✗ Error downloading model: {e}") |
|
|
print("\nPossible reasons:") |
|
|
print("1. Model requires authentication - set HUGGINGFACE_TOKEN in .env") |
|
|
print("2. Model is gated and you don't have access") |
|
|
print("3. Network connection issues") |
|
|
sys.exit(1) |
|
|
|
|
|
def initialize_pipeline(): |
|
|
"""Initialize the pipeline with the model""" |
|
|
global pipe, model_name, tokenizer |
|
|
|
|
|
if model_name is None: |
|
|
model_name, _ = check_and_download_model() |
|
|
|
|
|
if tokenizer is None: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./my_model_cache") |
|
|
|
|
|
print(f"Initializing pipeline with {model_name}...") |
|
|
pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) |
|
|
print("✓ Pipeline initialized successfully!") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def greet_json(): |
|
|
return { |
|
|
"message": "FunctionGemma API is running!", |
|
|
"model": model_name, |
|
|
"status": "ready" |
|
|
} |
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return {"status": "healthy", "model": model_name} |
|
|
|
|
|
@app.get("/generate") |
|
|
def generate_text(prompt: str = "Who are you?"): |
|
|
"""Generate text using the model""" |
|
|
if pipe is None: |
|
|
initialize_pipeline() |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
result = pipe(messages, max_new_tokens=1000) |
|
|
return {"response": result[0]["generated_text"]} |
|
|
|
|
|
@app.post("/chat") |
|
|
def chat_completion(messages: list): |
|
|
"""Chat completion endpoint""" |
|
|
if pipe is None: |
|
|
initialize_pipeline() |
|
|
|
|
|
result = pipe(messages, max_new_tokens=200) |
|
|
return {"response": result[0]["generated_text"]} |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
def openai_chat_completions(request: dict): |
|
|
""" |
|
|
OpenAI-compatible chat completions endpoint |
|
|
Expected request format: |
|
|
{ |
|
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
"messages": [ |
|
|
{"role": "user", "content": "Hello"} |
|
|
], |
|
|
"max_tokens": 100, |
|
|
"temperature": 0.7 |
|
|
} |
|
|
""" |
|
|
if pipe is None: |
|
|
initialize_pipeline() |
|
|
|
|
|
import time |
|
|
|
|
|
messages = request.get("messages", []) |
|
|
model = request.get("model", model_name) |
|
|
max_tokens = request.get("max_tokens", 1000) |
|
|
temperature = request.get("temperature", 0.7) |
|
|
|
|
|
print('\n\n request') |
|
|
print(request) |
|
|
print('\n\n messages') |
|
|
print(messages) |
|
|
print('\n\n model') |
|
|
print(model) |
|
|
print('\n\n max_tokens') |
|
|
print(max_tokens) |
|
|
print('\n\n temperature') |
|
|
print(temperature) |
|
|
|
|
|
|
|
|
result = pipe( |
|
|
messages, |
|
|
max_new_tokens=max_tokens, |
|
|
|
|
|
) |
|
|
|
|
|
result = convert_json_format(result) |
|
|
|
|
|
|
|
|
completion_id = f"chatcmpl-{int(time.time())}" |
|
|
created = int(time.time()) |
|
|
|
|
|
return_json = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion", |
|
|
"created": created, |
|
|
"model": model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": result["generations"][0][0]["text"] |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": 0, |
|
|
"completion_tokens": 0, |
|
|
"total_tokens": 0 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if tokenizer: |
|
|
prompt_text = "" |
|
|
for message in messages: |
|
|
prompt_text += message.get("content", "") + " " |
|
|
prompt_tokens = len(tokenizer.encode(prompt_text.strip())) |
|
|
return_json["usage"]["prompt_tokens"] = prompt_tokens |
|
|
|
|
|
|
|
|
if tokenizer and result["generations"]: |
|
|
completion_text = result["generations"][0][0]["text"] |
|
|
completion_tokens = len(tokenizer.encode(completion_text)) |
|
|
return_json["usage"]["completion_tokens"] = completion_tokens |
|
|
|
|
|
return_json["usage"]["total_tokens"] = return_json["usage"]["prompt_tokens"] + return_json["usage"]["completion_tokens"] |
|
|
|
|
|
print('\n\n return_json') |
|
|
print(return_json) |
|
|
print('return over! \n\n') |
|
|
|
|
|
return return_json |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the model when the app starts""" |
|
|
print("=" * 60) |
|
|
print("FunctionGemma FastAPI Server") |
|
|
print("=" * 60) |
|
|
print("Initializing model...") |
|
|
initialize_pipeline() |
|
|
print("\n" + "=" * 60) |
|
|
print("Server ready at http://0.0.0.0:7860") |
|
|
print("Available endpoints:") |
|
|
print(" GET / - Welcome message") |
|
|
print(" GET /health - Health check") |
|
|
print(" GET /generate?prompt=... - Generate text with prompt") |
|
|
print(" POST /chat - Chat completion") |
|
|
print(" POST /v1/chat/completions - OpenAI-compatible endpoint") |
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
import re |
|
|
|
|
|
def convert_json_format(input_data): |
|
|
output_generations = [] |
|
|
for item in input_data: |
|
|
generated_text_list = item.get('generated_text', []) |
|
|
|
|
|
assistant_content = "" |
|
|
for message in generated_text_list: |
|
|
if message.get('role') == 'assistant': |
|
|
assistant_content = message.get('content', '') |
|
|
break |
|
|
|
|
|
|
|
|
clean_content = re.sub(r'<think>.*?</think>\s*', '', assistant_content, flags=re.DOTALL).strip() |
|
|
|
|
|
output_generations.append([ |
|
|
{ |
|
|
"text": clean_content, |
|
|
"generationInfo": { |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
} |
|
|
]) |
|
|
|
|
|
return {"generations": output_generations} |
|
|
|