Coverage for tests/tests_advertrain/test_dependencies/test_fire.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import pytest 

2import torch 

3 

4from robustML.advertrain.dependencies.fire import (entropy_loss, fire_loss, 

5 noise_loss) 

6 

7torch.manual_seed(0) 

8 

9 

10class MockModel(torch.nn.Module): 

11 def __init__(self): 

12 super().__init__() 

13 self.lin = torch.nn.Linear(10, 2) 

14 

15 def forward(self, x): 

16 return self.lin(x) 

17 

18 

19@pytest.fixture 

20def mock_model(): 

21 return MockModel() 

22 

23 

24def test_entropy_loss(): 

25 logits = torch.randn(32, 2) 

26 loss = entropy_loss(logits) 

27 

28 assert loss.item() >= 0 

29 

30 

31def test_fire_loss(mock_model): 

32 x = torch.randn(32, 10) 

33 y = torch.randint(0, 2, (32,)) 

34 optimizer = torch.optim.Adam(mock_model.parameters(), lr=0.001) 

35 device = torch.device('cpu') 

36 

37 total_loss, nat_loss, rob_loss, ent_loss = fire_loss( 

38 model=mock_model, 

39 x_natural=x, 

40 y=y, 

41 optimizer=optimizer, 

42 epoch=1, 

43 device=device 

44 ) 

45 

46 assert total_loss.item() >= 0 

47 assert nat_loss.item() >= 0 

48 assert rob_loss.item() >= 0 

49 assert ent_loss.item() >= 0 

50 

51 

52def test_noise_loss(mock_model): 

53 x = torch.randn(32, 10) 

54 y = torch.randint(0, 2, (32,)) 

55 loss = noise_loss( 

56 model=mock_model, 

57 x_natural=x, 

58 y=y 

59 ) 

60 

61 assert loss.item() >= 0