Coverage for tests/tests_advertrain/test_training/test_fire_training.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import torch 

2import pytest 

3from robustML.advertrain.training.fire_training import FIRETraining 

4 

5torch.manual_seed(0) 

6 

7 

8class MockModel(torch.nn.Module): 

9 def __init__(self): 

10 super().__init__() 

11 self.lin = torch.nn.Linear(10, 2) 

12 

13 def forward(self, x): 

14 return self.lin(x) 

15 

16 

17class MockOptimizer(torch.optim.Optimizer): 

18 # Mock optimizer for testing 

19 pass 

20 

21 

22@pytest.fixture 

23def mock_model(): 

24 return MockModel() 

25 

26 

27@pytest.fixture 

28def mock_optimizer(mock_model): 

29 return MockOptimizer(mock_model.parameters(), {}) 

30 

31 

32@pytest.fixture 

33def mock_device(): 

34 return torch.device('cpu') 

35 

36 

37@pytest.fixture 

38def fire_training(mock_model, mock_optimizer, mock_device): 

39 return FIRETraining( 

40 model=mock_model, 

41 optimizer=mock_optimizer, 

42 device=mock_device, 

43 epsilon=0.1, 

44 beta=1.0, 

45 perturb_steps=20 

46 ) 

47 

48 

49def test_val_batch(fire_training): 

50 x = torch.randn(32, 10) 

51 y = torch.randint(0, 2, (32,)) 

52 loss, batch_size = fire_training.val_batch(x, y, epoch=1) 

53 

54 assert loss >= 0 

55 assert batch_size == x.size(0)