Coverage for tests/tests_advertrain/test_dependencies/test_fire.py: 100%
32 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import pytest
2import torch
4from robustML.advertrain.dependencies.fire import (entropy_loss, fire_loss,
5 noise_loss)
7torch.manual_seed(0)
10class MockModel(torch.nn.Module):
11 def __init__(self):
12 super().__init__()
13 self.lin = torch.nn.Linear(10, 2)
15 def forward(self, x):
16 return self.lin(x)
19@pytest.fixture
20def mock_model():
21 return MockModel()
24def test_entropy_loss():
25 logits = torch.randn(32, 2)
26 loss = entropy_loss(logits)
28 assert loss.item() >= 0
31def test_fire_loss(mock_model):
32 x = torch.randn(32, 10)
33 y = torch.randint(0, 2, (32,))
34 optimizer = torch.optim.Adam(mock_model.parameters(), lr=0.001)
35 device = torch.device('cpu')
37 total_loss, nat_loss, rob_loss, ent_loss = fire_loss(
38 model=mock_model,
39 x_natural=x,
40 y=y,
41 optimizer=optimizer,
42 epoch=1,
43 device=device
44 )
46 assert total_loss.item() >= 0
47 assert nat_loss.item() >= 0
48 assert rob_loss.item() >= 0
49 assert ent_loss.item() >= 0
52def test_noise_loss(mock_model):
53 x = torch.randn(32, 10)
54 y = torch.randint(0, 2, (32,))
55 loss = noise_loss(
56 model=mock_model,
57 x_natural=x,
58 y=y
59 )
61 assert loss.item() >= 0