memorychat / chat_interface.py
artecnosomatic's picture
Add version tracking (v2.0.0) and improve error handling for better user experience
5e98df0
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from typing import List, Dict, Optional
from rich.console import Console
console = Console()
class HuggingFaceChat:
"""Interface for chatting with Hugging Face models."""
def __init__(self, model_name: str = "microsoft/DialoGPT-medium"):
"""
Initialize the chat interface.
Args:
model_name: The Hugging Face model to use
"""
self.model_name = model_name
self.device = 0 if torch.cuda.is_available() else -1 # Use GPU if available
# Try loading the model with safetensors first
try:
console.print(f"[blue]Loading model: {model_name}[/blue]")
# Try to load with safetensors format first
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
self.model = AutoModelForCausalLM.from_pretrained(model_name, use_safetensors=True)
self.chatbot = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device
)
console.print("[green]✓ Model loaded successfully[/green]")
except Exception as e:
console.print(f"[red]Error loading model with safetensors: {e}[/red]")
try:
# Fallback to regular loading
console.print("[yellow]Trying regular loading...[/yellow]")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.chatbot = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=self.device
)
console.print("[green]✓ Model loaded successfully[/green]")
except Exception as e2:
console.print(f"[red]Error loading model: {e2}[/red]")
console.print("[yellow]Falling back to simple text generation...[/yellow]")
self.chatbot = None
def generate_response(self, prompt: str, max_length: int = 1000) -> str:
"""
Generate a response to a prompt.
Args:
prompt: The input prompt
max_length: Maximum length of the generated response
Returns:
The generated response
"""
if not self.chatbot:
return "I'm sorry, but I couldn't load the AI model. This might be due to:\n1. Model loading issues\n2. Internet connection problems\n3. Server maintenance\n\nPlease try again in a few minutes."
try:
# Generate response with improved parameters for better quality
response = self.chatbot(
prompt,
max_length=max_length,
do_sample=True,
temperature=0.8, # Higher for more creativity
top_p=0.95, # Higher for more diverse responses
top_k=50, # Limit to top 50 tokens
repetition_penalty=1.1, # Lower penalty for more natural flow
no_repeat_ngram_size=3, # Avoid repeating 3-word phrases
pad_token_id=self.tokenizer.eos_token_id,
truncation=True # Enable truncation to prevent errors
)
# Extract the generated text
generated_text = response[0]['generated_text']
# Remove the prompt from the response if it's included
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
# Clean up the response
# Remove any incomplete sentences or hanging punctuation
generated_text = self._clean_response(generated_text)
# If response is empty or too short, provide a helpful message
if not generated_text or len(generated_text.strip()) < 5:
return "I'm processing your message. Could you please try again or rephrase your question?"
return generated_text
except Exception as e:
console.print(f"[red]Error generating response: {e}[/red]")
return f"I'm experiencing technical difficulties. Error: {str(e)[:100]}..."
def _clean_response(self, text: str) -> str:
"""Clean up the generated response."""
# Remove any trailing incomplete sentences
if text.endswith(('.', '!', '?')):
return text
# Find the last complete sentence
import re
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) > 1:
# Remove the last incomplete sentence
text = ' '.join(sentences[:-1])
return text.strip()
def check_model_availability(self) -> bool:
"""Check if the model is available."""
return self.chatbot is not None
def get_model_info(self) -> Dict:
"""Get information about the loaded model."""
return {
"model_name": self.model_name,
"device": "GPU" if self.device == 0 else "CPU",
"available": self.chatbot is not None
}