model / app.py
airsltd's picture
Update app.py
c124a7b verified
#!/usr/bin/env python3
"""
FastAPI application for FunctionGemma with HuggingFace login support.
This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860
修复:增加token计算
"""
import os
import sys
from pathlib import Path
from fastapi import FastAPI
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login
# Global variables
model_name = None
pipe = None
tokenizer = None # Add global tokenizer
app = FastAPI(title="FunctionGemma API", version="1.0.0")
def check_and_download_model():
"""Check if model exists in cache, if not download it"""
global model_name, tokenizer # Include tokenizer in global
# Use TinyLlama - a fully public model
# model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_name = "unsloth/functiongemma-270m-it"
# model_name = "Qwen/Qwen3-0.6B"
cache_dir = "./my_model_cache"
# Check if model already exists in cache
model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}"
snapshot_path = model_path / "snapshots"
if snapshot_path.exists() and any(snapshot_path.iterdir()):
print(f"✓ Model {model_name} already exists in cache")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) # Load tokenizer if model exists
return model_name, cache_dir
print(f"✗ Model {model_name} not found in cache")
print("Downloading model...")
# Login to Hugging Face (optional, for gated models)
token = os.getenv("HUGGINGFACE_TOKEN")
if token:
try:
print("Logging in to Hugging Face...")
login(token=token)
print("✓ HuggingFace login successful!")
except Exception as e:
print(f"⚠ Login failed: {e}")
print("Continuing without login (public models only)")
else:
print("ℹ No HUGGINGFACE_TOKEN set - using public models only")
try:
# Download tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
print("✓ Tokenizer loaded successfully!")
# Download model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
print("✓ Model loaded successfully!")
print(f"✓ Model and tokenizer downloaded successfully to {cache_dir}")
return model_name, cache_dir
except Exception as e:
print(f"✗ Error downloading model: {e}")
print("\nPossible reasons:")
print("1. Model requires authentication - set HUGGINGFACE_TOKEN in .env")
print("2. Model is gated and you don't have access")
print("3. Network connection issues")
sys.exit(1)
def initialize_pipeline():
"""Initialize the pipeline with the model"""
global pipe, model_name, tokenizer # Include tokenizer in global
if model_name is None:
model_name, _ = check_and_download_model()
if tokenizer is None: # Ensure tokenizer is loaded
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./my_model_cache")
print(f"Initializing pipeline with {model_name}...")
pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) # Pass tokenizer to pipeline
print("✓ Pipeline initialized successfully!")
# API Endpoints
@app.get("/")
def greet_json():
return {
"message": "FunctionGemma API is running!",
"model": model_name,
"status": "ready"
}
@app.get("/health")
def health_check():
return {"status": "healthy", "model": model_name}
@app.get("/generate")
def generate_text(prompt: str = "Who are you?"):
"""Generate text using the model"""
if pipe is None:
initialize_pipeline()
messages = [{"role": "user", "content": prompt}]
result = pipe(messages, max_new_tokens=1000)
return {"response": result[0]["generated_text"]}
@app.post("/chat")
def chat_completion(messages: list):
"""Chat completion endpoint"""
if pipe is None:
initialize_pipeline()
result = pipe(messages, max_new_tokens=200)
return {"response": result[0]["generated_text"]}
@app.post("/v1/chat/completions")
def openai_chat_completions(request: dict):
"""
OpenAI-compatible chat completions endpoint
Expected request format:
{
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 100,
"temperature": 0.7
}
"""
if pipe is None:
initialize_pipeline()
import time
messages = request.get("messages", [])
model = request.get("model", model_name)
max_tokens = request.get("max_tokens", 1000)
temperature = request.get("temperature", 0.7)
print('\n\n request')
print(request)
print('\n\n messages')
print(messages)
print('\n\n model')
print(model)
print('\n\n max_tokens')
print(max_tokens)
print('\n\n temperature')
print(temperature)
# Generate response
result = pipe(
messages,
max_new_tokens=max_tokens,
# temperature=temperature
)
result = convert_json_format(result)
completion_id = f"chatcmpl-{int(time.time())}"
created = int(time.time())
return_json = {
"id": completion_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": result["generations"][0][0]["text"] # Corrected access
},
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
# Calculate prompt tokens
if tokenizer:
prompt_text = ""
for message in messages:
prompt_text += message.get("content", "") + " "
prompt_tokens = len(tokenizer.encode(prompt_text.strip()))
return_json["usage"]["prompt_tokens"] = prompt_tokens
# Calculate completion tokens
if tokenizer and result["generations"]:
completion_text = result["generations"][0][0]["text"]
completion_tokens = len(tokenizer.encode(completion_text))
return_json["usage"]["completion_tokens"] = completion_tokens
return_json["usage"]["total_tokens"] = return_json["usage"]["prompt_tokens"] + return_json["usage"]["completion_tokens"]
print('\n\n return_json')
print(return_json)
print('return over! \n\n')
return return_json
# Initialize model on startup
@app.on_event("startup")
async def startup_event():
"""Initialize the model when the app starts"""
print("=" * 60)
print("FunctionGemma FastAPI Server")
print("=" * 60)
print("Initializing model...")
initialize_pipeline()
print("\n" + "=" * 60)
print("Server ready at http://0.0.0.0:7860")
print("Available endpoints:")
print(" GET / - Welcome message")
print(" GET /health - Health check")
print(" GET /generate?prompt=... - Generate text with prompt")
print(" POST /chat - Chat completion")
print(" POST /v1/chat/completions - OpenAI-compatible endpoint")
print("=" * 60 + "\n")
import re
def convert_json_format(input_data):
output_generations = []
for item in input_data:
generated_text_list = item.get('generated_text', [])
assistant_content = ""
for message in generated_text_list:
if message.get('role') == 'assistant':
assistant_content = message.get('content', '')
break # Assuming only one assistant response per generated_text
# Remove <think>...</think> tags
clean_content = re.sub(r'<think>.*?</think>\s*', '', assistant_content, flags=re.DOTALL).strip()
output_generations.append([
{
"text": clean_content,
"generationInfo": {
"finish_reason": "stop"
}
}
])
return {"generations": output_generations}