VAE Model for Swiss Roll
This is a Variational Autoencoder (VAE) model trained on the Swiss Roll dataset.
Model Description
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the Swiss Roll 2D manifold dataset. The model learns to encode 2D points from the Swiss Roll into a lower-dimensional latent space and decode them back, enabling both dimensionality reduction and generation of new points that lie on the Swiss Roll manifold.
The architecture is based on the implementation outlined in Auto-Encoding Variational Bayes by Diederik et al., 2022
Architecture Details
- Model Type: Variational Autoencoder (VAE)
- Framework: PyTorch
- Input: 2-dimensional points from Swiss Roll (x, z coordinates after projection)
- Latent Space: 2 dimensions
- Encoder and Decoder Layers: 2
- Encoder and Decoder Hidden Units: 96 โ 48 (encoder), 96 โ 48 (decoder)
- Total Parameters: 15,994
- Data type: Binary/Continous (automatically detected)
- Current Implementation: Continous (un-normalised)
Key Components
- Encoder Network: Maps input images to latent distribution parameters (ฮผ, ฯยฒ)
- Reparameterization Trick: Enables differentiable sampling from the latent distribution
- Decoder Network: Reconstructs images from latent space samples
- Loss Function: Combines reconstruction loss ELBO (Bernoulli: binary cross-entropy, Gaussian: negative log-likelihood) + KL divergence
Training Details
- Dataset: Swiss Roll (10,000 points generated using scikit-learn's make_swiss_roll)
- Train/Test Split: 80/20
- Batch Size: 128
- Epochs: 150
- Optimizer: Adam
- Learning Rate: 1e-3
- Gamma: 1e-1
Model Performance
Metrics
Final Training Loss: ~6.16
Reconstruction Loss: ~3.42
KL Divergence: ~2.74
Final Validation Loss: ~5.94
Reconstruction Loss: ~3.23
KL Divergence: ~2.71
Capabilities
- โ High-quality reconstruction of Swiss Roll points
- โ Smooth latent space interpolation
- โ Generation of new points along the Swiss Roll manifold
- โ Well-organized latent space capturing the underlying manifold structure
Usage
Using Transformers
from transformers import AutoModel
import torch
import torchvision.transforms as transforms
# Load model
model = AutoModel.from_pretrained("uday9k/SwissRoll_VAE")
# Generate samples
with torch.no_grad():
z = torch.randn(1, 20) # Sample from prior
generated = model.generate(z=z)
Visualizations Available
- Latent Space Visualization: 2D projection of the 2D latent space showing manifold structure
- Reconstructions: Original vs. reconstructed Swiss Roll points
- Generated Samples: New digits sampled from the latent space
- Interpolations: Smooth transitions between different regions of the Swiss Roll
- Training Curves: Loss components over training epochs
Files and Outputs
SwissRoll_VAE_Train.ipynb: Complete implementation with training and visualizationcustomVAE_model.pth: Trained model weightsgenerated_samples: Scatter plot of generated samples as part of notebooklatent_space_visualization: 2D latent space plot as part of notebookreconstruction_comparison: Original vs reconstructed images as part of notebooklatent_interpolation: Interpolation between points as part of notebookcomprehensive_training_curves: Training loss curves as part of notebook
Applications
This VAE implementation can be used for:
- Generative Modeling: Create new points lying on the Swiss Roll manifold
- Dimensionality Reduction: Compress 2D points to 2D latent representations
- Manifold Learning: Learn the underlying structure of the Swiss Roll data
- Interpolation: Generate smooth transitions between points on the manifold
- Educational Purposes: Understand VAE concepts and implementation
Research and Educational Value
This implementation serves as an excellent educational resource for:
- Understanding Variational Autoencoders theory and practice
- Visualizing how VAEs learn manifold structures
- Learning PyTorch implementation techniques
- Exploring latent space representations on simple data
- Studying the balance between reconstruction and regularization
Citation
If you use this implementation in your research or projects, please cite:
@misc{vae_mnist_implementation,
title={Variational Autoencoder Implementation for Swiss Roll},
author={Uday Jain},
year={2026},
url={https://huggingface.co/uday9k/SwissRoll_VAE}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Additional Resources
- GitHub Repository: Profile
Tags: deep-learning, generative-ai, pytorch, vae, swiss-roll, unsupervised-learning
Model Card Authors: Uday Jain
- Downloads last month
- 185