Coverage for tests/tests_advertrain/test_training/test_classical_training.py: 86%
37 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-09-10 08:11 +0000
1import pytest
2import torch
3from torch.nn import Linear, Module
4from torch.optim import SGD
5from torch.utils.data import DataLoader, TensorDataset
7from robustML.advertrain.training.classical_training import ClassicalTraining
9torch.manual_seed(0)
12class SimpleModel(Module):
13 def __init__(self):
14 super(SimpleModel, self).__init__()
15 self.linear = Linear(in_features=5, out_features=2) # Example layer
17 def forward(self, x):
18 return self.linear(x)
21@pytest.fixture
22def mock_model():
23 return SimpleModel()
26@pytest.fixture
27def mock_optimizer(mock_model):
28 return SGD(mock_model.parameters(), lr=0.001)
31@pytest.fixture
32def mock_loss_func():
33 return torch.nn.CrossEntropyLoss()
36@pytest.fixture
37def mock_device():
38 return torch.device("cpu") # or "cuda" if testing on GPU
41@pytest.fixture
42def mock_dataloader():
43 x = torch.rand(10, 5) # example data
44 y = torch.randint(0, 2, (10,))
45 dataset = TensorDataset(x, y)
46 return DataLoader(dataset, batch_size=2)
49def test_initialization(mock_model, mock_optimizer, mock_loss_func, mock_device):
50 training = ClassicalTraining(mock_model, mock_optimizer, mock_loss_func, mock_device)
51 assert training.model is mock_model
52 assert training.optimizer is mock_optimizer
53 assert training.loss_func is mock_loss_func
54 assert training.device == mock_device