Coverage for tests/tests_advertrain/test_training/test_fire_training.py: 100%
30 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import torch
2import pytest
3from robustML.advertrain.training.fire_training import FIRETraining
5torch.manual_seed(0)
8class MockModel(torch.nn.Module):
9 def __init__(self):
10 super().__init__()
11 self.lin = torch.nn.Linear(10, 2)
13 def forward(self, x):
14 return self.lin(x)
17class MockOptimizer(torch.optim.Optimizer):
18 # Mock optimizer for testing
19 pass
22@pytest.fixture
23def mock_model():
24 return MockModel()
27@pytest.fixture
28def mock_optimizer(mock_model):
29 return MockOptimizer(mock_model.parameters(), {})
32@pytest.fixture
33def mock_device():
34 return torch.device('cpu')
37@pytest.fixture
38def fire_training(mock_model, mock_optimizer, mock_device):
39 return FIRETraining(
40 model=mock_model,
41 optimizer=mock_optimizer,
42 device=mock_device,
43 epsilon=0.1,
44 beta=1.0,
45 perturb_steps=20
46 )
49def test_val_batch(fire_training):
50 x = torch.randn(32, 10)
51 y = torch.randint(0, 2, (32,))
52 loss, batch_size = fire_training.val_batch(x, y, epoch=1)
54 assert loss >= 0
55 assert batch_size == x.size(0)