MNIST 手写数字集分类问题:基于 AlexNet 神经网络
Python
PyTorch
MNIST
Dataset
Dataloarder
Tensorboard
本文介绍如何读取 MNIST 数据集,搭建 AlexNet 简单卷积神经网络,模型训练和验证。模型在验证集的准确率大约 95 % 95\% 95% 。
已开源在 GitHub库
使用 git
下载。进入空目录:
1 git clone https://github.com/isKage/mnist-classification.git
PyTorch 的安装和环境配置可见 zhihu
安装指定依赖:【进入 requirements.txt
根目录下安装】
1 pip install -r requirements.txt
在根目录下创建 config.py
文件写入本地配置。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 import osimport torchimport warningsfrom datetime import datetimeclass DefaultConfig : model = 'Classification10Class' root = '<路径>/AllData/datasets/hojjatk/mnist-dataset' logdir = './logs' param_path = './checkpoints/' if not os.listdir(param_path): load_model_path = None else : load_model_path = os.path.join( param_path, sorted ( os.listdir(param_path), key=lambda x: datetime.strptime( x.split('_' )[-1 ].split('.pth' )[0 ], "%Y-%m-%d%H%M%S" ) )[-1 ] ) lr = 0.03 max_epochs = 1 batch_size = 64 num_workers = 0 print_feq = 100 if torch.cuda.is_available(): gpu = True device = torch.device('cuda' ) else : gpu = False device = torch.device('cpu' ) def _parse (self, kwargs ): """ 根据字典kwargs 更新 config 参数 """ for k, v in kwargs.items(): if not hasattr (self , k): warnings.warn("Warning: opt has not attribute %s" % k) setattr (self , k, v) config.device = torch.device('cuda:0' ) if config.gpu else torch.device('cpu' ) print ('User config:' ) for k, v in self .__class__.__dict__.items(): if not k.startswith('_' ): print (k, getattr (self , k)) config = DefaultConfig()
1 读取 MNIST 数据集
直接使用 torchvision.datasets.MNIST
会出现网络问题,难以下载。
可以先前往 kaggle 下载。
使用 kaggle 命令下载教程可见 从 Kaggle 下载数据集(mac 和 win 端) 。
然后自定义 get_data.py
的 getData
函数读取数据集。其中 config
为本地配置(包含了一些参数和文件路径)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 from config import configimport torchvision.datasetsfrom torch.utils.data import DataLoaderdef getData (root=config.root, batch_size=config.batch_size ): train_dataset = torchvision.datasets.MNIST( root=root, train=True , transform=torchvision.transforms.ToTensor(), download=False , ) test_dataset = torchvision.datasets.MNIST( root=root, train=False , transform=torchvision.transforms.ToTensor(), download=False , ) train_data_size = len (train_dataset) test_data_size = len (test_dataset) print ("训练数据集长度为 {}" .format (train_data_size)) print ("测试数据集长度为 {}" .format (test_data_size)) train_dataloader = DataLoader( dataset=train_dataset, batch_size=batch_size, ) test_dataloader = DataLoader( dataset=test_dataset, batch_size=batch_size, ) return train_dataset, test_dataset, train_dataloader, test_dataloader if __name__ == "__main__" : train_dataset, test_dataset, train_dataloader, test_dataloader = getData() img, label = train_dataset[0 ] print (img.shape) print (label)
2 搭建网络
MNIST 数据集较为简单,使用简单的 AlexNet 卷积神经网络即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 import timeimport torchfrom torch import nnclass BasicModule (nn.Module): """ 作为基类,继承 nn.Module 但增加了模型保存和加载功能 save and load """ def __init__ (self ): super ().__init__() self .model_name = str (type (self )) def load (self, model_path ): """ 根据模型路径加载模型 :param model_path: 模型路径 :return: 模型 """ self .load_state_dict(torch.load(model_path)) def save (self, filename=None ): """ 保存模型,默认使用 "模型名字 + 时间" 作为文件名,也可以自定义 """ if filename is None : filename = 'checkpoints/' + self .model_name + '_' + time.strftime("%Y-%m-%d%H%M%S" ) + '.pth' torch.save(self .state_dict(), filename) return filename class Classification10Class (BasicModule ): def __init__ (self ): super (Classification10Class, self ).__init__() self .model_name = 'Classification10Class' self .module = nn.Sequential( nn.Conv2d(in_channels=1 , out_channels=16 , kernel_size=5 , stride=1 , padding=2 ), nn.MaxPool2d(kernel_size=2 ), nn.Conv2d(in_channels=16 , out_channels=32 , kernel_size=5 , stride=1 , padding=2 ), nn.MaxPool2d(kernel_size=2 ), nn.Conv2d(in_channels=32 , out_channels=64 , kernel_size=5 , stride=1 , padding=2 ), nn.MaxPool2d(kernel_size=2 ), nn.Flatten(), nn.Linear(in_features=64 * 3 * 3 , out_features=64 ), nn.Linear(in_features=64 , out_features=10 ), ) def forward (self, x ): x = self .module(x) return x if __name__ == '__main__' : classification = Classification10Class() inputs = torch.ones((64 , 1 , 28 , 28 )) outputs = classification(inputs) print (outputs.shape)
3 主程序
主程序 main.py
包含了训练、验证和写入 tensorboard 可视化。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 import modelsfrom config import configfrom get_data import getDataimport osimport torchfrom torch import nnfrom torch.utils.tensorboard import SummaryWriterdef train (**kwargs ): config._parse(kwargs) classification = getattr (models, config.model)() classification.to(config.device) train_dataset, test_dataset, train_dataloader, test_dataloader = getData() test_data_size = len (test_dataset) loss_fn = nn.CrossEntropyLoss() learning_rate = 0.01 optimizer = torch.optim.SGD( params=classification.parameters(), lr=config.lr, ) total_train_step = 0 total_test_step = 0 epochs = config.max_epochs writer = SummaryWriter("./logs" ) for epoch in range (epochs): print ("------------- 第 {} 轮训练开始 -------------" .format (epoch + 1 )) classification.train() for data in train_dataloader: images, targets = data images, targets = images.to(config.device), targets.to(config.device) outputs = classification(images) loss = loss_fn(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() total_train_step += 1 if total_train_step % config.print_feq == 0 : print ("训练次数: {}, loss: {}" .format (total_train_step, loss.item())) writer.add_scalar( tag="train_loss (every 100 steps)" , scalar_value=loss.item(), global_step=total_train_step, ) classification.eval () total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in test_dataloader: images, targets = data images, targets = images.to(config.device), targets.to(config.device) outputs = classification(images) loss = loss_fn(outputs, targets) total_test_loss += loss.item() accuracy = (outputs.argmax(axis=1 ) == targets).sum () total_accuracy += accuracy print ("##### 在测试集上的 loss: {} #####" .format (total_test_loss)) writer.add_scalar( tag="test_loss (every epoch)" , scalar_value=total_test_loss, global_step=epoch, ) print ("##### 在测试集上的正确率: {} #####" .format (total_accuracy / test_data_size)) writer.add_scalar( tag="test_accuracy (every epoch)" , scalar_value=total_accuracy / test_data_size, global_step=epoch, ) classification.save() print ("##### 模型成功保存 #####" ) writer.close() if __name__ == '__main__' : import fire fire.Fire()
4 运行程序
使用 fire
包,从而实现终端训练。
即可运行主程序的 train
函数。