Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from typing import List, Dict, Optional | |
| from rich.console import Console | |
| console = Console() | |
| class HuggingFaceChat: | |
| """Interface for chatting with Hugging Face models.""" | |
| def __init__(self, model_name: str = "microsoft/DialoGPT-medium"): | |
| """ | |
| Initialize the chat interface. | |
| Args: | |
| model_name: The Hugging Face model to use | |
| """ | |
| self.model_name = model_name | |
| self.device = 0 if torch.cuda.is_available() else -1 # Use GPU if available | |
| # Try loading the model with safetensors first | |
| try: | |
| console.print(f"[blue]Loading model: {model_name}[/blue]") | |
| # Try to load with safetensors format first | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True) | |
| self.chatbot = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=self.device | |
| ) | |
| console.print("[green]✓ Model loaded successfully[/green]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading model with safetensors: {e}[/red]") | |
| try: | |
| # Fallback to regular loading | |
| console.print("[yellow]Trying regular loading...[/yellow]") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
| self.chatbot = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=self.device | |
| ) | |
| console.print("[green]✓ Model loaded successfully[/green]") | |
| except Exception as e2: | |
| console.print(f"[red]Error loading model: {e2}[/red]") | |
| console.print("[yellow]Falling back to simple text generation...[/yellow]") | |
| self.chatbot = None | |
| def generate_response(self, prompt: str, max_length: int = 1000) -> str: | |
| """ | |
| Generate a response to a prompt. | |
| Args: | |
| prompt: The input prompt | |
| max_length: Maximum length of the generated response | |
| Returns: | |
| The generated response | |
| """ | |
| if not self.chatbot: | |
| return "I'm sorry, but I couldn't load the AI model. This might be due to:\n1. Model loading issues\n2. Internet connection problems\n3. Server maintenance\n\nPlease try again in a few minutes." | |
| try: | |
| # Generate response with improved parameters for better quality | |
| response = self.chatbot( | |
| prompt, | |
| max_length=max_length, | |
| do_sample=True, | |
| temperature=0.8, # Higher for more creativity | |
| top_p=0.95, # Higher for more diverse responses | |
| top_k=50, # Limit to top 50 tokens | |
| repetition_penalty=1.1, # Lower penalty for more natural flow | |
| no_repeat_ngram_size=3, # Avoid repeating 3-word phrases | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| truncation=True # Enable truncation to prevent errors | |
| ) | |
| # Extract the generated text | |
| generated_text = response[0]['generated_text'] | |
| # Remove the prompt from the response if it's included | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt):].strip() | |
| # Clean up the response | |
| # Remove any incomplete sentences or hanging punctuation | |
| generated_text = self._clean_response(generated_text) | |
| # If response is empty or too short, provide a helpful message | |
| if not generated_text or len(generated_text.strip()) < 5: | |
| return "I'm processing your message. Could you please try again or rephrase your question?" | |
| return generated_text | |
| except Exception as e: | |
| console.print(f"[red]Error generating response: {e}[/red]") | |
| return f"I'm experiencing technical difficulties. Error: {str(e)[:100]}..." | |
| def _clean_response(self, text: str) -> str: | |
| """Clean up the generated response.""" | |
| # Remove any trailing incomplete sentences | |
| if text.endswith(('.', '!', '?')): | |
| return text | |
| # Find the last complete sentence | |
| import re | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| if len(sentences) > 1: | |
| # Remove the last incomplete sentence | |
| text = ' '.join(sentences[:-1]) | |
| return text.strip() | |
| def check_model_availability(self) -> bool: | |
| """Check if the model is available.""" | |
| return self.chatbot is not None | |
| def get_model_info(self) -> Dict: | |
| """Get information about the loaded model.""" | |
| return { | |
| "model_name": self.model_name, | |
| "device": "GPU" if self.device == 0 else "CPU", | |
| "available": self.chatbot is not None | |
| } |