Coverage for tests/tests_advertrain/test_dependencies/test_trades.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import torch 

2import pytest 

3from robustML.advertrain.dependencies.trades import squared_l2_norm, l2_norm, trades_loss 

4 

5torch.manual_seed(0) 

6 

7 

8class MockModel(torch.nn.Module): 

9 def __init__(self): 

10 super().__init__() 

11 self.lin = torch.nn.Linear(10, 2) # Adjust dimensions as needed 

12 

13 def forward(self, x): 

14 return self.lin(x) 

15 

16 

17@pytest.fixture 

18def mock_model(): 

19 return MockModel() 

20 

21 

22def test_squared_l2_norm(): 

23 x = torch.randn(32, 10) 

24 norm = squared_l2_norm(x) 

25 

26 assert torch.all(norm >= 0) 

27 

28 

29def test_l2_norm(): 

30 x = torch.randn(32, 10) 

31 norm = l2_norm(x) 

32 

33 assert torch.all(norm >= 0) 

34 

35 

36def test_trades_loss(mock_model): 

37 x = torch.randn(32, 10) 

38 y = torch.randint(0, 2, (32,)) 

39 optimizer = torch.optim.Adam(mock_model.parameters(), lr=0.001) 

40 device = torch.device('cpu') 

41 

42 loss = trades_loss( 

43 model=mock_model, 

44 x_natural=x, 

45 y=y, 

46 optimizer=optimizer, 

47 device=device 

48 ) 

49 

50 assert loss.item() >= 0