PlantRNA-FM Sparse Auto Encoders

Here are a series of Autoencoders for the plantRNAFM model detailed here:

These were trained using sictionary_learning, and are TopK SAEs with k=64. Dictionary size is 15,360 (32x expansion of model dimension). All layers were trained with the same learning rate, 1e-4. I chose the parameters from a hyperparameter sweep; these give the pareto optimal fraction of variance explained vs dead features, based on layer 6. Here are the metrics for all layers:

β”Œβ”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ layer ┆ dead_feature_pct ┆ fve      β”‚
β”‚ ---   ┆ ---              ┆ ---      β”‚
β”‚ i64   ┆ f64              ┆ f64      β”‚
β•žβ•β•β•β•β•β•β•β•ͺ══════════════════β•ͺ══════════║
β”‚ 0     ┆ 70.110677        ┆ 0.999368 β”‚
β”‚ 1     ┆ 3.046875         ┆ 0.991308 β”‚
β”‚ 2     ┆ 1.197917         ┆ 0.982633 β”‚
β”‚ 3     ┆ 0.208333         ┆ 0.971918 β”‚
β”‚ 4     ┆ 0.078125         ┆ 0.958037 β”‚
β”‚ 5     ┆ 0.0              ┆ 0.950458 β”‚
β”‚ 6     ┆ 0.00651          ┆ 0.948665 β”‚
β”‚ 7     ┆ 0.0              ┆ 0.952286 β”‚
β”‚ 8     ┆ 0.0              ┆ 0.963454 β”‚
β”‚ 9     ┆ 0.00651          ┆ 0.974378 β”‚
β”‚ 10    ┆ 0.143229         ┆ 0.982892 β”‚
β”‚ 11    ┆ 0.188802         ┆ 0.990729 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Theres a bit of evidence there that earlier layers are overparameterised by this dictionary, but they are probably not that interesting anyway.

I load the SAEs for investigation like this:

from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from dictionary_learning import utils
SAE_MODEL_CACHE = {}
BASE_MODEL_CACHE = {}
def get_base_model(base_model_name: str, device: str):
    """Loads and caches the base PlantRNA-FM model and tokenizer."""
    if "model" not in BASE_MODEL_CACHE:
        print("Loading base model and tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        model = AutoModel.from_pretrained(base_model_name).to(device)
        model.eval()
        BASE_MODEL_CACHE["model"] = model
        BASE_MODEL_CACHE["tokenizer"] = tokenizer
        print("Base model and tokenizer loaded.")
    return BASE_MODEL_CACHE["model"], BASE_MODEL_CACHE["tokenizer"]


def get_sae_model(sae_repo: str, layer: int, k_sparsity: int, device: str):
    """Downloads, loads, and caches a specific Sparse Autoencoder."""
    sae_key = f"layer_{layer}_k_{k_sparsity}"
    if sae_key not in SAE_MODEL_CACHE:
        print(f"Loading SAE for layer {layer}, k={k_sparsity}...")
        try:
            model_path = hf_hub_download(
                repo_id=sae_repo,
                filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/ae.pt",
            )
            _ = hf_hub_download(
                repo_id=sae_repo,
                filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/config.json",
            )
            ae, config = utils.load_dictionary(pathlib.Path(model_path).parents[0], device=device)
            SAE_MODEL_CACHE[sae_key] = ae
            print(f"SAE for {sae_key} loaded.")
        except Exception:
            logger.error(f"Could not load SAE (layer={layer}, k={k_sparsity}). Check if it exists in the repo.")
    return SAE_MODEL_CACHE[sae_key]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support