| | --- |
| | license: apache-2.0 |
| | tags: |
| | - jax |
| | - safetensors |
| | --- |
| | |
| | # Parametric PerceptNet Fully Trained |
| |
|
| | ## Model Description |
| |
|
| | ## How to use it |
| |
|
| | ### Install the model's package from source: |
| | ``` |
| | git clone https://github.com/Jorgvt/paramperceptnet.git |
| | cd paramperceptnet |
| | pip install -e . |
| | ``` |
| |
|
| | ### 1.Import necessary libraries: |
| |
|
| | ``` |
| | import json |
| | |
| | from huggingface_hub import hf_hub_download |
| | import flax |
| | import orbax.checkpoint |
| | from ml_collections import ConfigDict |
| | |
| | from paramperceptnet.models import PerceptNet |
| | ``` |
| |
|
| | ### 2.Download the configuration |
| |
|
| | ``` |
| | config_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained", |
| | filename="config.json") |
| | with open(config_path, "r") as f: |
| | config = ConfigDict(json.load(f)) |
| | ``` |
| |
|
| | ### 3. Download the weights |
| |
|
| | #### 3.1. Using `safetensors` |
| |
|
| | ``` |
| | from safetensors.flax import load_file |
| | |
| | weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained", |
| | filename="weights.safetensors") |
| | variables = load_file(weights_path) |
| | variables = flax.traverse_util.unflatten_dict(variables, sep=".") |
| | state = variables["state"] |
| | params = variables["params"] |
| | ``` |
| |
|
| | #### 3.2. Using `mgspack` |
| | ``` |
| | weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-fully-trained", |
| | filename="weights.msgpack") |
| | with open(weights_path, "rb") as f: |
| | variables = orbax.checkpoint.msgpack_utils.msgpack_restore(f.read()) |
| | variables = jax.tree_util.tree_map(lambda x: jnp.array(x), variables) |
| | state = variables["state"] |
| | params = variables["params"] |
| | ``` |
| |
|
| | ### 4. Use the model |
| |
|
| | ``` |
| | from jax import numpy as jnp |
| | model = PerceptNet(config) |
| | pred = model.apply({"params": params, **state}, jnp.ones((1,384,512,3))) |
| | ``` |
| |
|