Coverage for adaro_rl / zoo / main.py: 57%

14 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 07:50 +0000

1from huggingface_hub import hf_hub_download 

2import argparse 

3 

4 

5def download_model(model_config: str, local_dir=None): 

6 """ 

7 Fetch the checkpoint associated with a zoo configuration. 

8 

9 Given the *name* of a configuration (e.g. ``"Enduro-v5"``) this utility 

10 looks it up in :data:`adaro_rl.zoo.configs`, extracts the repository ID 

11 and filename stored in the config’s ``agent`` sub-module, and downloads 

12 the file via the HuggingFace Hub. 

13 

14 Parameters 

15 ---------- 

16 model_config : str 

17 Key in :data:`adaro_rl.zoo.configs` that identifies the desired 

18 benchmark configuration. 

19 local_dir : str or None, optional 

20 Destination directory; if *None* the file is cached in the default 

21 HuggingFace location. 

22 

23 Returns 

24 ------- 

25 str 

26 Absolute path to the downloaded ``.zip`` checkpoint. 

27 """ 

28 from . import configs 

29 cfg = configs[model_config] 

30 return hf_hub_download( 

31 repo_id=cfg.agent.repo_id, 

32 filename=cfg.agent.filename, 

33 local_dir=local_dir, 

34 ) 

35 

36 

37def main() -> None: 

38 """ 

39 Command-line wrapper around :func:`download_model`. 

40 

41 Examples 

42 -------- 

43 .. code-block:: console 

44 

45 $ python -m adaro_rl.zoo.main --download-model Enduro-v5 --local-dir checkpoints/ 

46 """ 

47 p = argparse.ArgumentParser("Download a model from HF hub") 

48 p.add_argument("--download-model", required=True) 

49 p.add_argument("--local-dir", default=None) 

50 args = p.parse_args() 

51 download_model(args.download_model, args.local_dir) 

52 

53 

54if __name__ == "__main__": 

55 main()