VAE Model for Swiss Roll

This is a Variational Autoencoder (VAE) model trained on the Swiss Roll dataset.

Model Description

This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the Swiss Roll 2D manifold dataset. The model learns to encode 2D points from the Swiss Roll into a lower-dimensional latent space and decode them back, enabling both dimensionality reduction and generation of new points that lie on the Swiss Roll manifold.

The architecture is based on the implementation outlined in Auto-Encoding Variational Bayes by Diederik et al., 2022

Architecture Details

  • Model Type: Variational Autoencoder (VAE)
  • Framework: PyTorch
  • Input: 2-dimensional points from Swiss Roll (x, z coordinates after projection)
  • Latent Space: 2 dimensions
  • Encoder and Decoder Layers: 2
  • Encoder and Decoder Hidden Units: 96 โ†’ 48 (encoder), 96 โ†’ 48 (decoder)
  • Total Parameters: 15,994
  • Data type: Binary/Continous (automatically detected)
  • Current Implementation: Continous (un-normalised)

Key Components

  1. Encoder Network: Maps input images to latent distribution parameters (ฮผ, ฯƒยฒ)
  2. Reparameterization Trick: Enables differentiable sampling from the latent distribution
  3. Decoder Network: Reconstructs images from latent space samples
  4. Loss Function: Combines reconstruction loss ELBO (Bernoulli: binary cross-entropy, Gaussian: negative log-likelihood) + KL divergence

Training Details

  • Dataset: Swiss Roll (10,000 points generated using scikit-learn's make_swiss_roll)
  • Train/Test Split: 80/20
  • Batch Size: 128
  • Epochs: 150
  • Optimizer: Adam
  • Learning Rate: 1e-3
  • Gamma: 1e-1

Model Performance

Metrics

  • Final Training Loss: ~6.16

  • Reconstruction Loss: ~3.42

  • KL Divergence: ~2.74

  • Final Validation Loss: ~5.94

  • Reconstruction Loss: ~3.23

  • KL Divergence: ~2.71

Capabilities

  • โœ… High-quality reconstruction of Swiss Roll points
  • โœ… Smooth latent space interpolation
  • โœ… Generation of new points along the Swiss Roll manifold
  • โœ… Well-organized latent space capturing the underlying manifold structure

Usage

Using Transformers

from transformers import AutoModel
import torch
import torchvision.transforms as transforms

# Load model
model = AutoModel.from_pretrained("uday9k/SwissRoll_VAE")

# Generate samples
with torch.no_grad():
    z = torch.randn(1, 20)  # Sample from prior
    generated = model.generate(z=z)

Visualizations Available

  1. Latent Space Visualization: 2D projection of the 2D latent space showing manifold structure
  2. Reconstructions: Original vs. reconstructed Swiss Roll points
  3. Generated Samples: New digits sampled from the latent space
  4. Interpolations: Smooth transitions between different regions of the Swiss Roll
  5. Training Curves: Loss components over training epochs

Files and Outputs

  • SwissRoll_VAE_Train.ipynb: Complete implementation with training and visualization
  • customVAE_model.pth: Trained model weights
  • generated_samples: Scatter plot of generated samples as part of notebook
  • latent_space_visualization: 2D latent space plot as part of notebook
  • reconstruction_comparison: Original vs reconstructed images as part of notebook
  • latent_interpolation: Interpolation between points as part of notebook
  • comprehensive_training_curves: Training loss curves as part of notebook

Applications

This VAE implementation can be used for:

  • Generative Modeling: Create new points lying on the Swiss Roll manifold
  • Dimensionality Reduction: Compress 2D points to 2D latent representations
  • Manifold Learning: Learn the underlying structure of the Swiss Roll data
  • Interpolation: Generate smooth transitions between points on the manifold
  • Educational Purposes: Understand VAE concepts and implementation

Research and Educational Value

This implementation serves as an excellent educational resource for:

  • Understanding Variational Autoencoders theory and practice
  • Visualizing how VAEs learn manifold structures
  • Learning PyTorch implementation techniques
  • Exploring latent space representations on simple data
  • Studying the balance between reconstruction and regularization

Citation

If you use this implementation in your research or projects, please cite:

@misc{vae_mnist_implementation,
  title={Variational Autoencoder Implementation for Swiss Roll},
  author={Uday Jain},
  year={2026},
  url={https://huggingface.co/uday9k/SwissRoll_VAE}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Additional Resources


Tags: deep-learning, generative-ai, pytorch, vae, swiss-roll, unsupervised-learning

Model Card Authors: Uday Jain

Downloads last month
185
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support