🧠 Brain Stroke Segmentation - LCNN Model
Figure 1: Sample segmentation results showing precise lesion localization on CT scans.
📋 Model Overview
This repository hosts a trained LCNN (Local-Global Combined Neural Network) for automatic brain stroke lesion segmentation from CT scans. The model achieves high precision by combining local anatomical details with global semantic context, specifically addressing the challenge of segmenting small lesions in complex brain structures.
Key Features
- 🎯 SEAN (Symmetry Enhanced Attention Network): Exploits the natural symmetry of the brain to detect anomalies by comparing left and right hemispheres.
- 🌐 Global Context via ResNeXt50: Utilizes a powerful backbone to understand the global context of the slice, significantly reducing false positives.
- 🔄 Dual-Path Fusion: Merges local detailed features with global semantic features for robust segmentation.
- 🏥 Clinical Alignment: Includes an alignment module to automatically correct head rotation, ensuring consistent processing.
📊 Visual Gallery
The model demonstrates robust performance across various stroke types and locations.
Segmentation Samples
📈 Performance & Results
Training Dashboard
The training process was monitored extensively to ensure stable convergence and optimal performance.
Figure 2: Comprehensive training dashboard showing loss, metrics, and resource usage.
Detailed Metrics
The model was evaluated using strict medical imaging metrics.
Training Progression
Tracking the loss components and validation Dice score throughout the training epochs. The graphs below demonstrate steady convergence and learning stability.
Figure 3: Top Left: Training Loss curve. Top Right: Validation Dice Score. Bottom: Learning Rate schedule.
🔬 Technical Details
Architecture Components
- Local Path (SEAN): A 3D UNet-like structure that processes adjacent slices (T=1, total 3 slices) to capture volumetric continuity. It uses a Symmetry Enhanced Attention mechanism.
- Global Path (ResNeXt50): A 2D CNN backbone that processes the center slice to capture high-level semantics.
- Fusion: The outputs are fused with a weighted average (Local: 0.7, Global: 0.3).
Loss Function Analysis
We employ a composite loss function designed to handle class imbalance (Dice), pixel accuracy (Cross-Entropy), and symmetry enforcement (Alignment).
Figure 4: Breakdown of the combined loss function components over time.
# Loss Composition
total_loss = 0.7 * dice_loss + 0.3 * ce_loss + 0.05 * alignment_loss
🚀 Quick Start
Installation
git clone https://github.com/hoangtung386/brain-stroke-segmentation.git
cd brain-stroke-segmentation
pip install -r requirements.txt
pip install huggingface_hub
Inference Example
import torch
import os
from models.lcnn import LCNN
from huggingface_hub import hf_hub_download
# Create checkpoints folder
os.makedirs("./checkpoints", exist_ok=True)
# Download directly to the folder without creating a complex structure
checkpoint_path = hf_hub_download(
repo_id="hoangtung386/brain-stroke-lcnn",
filename="best_model.pth",
local_dir="./checkpoints",
local_dir_use_symlinks=False
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LCNN(num_channels=1, num_classes=2, global_impact=0.3, local_impact=0.7, T=1).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("\n--------------------------------------")
print("Model ready for inference")
print(f"Model path: {checkpoint_path}")
print("--------------------------------------\n")
📄 License
This project is licensed under the MIT License - see LICENSE.txt for details.
🤝 Acknowledgments
Developed by Le Vu Hoang Tung (@hoangtung386).
- Frameworks: PyTorch, MONAI
- Tools: Weights & Biases for experiment tracking.
- AI Assistance: Google Gemini 3.0 Pro & local LLMs for code auditing.
- Downloads last month
- 59