PlantRNA-FM Sparse Auto Encoders
Here are a series of Autoencoders for the plantRNAFM model detailed here:
These were trained using sictionary_learning, and are TopK SAEs with k=64. Dictionary size is 15,360 (32x expansion of model dimension). All layers were trained with the same learning rate, 1e-4. I chose the parameters from a hyperparameter sweep; these give the pareto optimal fraction of variance explained vs dead features, based on layer 6. Here are the metrics for all layers:
βββββββββ¬βββββββββββββββββββ¬βββββββββββ
β layer β dead_feature_pct β fve β
β --- β --- β --- β
β i64 β f64 β f64 β
βββββββββͺβββββββββββββββββββͺβββββββββββ‘
β 0 β 70.110677 β 0.999368 β
β 1 β 3.046875 β 0.991308 β
β 2 β 1.197917 β 0.982633 β
β 3 β 0.208333 β 0.971918 β
β 4 β 0.078125 β 0.958037 β
β 5 β 0.0 β 0.950458 β
β 6 β 0.00651 β 0.948665 β
β 7 β 0.0 β 0.952286 β
β 8 β 0.0 β 0.963454 β
β 9 β 0.00651 β 0.974378 β
β 10 β 0.143229 β 0.982892 β
β 11 β 0.188802 β 0.990729 β
βββββββββ΄βββββββββββββββββββ΄βββββββββββ
Theres a bit of evidence there that earlier layers are overparameterised by this dictionary, but they are probably not that interesting anyway.
I load the SAEs for investigation like this:
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from dictionary_learning import utils
SAE_MODEL_CACHE = {}
BASE_MODEL_CACHE = {}
def get_base_model(base_model_name: str, device: str):
"""Loads and caches the base PlantRNA-FM model and tokenizer."""
if "model" not in BASE_MODEL_CACHE:
print("Loading base model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModel.from_pretrained(base_model_name).to(device)
model.eval()
BASE_MODEL_CACHE["model"] = model
BASE_MODEL_CACHE["tokenizer"] = tokenizer
print("Base model and tokenizer loaded.")
return BASE_MODEL_CACHE["model"], BASE_MODEL_CACHE["tokenizer"]
def get_sae_model(sae_repo: str, layer: int, k_sparsity: int, device: str):
"""Downloads, loads, and caches a specific Sparse Autoencoder."""
sae_key = f"layer_{layer}_k_{k_sparsity}"
if sae_key not in SAE_MODEL_CACHE:
print(f"Loading SAE for layer {layer}, k={k_sparsity}...")
try:
model_path = hf_hub_download(
repo_id=sae_repo,
filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/ae.pt",
)
_ = hf_hub_download(
repo_id=sae_repo,
filename=f"layer_{layer}/k_{k_sparsity}/trainer_0/config.json",
)
ae, config = utils.load_dictionary(pathlib.Path(model_path).parents[0], device=device)
SAE_MODEL_CACHE[sae_key] = ae
print(f"SAE for {sae_key} loaded.")
except Exception:
logger.error(f"Could not load SAE (layer={layer}, k={k_sparsity}). Check if it exists in the repo.")
return SAE_MODEL_CACHE[sae_key]
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support