Coverage for tests/tests_advertrain/test_metrics.py: 100%

35 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-10-01 08:42 +0000

1import os 

2 

3os.getcwd() 

4 

5import pytest 

6import torch 

7 

8from robustAI.advertrain.metrics import Metrics 

9 

10 

11@pytest.fixture 

12def sample_data(): 

13 x = torch.tensor([[1, 2], [3, 4]]) 

14 y = torch.tensor([1, 0]) 

15 pred = torch.tensor([1, 0]) 

16 loss = torch.tensor([0.5]) 

17 return x, y, pred, loss 

18 

19 

20def test_initial_state(): 

21 metrics = Metrics() 

22 assert metrics.TP == 0 

23 assert metrics.TN == 0 

24 assert metrics.FP == 0 

25 assert metrics.FN == 0 

26 assert metrics.loss == 0.0 

27 

28 

29def test_metrics_update(sample_data): 

30 x, y, pred, loss = sample_data 

31 metrics = Metrics() 

32 metrics.update(x, y, pred, loss) 

33 

34 assert metrics.TP == 1 

35 assert metrics.TN == 1 

36 assert metrics.FP == 0 

37 assert metrics.FN == 0 

38 assert metrics.loss == loss.item() * len(x) 

39 

40 

41def test_get_metrics(sample_data): 

42 x, y, pred, loss = sample_data 

43 metrics = Metrics() 

44 metrics.update(x, y, pred, loss) 

45 

46 accuracy, loss, precision, recall, f1_score = metrics.get_metrics() 

47 

48 assert accuracy == (1 + 1) / (2 + 1e-8) 

49 assert precision == 1 / (1 + 1e-8)