import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from typing import List, Dict, Optional from rich.console import Console console = Console() class HuggingFaceChat: """Interface for chatting with Hugging Face models.""" def __init__(self, model_name: str = "microsoft/DialoGPT-medium"): """ Initialize the chat interface. Args: model_name: The Hugging Face model to use """ self.model_name = model_name self.device = 0 if torch.cuda.is_available() else -1 # Use GPU if available # Try loading the model with safetensors first try: console.print(f"[blue]Loading model: {model_name}[/blue]") # Try to load with safetensors format first self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True) self.chatbot = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device ) console.print("[green]✓ Model loaded successfully[/green]") except Exception as e: console.print(f"[red]Error loading model with safetensors: {e}[/red]") try: # Fallback to regular loading console.print("[yellow]Trying regular loading...[/yellow]") self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name) self.chatbot = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=self.device ) console.print("[green]✓ Model loaded successfully[/green]") except Exception as e2: console.print(f"[red]Error loading model: {e2}[/red]") console.print("[yellow]Falling back to simple text generation...[/yellow]") self.chatbot = None def generate_response(self, prompt: str, max_length: int = 1000) -> str: """ Generate a response to a prompt. Args: prompt: The input prompt max_length: Maximum length of the generated response Returns: The generated response """ if not self.chatbot: return "I'm sorry, but I couldn't load the AI model. This might be due to:\n1. Model loading issues\n2. Internet connection problems\n3. Server maintenance\n\nPlease try again in a few minutes." try: # Generate response with improved parameters for better quality response = self.chatbot( prompt, max_length=max_length, do_sample=True, temperature=0.8, # Higher for more creativity top_p=0.95, # Higher for more diverse responses top_k=50, # Limit to top 50 tokens repetition_penalty=1.1, # Lower penalty for more natural flow no_repeat_ngram_size=3, # Avoid repeating 3-word phrases pad_token_id=self.tokenizer.eos_token_id, truncation=True # Enable truncation to prevent errors ) # Extract the generated text generated_text = response[0]['generated_text'] # Remove the prompt from the response if it's included if generated_text.startswith(prompt): generated_text = generated_text[len(prompt):].strip() # Clean up the response # Remove any incomplete sentences or hanging punctuation generated_text = self._clean_response(generated_text) # If response is empty or too short, provide a helpful message if not generated_text or len(generated_text.strip()) < 5: return "I'm processing your message. Could you please try again or rephrase your question?" return generated_text except Exception as e: console.print(f"[red]Error generating response: {e}[/red]") return f"I'm experiencing technical difficulties. Error: {str(e)[:100]}..." def _clean_response(self, text: str) -> str: """Clean up the generated response.""" # Remove any trailing incomplete sentences if text.endswith(('.', '!', '?')): return text # Find the last complete sentence import re sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) > 1: # Remove the last incomplete sentence text = ' '.join(sentences[:-1]) return text.strip() def check_model_availability(self) -> bool: """Check if the model is available.""" return self.chatbot is not None def get_model_info(self) -> Dict: """Get information about the loaded model.""" return { "model_name": self.model_name, "device": "GPU" if self.device == 0 else "CPU", "available": self.chatbot is not None }