Coverage for tests/tests_advertrain/test_training/test_classical_training.py: 86%

37 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-09-10 08:11 +0000

1import pytest 

2import torch 

3from torch.nn import Linear, Module 

4from torch.optim import SGD 

5from torch.utils.data import DataLoader, TensorDataset 

6 

7from robustML.advertrain.training.classical_training import ClassicalTraining 

8 

9torch.manual_seed(0) 

10 

11 

12class SimpleModel(Module): 

13 def __init__(self): 

14 super(SimpleModel, self).__init__() 

15 self.linear = Linear(in_features=5, out_features=2) # Example layer 

16 

17 def forward(self, x): 

18 return self.linear(x) 

19 

20 

21@pytest.fixture 

22def mock_model(): 

23 return SimpleModel() 

24 

25 

26@pytest.fixture 

27def mock_optimizer(mock_model): 

28 return SGD(mock_model.parameters(), lr=0.001) 

29 

30 

31@pytest.fixture 

32def mock_loss_func(): 

33 return torch.nn.CrossEntropyLoss() 

34 

35 

36@pytest.fixture 

37def mock_device(): 

38 return torch.device("cpu") # or "cuda" if testing on GPU 

39 

40 

41@pytest.fixture 

42def mock_dataloader(): 

43 x = torch.rand(10, 5) # example data 

44 y = torch.randint(0, 2, (10,)) 

45 dataset = TensorDataset(x, y) 

46 return DataLoader(dataset, batch_size=2) 

47 

48 

49def test_initialization(mock_model, mock_optimizer, mock_loss_func, mock_device): 

50 training = ClassicalTraining(mock_model, mock_optimizer, mock_loss_func, mock_device) 

51 assert training.model is mock_model 

52 assert training.optimizer is mock_optimizer 

53 assert training.loss_func is mock_loss_func 

54 assert training.device == mock_device