#!/usr/bin/env python3 """ FastAPI application for FunctionGemma with HuggingFace login support. This file is designed to be run with: uvicorn app:app --host 0.0.0.0 --port 7860 修复:增加token计算 """ import os import sys from pathlib import Path from fastapi import FastAPI from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login # Global variables model_name = None pipe = None tokenizer = None # Add global tokenizer app = FastAPI(title="FunctionGemma API", version="1.0.0") def check_and_download_model(): """Check if model exists in cache, if not download it""" global model_name, tokenizer # Include tokenizer in global # Use TinyLlama - a fully public model # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_name = "unsloth/functiongemma-270m-it" # model_name = "Qwen/Qwen3-0.6B" cache_dir = "./my_model_cache" # Check if model already exists in cache model_path = Path(cache_dir) / f"models--{model_name.replace('/', '--')}" snapshot_path = model_path / "snapshots" if snapshot_path.exists() and any(snapshot_path.iterdir()): print(f"✓ Model {model_name} already exists in cache") tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) # Load tokenizer if model exists return model_name, cache_dir print(f"✗ Model {model_name} not found in cache") print("Downloading model...") # Login to Hugging Face (optional, for gated models) token = os.getenv("HUGGINGFACE_TOKEN") if token: try: print("Logging in to Hugging Face...") login(token=token) print("✓ HuggingFace login successful!") except Exception as e: print(f"⚠ Login failed: {e}") print("Continuing without login (public models only)") else: print("ℹ No HUGGINGFACE_TOKEN set - using public models only") try: # Download tokenizer print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) print("✓ Tokenizer loaded successfully!") # Download model print("Loading model...") model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) print("✓ Model loaded successfully!") print(f"✓ Model and tokenizer downloaded successfully to {cache_dir}") return model_name, cache_dir except Exception as e: print(f"✗ Error downloading model: {e}") print("\nPossible reasons:") print("1. Model requires authentication - set HUGGINGFACE_TOKEN in .env") print("2. Model is gated and you don't have access") print("3. Network connection issues") sys.exit(1) def initialize_pipeline(): """Initialize the pipeline with the model""" global pipe, model_name, tokenizer # Include tokenizer in global if model_name is None: model_name, _ = check_and_download_model() if tokenizer is None: # Ensure tokenizer is loaded tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./my_model_cache") print(f"Initializing pipeline with {model_name}...") pipe = pipeline("text-generation", model=model_name, tokenizer=tokenizer) # Pass tokenizer to pipeline print("✓ Pipeline initialized successfully!") # API Endpoints @app.get("/") def greet_json(): return { "message": "FunctionGemma API is running!", "model": model_name, "status": "ready" } @app.get("/health") def health_check(): return {"status": "healthy", "model": model_name} @app.get("/generate") def generate_text(prompt: str = "Who are you?"): """Generate text using the model""" if pipe is None: initialize_pipeline() messages = [{"role": "user", "content": prompt}] result = pipe(messages, max_new_tokens=1000) return {"response": result[0]["generated_text"]} @app.post("/chat") def chat_completion(messages: list): """Chat completion endpoint""" if pipe is None: initialize_pipeline() result = pipe(messages, max_new_tokens=200) return {"response": result[0]["generated_text"]} @app.post("/v1/chat/completions") def openai_chat_completions(request: dict): """ OpenAI-compatible chat completions endpoint Expected request format: { "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "messages": [ {"role": "user", "content": "Hello"} ], "max_tokens": 100, "temperature": 0.7 } """ if pipe is None: initialize_pipeline() import time messages = request.get("messages", []) model = request.get("model", model_name) max_tokens = request.get("max_tokens", 1000) temperature = request.get("temperature", 0.7) print('\n\n request') print(request) print('\n\n messages') print(messages) print('\n\n model') print(model) print('\n\n max_tokens') print(max_tokens) print('\n\n temperature') print(temperature) # Generate response result = pipe( messages, max_new_tokens=max_tokens, # temperature=temperature ) result = convert_json_format(result) completion_id = f"chatcmpl-{int(time.time())}" created = int(time.time()) return_json = { "id": completion_id, "object": "chat.completion", "created": created, "model": model, "choices": [ { "index": 0, "message": { "role": "assistant", "content": result["generations"][0][0]["text"] # Corrected access }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 } } # Calculate prompt tokens if tokenizer: prompt_text = "" for message in messages: prompt_text += message.get("content", "") + " " prompt_tokens = len(tokenizer.encode(prompt_text.strip())) return_json["usage"]["prompt_tokens"] = prompt_tokens # Calculate completion tokens if tokenizer and result["generations"]: completion_text = result["generations"][0][0]["text"] completion_tokens = len(tokenizer.encode(completion_text)) return_json["usage"]["completion_tokens"] = completion_tokens return_json["usage"]["total_tokens"] = return_json["usage"]["prompt_tokens"] + return_json["usage"]["completion_tokens"] print('\n\n return_json') print(return_json) print('return over! \n\n') return return_json # Initialize model on startup @app.on_event("startup") async def startup_event(): """Initialize the model when the app starts""" print("=" * 60) print("FunctionGemma FastAPI Server") print("=" * 60) print("Initializing model...") initialize_pipeline() print("\n" + "=" * 60) print("Server ready at http://0.0.0.0:7860") print("Available endpoints:") print(" GET / - Welcome message") print(" GET /health - Health check") print(" GET /generate?prompt=... - Generate text with prompt") print(" POST /chat - Chat completion") print(" POST /v1/chat/completions - OpenAI-compatible endpoint") print("=" * 60 + "\n") import re def convert_json_format(input_data): output_generations = [] for item in input_data: generated_text_list = item.get('generated_text', []) assistant_content = "" for message in generated_text_list: if message.get('role') == 'assistant': assistant_content = message.get('content', '') break # Assuming only one assistant response per generated_text # Remove ... tags clean_content = re.sub(r'.*?\s*', '', assistant_content, flags=re.DOTALL).strip() output_generations.append([ { "text": clean_content, "generationInfo": { "finish_reason": "stop" } } ]) return {"generations": output_generations}