知行编程网知行编程网  2022-03-16 20:00 知行编程网 隐藏边栏 |   抢沙发  11 
文章评分 0 次,平均分 0.0

迁移学习不好懂?要不看看这个Pytorch项目

来自 | Medium

编辑 | 元子

报道 | 新智元

【导读】迁移学习是一个非常重要的机器学习技术,已被广泛应用于机器学习的许多应用中。本文的目标是让读者理解迁移学习的意义,了解转学习的重要性,并学会使用PyTorch进行实践。

元学习有一个非常重要的理念是在较少样本量的情况下,让机器能够自己学会学习。这一点和迁移学习非常相似。吴恩达曾经说过"迁移学习将会是继监督学习之后的下一个机器学习商业成功的驱动力"。


相比而言,依赖大量数据进行训练的其他机器学习手段,对数据和算力的依赖有点过于严重。况且,数据和算力那么贵!


迁移学习的一大特色,就是“将一个任务环境中学到的东西用来提升在另一个任务环境中模型的泛化能力”。


没有GPU也没关系,可以使用谷歌的免费GPU服务,通过谷歌Colab来训练模型。


借TensorFlow 2.0发布之际,就让我们对比一下,通过PyTorch来更直观、更深入的了解迁移学习。


前期准备


本次旅程,我们将使用预先训练的网络,来构建用于疟疾检测的图像分类器,这个分类器只需要将得到的数据,分为“感染”“未感染”两类。


我们将要用到的图像数据集可以在这里下载👇

https://drive.google.com/open?id=16DbIOMCtCuRuMdYF64MPv3iLqpSG6tfv


经过预先训练的网络在ImageNet上进行了训练,其中包含120万张1000个类别的图像,


用到的模型是torchvision.models,它有6种不同的架构我们可以使用。


torchvision.models具有模型性能的细分以及可以使用的层数(由模型附带的数字表示)。


加载所有必需的包和库:




将数据进行可视化:



下图是感染的图


迁移学习不好懂?要不看看这个Pytorch项目


定义转换并加载进数据


转换是将一个图形、表达式或函数转换为另一个图形、表达式或函数的过程。


我们需要为训练、测试以及验证数据定义一些转换。值得注意的,可能有的类别图像太少,不够进行转换,为了增加网络识别的图像数量,我们执行所谓的数据增强。


在训练期间,我们随机裁剪、缩放和旋转图像,以便在每个时期,网络会看到同一图像的不同变化,提高实验的准确性



接下来加载数据集。最简单的方法是用torchvision的dataset.ImageFolder。


加载imageFolder后,我们将数据拆分为20%验证集和10%测试集; 然后将它传递给DataLoader。


它接收一个类似从ImageFolder获得的数据集,并返回批量图像及其相应的标签(可以将改组设置为true以在时期内引入变化)。



模型训练流程


1. 加载预先训练的模型


PyTorch以及几乎所有其他深度学习框架,都使用CUDA来有效地计算GPU上的前向和后向传递。


在PyTorch中,我们使用model.cuda()将模型参数和其他张量移动到GPU内存,或者从GPU移回,


2. 冻结卷积层并使用自定义分类器替换完全连接的层



冻结模型参数允许我们为早期卷积层保留预训练模型的权重,其目的是用于特征提取。


然后我们定义我们的全连接网络,他将作为输入神经元,示例代码中是1024,这个数字取决于预训练模型的输入神经元,和自定义隐藏层。


我们还定义了要使用的激活函数,和有助于通过随机关闭层中的神经元,以强制在剩余节点之间共享信息,来避免过度拟合的丢失。


在我们定义了自定义全连接网络之后,我们将其连接到预先训练好的模型的完全连接网络。


接下来我们定义损失函数,优化器,并通过将模型移动到GPU来准备训练模型。


3. 为特定任务训练自定义分类器


在训练期间,我们遍历每个时期的DataLoader。 对于每个batch,使用标准函数计算损失。使用loss.backward()方法计算相对于模型参数的损失梯度。


optimizer.zero_grad()负责清除任何累积的梯度,因为我们会一遍又一遍地计算梯度。


optimizer.step()使用具有动量的随机梯度下降(Adam)更新模型参数。


为了防止过度拟合,我们使用一种称为早期停止的强大技术。背后的想法很简单,当验证数据集上的性能开始降低时停止训练。



在耐心地等待训练过程完成并保存最佳模型参数的检查点之后,让我们加载检查点并在看不见的数据(测试数据)上测试模型的性能。


从磁盘加载已保存的模型



在看不见的数据上测试加载的模型。 我们对看不见的数据有90%的准确率,这在第一次尝试时非常令人印象深刻。



现在我们对模型有了信心,现在是时候进行一些预测并将结果可视化了。



迁移学习不好懂?要不看看这个Pytorch项目


好。教程到此就结束了。我们使用PyTorch,利用迁移学习建立了一个疟疾分类器的应用。


接下来,我们可以继续的完善代码,或者可以再做几个其他同类型的应用。


参考链接:

https://heartbeat.fritz.ai/transfer-learning-with-pytorch-cfcb69016c72

—完—

为您推荐

一文读懂 12种卷积方法

如何向 5 岁小孩解释什么是SVM ?

【新纪录】90秒训练AlexNet!

AI圣经 PRML《模式识别与机器学习》

我的2019秋招算法面经


本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享