Coverage for adaro_rl / zoo / main.py: 57%
14 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 07:50 +0000
1from huggingface_hub import hf_hub_download
2import argparse
5def download_model(model_config: str, local_dir=None):
6 """
7 Fetch the checkpoint associated with a zoo configuration.
9 Given the *name* of a configuration (e.g. ``"Enduro-v5"``) this utility
10 looks it up in :data:`adaro_rl.zoo.configs`, extracts the repository ID
11 and filename stored in the config’s ``agent`` sub-module, and downloads
12 the file via the HuggingFace Hub.
14 Parameters
15 ----------
16 model_config : str
17 Key in :data:`adaro_rl.zoo.configs` that identifies the desired
18 benchmark configuration.
19 local_dir : str or None, optional
20 Destination directory; if *None* the file is cached in the default
21 HuggingFace location.
23 Returns
24 -------
25 str
26 Absolute path to the downloaded ``.zip`` checkpoint.
27 """
28 from . import configs
29 cfg = configs[model_config]
30 return hf_hub_download(
31 repo_id=cfg.agent.repo_id,
32 filename=cfg.agent.filename,
33 local_dir=local_dir,
34 )
37def main() -> None:
38 """
39 Command-line wrapper around :func:`download_model`.
41 Examples
42 --------
43 .. code-block:: console
45 $ python -m adaro_rl.zoo.main --download-model Enduro-v5 --local-dir checkpoints/
46 """
47 p = argparse.ArgumentParser("Download a model from HF hub")
48 p.add_argument("--download-model", required=True)
49 p.add_argument("--local-dir", default=None)
50 args = p.parse_args()
51 download_model(args.download_model, args.local_dir)
54if __name__ == "__main__":
55 main()