知行编程网知行编程网  2022-05-17 16:00 知行编程网 隐藏边栏 |   抢沙发  5 
文章评分 0 次,平均分 0.0

“瘦身成功”的ALBERT,能取代BERT吗?

来自 | 十三 发自 凹非寺    

转自 | 量子位

参数比BERT少了80%,性能却提高了。
这就是谷歌去年提出的“瘦身成功版BERT”模型——ALBERT
这个模型一经发布,就受到了高度关注,二者的对比也成为了热门话题。
而最近,网友Naman Bansal就提出了一个疑问:
是否应该用ALBERT来代替BERT?
“瘦身成功”的ALBERT,能取代BERT吗?
能否替代,比比便知。



   BERT与ALBERT

BERT模型是大家比较所熟知的。
2018年由谷歌提出,训练的语料库规模非常庞大,包含33亿个词语。
“瘦身成功”的ALBERT,能取代BERT吗?
模型的创新点集中在了预训练过程,采用Masked LM和Next Sentence Prediction两种方法,分别捕捉词语和句子级别的表示。
BERT的出现,彻底改变了预训练产生词向量和下游具体NLP任务的关系。
时隔1年后,谷歌又提出ALBERT,也被称作“lite-BERT”,骨干网络和BERT相似,采用的依旧是 Transformer 编码器,激活函数也是GELU。
其最大的成功,就在于参数量比BERT少了80%,同时还取得了更好的结果。
与BERT相比的改进,主要包括嵌入向量参数化的因式分解、跨层参数共享、句间连贯性损失采用SOP,以及移除了dropout。
下图便是BERT和ALBERT,在SQuAD和RACE数据集上的性能测试比较结果。
“瘦身成功”的ALBERT,能取代BERT吗?
可以看出,ALBERT性能取得了较好的结果。

   如何实现自定义语料库(预训练)ALBERT?

为了进一步了解ALBERT,接下来,将在自定义语料库中实现ALBERT。
所采用的数据集是“用餐点评数据集”,目标就是通过ALBERT模型来识别菜肴的名称
第一步:下载数据集并准备文件
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Downlading all files and data</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_train.csv<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_val.csv<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review.txt<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review_nopunct.txt<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/models_toy/albert_config.json<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/finetune_checkpoint<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/pretrain_checkpoint<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Creating files and setting up ALBERT</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span>!pip install sentencepiece<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span>!git clone https://github.com/google-research/ALBERT<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span>!python ./ALBERT/create_pretraining_data.py --input_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"restaurant_review.txt"</span> --output_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"restaurant_review_train"</span> --vocab_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"vocab.txt"</span> --max_seq_length=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span>!pip install transformers<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span>!pip install tfrecord<br  /></section>
第二步:使用transformer并定义层
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Defining Layers for ALBERT</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertModel, AlbertPreTrainedModel<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.configuration_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertConfig<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertSequenceOrderHead</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(nn.Module)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, config)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span>        super().__init__()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span>        self.dense = nn.Linear(config.hidden_size, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span>        self.bias = nn.Parameter(torch.zeros(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, hidden_states)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span>        hidden_states = self.dense(hidden_states)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span>        prediction_scores = hidden_states + self.bias<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> prediction_scores<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">18</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> CrossEntropyLoss<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">19</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_bert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> ACT2FN<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertForPretrain</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(AlbertPreTrainedModel)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">21</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">22</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, config)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">23</span>        super().__init__(config)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">24</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">25</span>        self.albert = AlbertModel(config)       <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">26</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">27</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># For Masked LM</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">28</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># The original huggingface implementation, created new output weights via dense layer</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">29</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># However the original Albert </span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">30</span>        self.predictions_dense = nn.Linear(config.hidden_size, config.embedding_size)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">31</span>        self.predictions_activation = ACT2FN[config.hidden_act]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">32</span>        self.predictions_LayerNorm = nn.LayerNorm(config.embedding_size)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">33</span>        self.predictions_bias = nn.Parameter(torch.zeros(config.vocab_size)) <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">34</span>        self.predictions_decoder = nn.Linear(config.embedding_size, config.vocab_size)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">35</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">36</span>        self.predictions_decoder.weight = self.albert.embeddings.word_embeddings.weight<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">37</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">38</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># For sequence order prediction</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">39</span>        self.seq_relationship = AlbertSequenceOrderHead(config)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">41</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">42</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">43</span>        self,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">44</span>        input_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">45</span>        attention_mask=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">46</span>        token_type_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">47</span>        position_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">48</span>        head_mask=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">49</span>        inputs_embeds=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">50</span>        masked_lm_labels=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">51</span>        seq_relationship_labels=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">52</span>    )</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">53</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">54</span>        outputs = self.albert(<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">55</span>            input_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">56</span>            attention_mask=attention_mask,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">57</span>            token_type_ids=token_type_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">58</span>            position_ids=position_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">59</span>            head_mask=head_mask,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span>            inputs_embeds=inputs_embeds,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">61</span>        )<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">62</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">63</span>        loss_fct = CrossEntropyLoss()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">65</span>        sequence_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">66</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">67</span>        sequence_output = self.predictions_dense(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">68</span>        sequence_output = self.predictions_activation(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">69</span>        sequence_output = self.predictions_LayerNorm(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">70</span>        prediction_scores = self.predictions_decoder(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">71</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">72</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">73</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> masked_lm_labels <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">74</span>            masked_lm_loss = loss_fct(prediction_scores.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, self.config.vocab_size)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">75</span>                                      , masked_lm_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">76</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">77</span>        pooled_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">78</span>        seq_relationship_scores = self.seq_relationship(pooled_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">79</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> seq_relationship_labels <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:  <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">80</span>            seq_relationship_loss = loss_fct(seq_relationship_scores.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>), seq_relationship_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">81</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">82</span>        loss = masked_lm_loss + seq_relationship_loss<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">83</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">84</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> loss<br  /></section>
第三步:使用LAMB优化器并微调ALBERT
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Using LAMB optimizer</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  2</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#LAMB -  "https://github.com/cybertronai/pytorch-lamb"</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  3</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.optim <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> Optimizer<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  6</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">Lamb</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(Optimizer)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  7</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">r"""Implements Lamb algorithm.<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  8</span>    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  9</span>    Arguments:<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 10</span>        params (iterable): iterable of parameters to optimize or dicts defining<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 11</span>            parameter groups<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 12</span>        lr (float, optional): learning rate (default: 1e-3)<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 13</span>        betas (Tuple[float, float], optional): coefficients used for computing<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 14</span>            running averages of gradient and its square (default: (0.9, 0.999))<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 15</span>        eps (float, optional): term added to the denominator to improve<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 16</span>            numerical stability (default: 1e-8)<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 17</span>        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 18</span>        adam (bool, optional): always use trust ratio = 1, which turns this into<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 19</span>            Adam. Useful for comparison purposes.<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 20</span>    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 21</span>        https://arxiv.org/abs/1904.00962<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 22</span>    """</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 23</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 24</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, params, lr=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1e-3</span>, betas=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.9</span>, <span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.999</span>)</span>, eps=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1e-6</span>,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"> 25</span>                 weight_decay=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, adam=False)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 26</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= lr:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 27</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid learning rate: {}"</span>.format(lr))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 28</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= eps:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 29</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid epsilon value: {}"</span>.format(eps))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 30</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>] < <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 31</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid beta parameter at index 0: {}"</span>.format(betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 32</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] < <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 33</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid beta parameter at index 1: {}"</span>.format(betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 34</span>        defaults = dict(lr=lr, betas=betas, eps=eps,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 35</span>                        weight_decay=weight_decay)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 36</span>        self.adam = adam<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 37</span>        super(Lamb, self).__init__(params, defaults)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 38</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 39</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">step</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, closure=None)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 40</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"""Performs a single optimization step.<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 41</span>        Arguments:<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 42</span>            closure (callable, optional): A closure that reevaluates the model<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 43</span>                and returns the loss.<br  /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 44</span>        """</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 45</span>        loss = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 46</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> closure <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 47</span>            loss = closure()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 48</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 49</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> group <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> self.param_groups:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 50</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>]:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 51</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> p.grad <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 52</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">continue</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 53</span>                grad = p.grad.data<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 54</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> grad.is_sparse:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 55</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> RuntimeError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'Lamb does not support sparse gradients, consider SparseAdam instad.'</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 56</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 57</span>                state = self.state[p]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 58</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 59</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># State initialization</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 60</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(state) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 61</span>                    state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'step'</span>] = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 62</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Exponential moving average of gradient values</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 63</span>                    state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg'</span>] = torch.zeros_like(p.data)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 64</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Exponential moving average of squared gradient values</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 65</span>                    state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg_sq'</span>] = torch.zeros_like(p.data)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 66</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 67</span>                exp_avg, exp_avg_sq = state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg'</span>], state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg_sq'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 68</span>                beta1, beta2 = group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'betas'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 69</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 70</span>                state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'step'</span>] += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 71</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 72</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Decay the first and second moment running average coefficient</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 73</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># m_t</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 74</span>                exp_avg.mul_(beta1).add_(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span> - beta1, grad)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 75</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># v_t</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 76</span>                exp_avg_sq.mul_(beta2).addcmul_(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span> - beta2, grad, grad)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 77</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 78</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Paper v3 does not use debiasing.</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 79</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># bias_correction1 = 1 - beta1 ** state['step']</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 80</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># bias_correction2 = 1 - beta2 ** state['step']</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 81</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Apply bias to lr to avoid broadcast.</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 82</span>                step_size = group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'lr'</span>] <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># * math.sqrt(bias_correction2) / bias_correction1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 83</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 84</span>                weight_norm = p.data.pow(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>).sum().sqrt().clamp(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 85</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 86</span>                adam_step = exp_avg / exp_avg_sq.sqrt().add(group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eps'</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 87</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay'</span>] != <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 88</span>                    adam_step.add_(group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay'</span>], p.data)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 89</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 90</span>                adam_norm = adam_step.pow(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>).sum().sqrt()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 91</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> weight_norm == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">or</span> adam_norm == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 92</span>                    trust_ratio = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 93</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 94</span>                    trust_ratio = weight_norm / adam_norm<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 95</span>                state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_norm'</span>] = weight_norm<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 96</span>                state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'adam_norm'</span>] = adam_norm<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 97</span>                state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'trust_ratio'</span>] = trust_ratio<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 98</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> self.adam:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 99</span>                    trust_ratio = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">101</span>                p.data.add_(-step_size * trust_ratio, adam_step)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">102</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">103</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> loss<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">104</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">105</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> time<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">106</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">107</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">108</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> tfrecord.torch.dataset <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> TFRecordDataset<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">109</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> numpy <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> np<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">110</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">111</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">112</span>LEARNING_RATE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.001</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">113</span>EPOCH = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">114</span>BATCH_SIZE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">115</span>MAX_GRAD_NORM = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">116</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">117</span>print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Resume/Start training ---"</span>)   <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">118</span>feat_map = {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"input_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>, <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">119</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"input_mask"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">120</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"segment_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">121</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"next_sentence_labels"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">122</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"masked_lm_positions"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">123</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"masked_lm_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>}<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">124</span>pretrain_file = <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'restaurant_review_train'</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">125</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">126</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Create albert pretrain model</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">127</span>config = AlbertConfig.from_json_file(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"albert_config.json"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">128</span>albert_pretrain = AlbertForPretrain(config)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">129</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Create optimizer</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">130</span>optimizer = Lamb([{<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"params"</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> list(albert_pretrain.named_parameters())]}], lr=LEARNING_RATE)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">131</span>albert_pretrain.train()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">132</span>dataset = TFRecordDataset(pretrain_file, index_path = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, description=feat_map)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">133</span>loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">134</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">135</span>tmp_loss = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">136</span>start_time = time.time()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">137</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">138</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> os.path.isfile(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'pretrain_checkpoint'</span>):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">139</span>    print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Load from checkpoint ---"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">140</span>    checkpoint = torch.load(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"pretrain_checkpoint"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">141</span>    albert_pretrain.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">142</span>    optimizer.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">143</span>    epoch = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">144</span>    loss = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'loss'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">145</span>    losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'losses'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">146</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">147</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">148</span>    epoch = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">149</span>    losses = []<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">150</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> e <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(epoch+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>, EPOCH):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">151</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> loader:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">152</span>        b_input_ids = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'input_ids'</span>].long() <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">153</span>        b_token_type_ids = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'segment_ids'</span>].long() <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">154</span>        b_seq_relationship_labels = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'next_sentence_labels'</span>].long()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">155</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">156</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Convert the dataformat from loaded decoded format into format </span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">157</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># loaded format is created by google's Albert create_pretrain.py script</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">158</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># required by huggingfaces pytorch implementation of albert</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">159</span>        mask_rows = np.nonzero(batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy())[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">160</span>        mask_cols = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()[batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()!=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">161</span>        b_attention_mask = np.zeros((BATCH_SIZE,<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span>),dtype=np.int64)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">162</span>        b_attention_mask[mask_rows,mask_cols] = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">163</span>        b_masked_lm_labels = np.zeros((BATCH_SIZE,<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span>),dtype=np.int64) - <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">164</span>        b_masked_lm_labels[mask_rows,mask_cols] = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_ids'</span>].numpy()[batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()!=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]     <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">165</span>        b_attention_mask=torch.tensor(b_attention_mask).long()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">166</span>        b_masked_lm_labels=torch.tensor(b_masked_lm_labels).long()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">167</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">168</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">169</span>        loss = albert_pretrain(input_ids = b_input_ids<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">170</span>                              , attention_mask = b_attention_mask<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">171</span>                              , token_type_ids = b_token_type_ids<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">172</span>                              , masked_lm_labels = b_masked_lm_labels <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">173</span>                              , seq_relationship_labels = b_seq_relationship_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">174</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">175</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># clears old gradients</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">176</span>        optimizer.zero_grad()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">177</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># backward pass</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">178</span>        loss.backward()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">179</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># gradient clipping</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">180</span>        torch.nn.utils.clip_grad_norm_(parameters=albert_pretrain.parameters(), max_norm=MAX_GRAD_NORM)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">181</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># update parameters</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">182</span>        optimizer.step()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">183</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">184</span>        tmp_loss += loss.detach().item()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">185</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">186</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># print metrics and save to checkpoint every epoch</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">187</span>    print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Epoch: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{e}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">188</span>    print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tmp_loss/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span>)}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">189</span>    print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train Time: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(time.time()-start_time)/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span>}</span> mins"</span>)  <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">190</span>    losses.append(tmp_loss/<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">191</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">192</span>    tmp_loss = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">193</span>    start_time = time.time()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">194</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">195</span>    torch.save({<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>: albert_pretrain.state_dict(),<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>: optimizer.state_dict(),<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">196</span>               <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>: e, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'loss'</span>: loss,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'losses'</span>: losses}<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">197</span>           , <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'pretrain_checkpoint'</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">198</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> matplotlib <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pyplot <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> plot<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">199</span>plot.plot(losses)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">200</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">201</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Fine tuning ALBERT</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">202</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">203</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># At the time of writing, Hugging face didnt provide the class object for </span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">204</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># AlbertForTokenClassification, hence write your own defination below</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">205</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertModel, AlbertPreTrainedModel<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">206</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.configuration_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertConfig<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">207</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.tokenization_bert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> BertTokenizer<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">208</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">209</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> CrossEntropyLoss<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">210</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertForTokenClassification</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(AlbertPreTrainedModel)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">211</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">212</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, albert, config)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">213</span>        super().__init__(config)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">214</span>        self.num_labels = config.num_labels<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">215</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">216</span>        self.albert = albert<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">217</span>        self.dropout = nn.Dropout(config.hidden_dropout_prob)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">218</span>        self.classifier = nn.Linear(config.hidden_size, config.num_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">219</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">220</span>    <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">221</span>        self,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">222</span>        input_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">223</span>        attention_mask=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">224</span>        token_type_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">225</span>        position_ids=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">226</span>        head_mask=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">227</span>        inputs_embeds=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">228</span>        labels=None,<br  /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">229</span>    )</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">230</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">231</span>        outputs = self.albert(<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">232</span>            input_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">233</span>            attention_mask=attention_mask,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">234</span>            token_type_ids=token_type_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">235</span>            position_ids=position_ids,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">236</span>            head_mask=head_mask,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">237</span>            inputs_embeds=inputs_embeds,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">238</span>        )<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">239</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">240</span>        sequence_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">241</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">242</span>        sequence_output = self.dropout(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">243</span>        logits = self.classifier(sequence_output)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">244</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">245</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> logits<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">246</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">247</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> numpy <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> np<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">248</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">label_sent</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(name_tokens, sent_tokens)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">249</span>    label = []<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">250</span>    i = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">251</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(name_tokens)>len(sent_tokens):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">252</span>        label = np.zeros(len(sent_tokens))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">253</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">254</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">while</span> i<len(sent_tokens):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">255</span>            found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">False</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">256</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> name_tokens[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>] == sent_tokens[i]:       <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">257</span>                found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">True</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">258</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> j <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(len(name_tokens)<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">259</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> ((i+j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)>=len(sent_tokens)):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">260</span>                        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> label<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">261</span>                    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> name_tokens[j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] != sent_tokens[i+j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">262</span>                        found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">False</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">263</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> found_match:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">264</span>                    label.extend(list(np.ones(len(name_tokens)).astype(int)))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">265</span>                    i = i + len(name_tokens)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">266</span>                <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>: <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">267</span>                    label.extend([<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">268</span>                    i = i+ <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">269</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">270</span>                label.extend([<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">271</span>                i=i+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">272</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> label<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">273</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">274</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pandas <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> pd<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">275</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> glob<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">276</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">277</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">278</span>tokenizer = BertTokenizer(vocab_file=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"vocab.txt"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">279</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">280</span>df_data_train = pd.read_csv(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"dish_name_train.csv"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">281</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>] = df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'dish_name'</span>].apply(tokenizer.tokenize)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">282</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>] = df_data_train.review.apply(tokenizer.tokenize)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">283</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>] = df_data_train.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: label_sent(row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>], row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">284</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">285</span>df_data_val = pd.read_csv(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"dish_name_val.csv"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">286</span>df_data_val = df_data_val.dropna().reset_index()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">287</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>] = df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'dish_name'</span>].apply(tokenizer.tokenize)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">288</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>] = df_data_val.review.apply(tokenizer.tokenize)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">289</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>] = df_data_val.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: label_sent(row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>], row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">290</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">291</span>MAX_LEN = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">292</span>BATCH_SIZE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">293</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> keras.preprocessing.sequence <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pad_sequences<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">294</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">295</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.utils.data <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> TensorDataset, DataLoader, RandomSampler, SequentialSampler<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">296</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">297</span>tr_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]],maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">298</span>tr_tags = pad_sequences(df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>],maxlen=MAX_LEN, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>,dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">299</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># create the mask to ignore the padded elements in the sequences.</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">300</span>tr_masks = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> tr_inputs]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">301</span>tr_inputs = torch.tensor(tr_inputs)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">302</span>tr_tags = torch.tensor(tr_tags)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">303</span>tr_masks = torch.tensor(tr_masks)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">304</span>train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">305</span>train_sampler = RandomSampler(train_data)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">306</span>train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">307</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">308</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">309</span>val_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]],maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">310</span>val_tags = pad_sequences(df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>],maxlen=MAX_LEN, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>,dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">311</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># create the mask to ignore the padded elements in the sequences.</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">312</span>val_masks = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> val_inputs]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">313</span>val_inputs = torch.tensor(val_inputs)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">314</span>val_tags = torch.tensor(val_tags)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">315</span>val_masks = torch.tensor(val_masks)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">316</span>val_data = TensorDataset(val_inputs, val_masks, val_tags)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">317</span>val_sampler = RandomSampler(val_data)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">318</span>val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=BATCH_SIZE)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">319</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">320</span>model_tokenclassification = AlbertForTokenClassification(albert_pretrain.albert, config)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">321</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.optim <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> Adam<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">322</span>LEARNING_RATE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0000003</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">323</span>FULL_FINETUNING = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">True</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">324</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> FULL_FINETUNING:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">325</span>    param_optimizer = list(model_tokenclassification.named_parameters())<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">326</span>    no_decay = [<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'bias'</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'gamma'</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'beta'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">327</span>    optimizer_grouped_parameters = [<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">328</span>        {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> any(nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> n <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> no_decay)],<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">329</span>         <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay_rate'</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.01</span>},<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">330</span>        {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> any(nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> n <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> no_decay)],<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">331</span>         <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay_rate'</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span>}<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">332</span>    ]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">333</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">334</span>    param_optimizer = list(model_tokenclassification.classifier.named_parameters()) <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">335</span>    optimizer_grouped_parameters = [{<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"params"</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer]}]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">336</span>optimizer = Adam(optimizer_grouped_parameters, lr=LEARNING_RATE)<br  /></section>
第四步:为自定义语料库训练模型
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Training the model</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  2</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  3</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># from torch.utils.tensorboard import SummaryWriter</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> time<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os.path<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  6</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  7</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  8</span>EPOCH = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">800</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">  9</span>MAX_GRAD_NORM = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 10</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 11</span>start_time = time.time()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 12</span>tr_loss, tr_acc, nb_tr_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 13</span>eval_loss, eval_acc, nb_eval_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 14</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 15</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> os.path.isfile(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'finetune_checkpoint'</span>):<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 16</span>    print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Load from checkpoint ---"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 17</span>    checkpoint = torch.load(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"finetune_checkpoint"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 18</span>    model_tokenclassification.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 19</span>    optimizer.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 20</span>    epoch = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 21</span>    train_losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_losses'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 22</span>    train_accs = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accs'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 23</span>    eval_losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_losses'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 24</span>    eval_accs = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_accs'</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 25</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 26</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 27</span>    epoch = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 28</span>    train_losses,train_accs,eval_losses,eval_accs = [],[],[],[]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 29</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 30</span>print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Resume/Start training ---"</span>)    <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 31</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> e <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(epoch+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>, EPOCH): <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 32</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 33</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># TRAIN loop</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 34</span>    model_tokenclassification.train()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 35</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 36</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> train_dataloader:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 37</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># add batch to gpu</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 38</span>        batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 39</span>        b_input_ids, b_input_mask, b_labels = batch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 40</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># forward pass</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 41</span>        b_outputs = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, attention_mask=b_input_mask, labels=b_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 42</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 43</span>        ce_loss_fct = CrossEntropyLoss()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 44</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Only keep active parts of the loss</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 45</span>        b_active_loss = b_input_mask.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 46</span>        b_active_logits = b_outputs.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, config.num_labels)[b_active_loss]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 47</span>        b_active_labels = b_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>)[b_active_loss]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 48</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 49</span>        loss = ce_loss_fct(b_active_logits, b_active_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 50</span>        acc = torch.mean((torch.max(b_active_logits.detach(),<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] == b_active_labels.detach()).float())<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 51</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 52</span>        model_tokenclassification.zero_grad()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 53</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># backward pass</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 54</span>        loss.backward()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 55</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># track train loss</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 56</span>        tr_loss += loss.item()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 57</span>        tr_acc += acc<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 58</span>        nb_tr_steps += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 59</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># gradient clipping</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 60</span>        torch.nn.utils.clip_grad_norm_(parameters=model_tokenclassification.parameters(), max_norm=MAX_GRAD_NORM)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 61</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># update parameters</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 62</span>        optimizer.step()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 63</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 64</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 65</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># VALIDATION on validation set</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 66</span>    model_tokenclassification.eval()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 67</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> val_dataloader:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 68</span>        batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 69</span>        b_input_ids, b_input_mask, b_labels = batch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 70</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 71</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">with</span> torch.no_grad():<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 72</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 73</span>            b_outputs = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 74</span>                         attention_mask=b_input_mask, labels=b_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 75</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 76</span>            loss_fct = CrossEntropyLoss()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 77</span>            <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Only keep active parts of the loss</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 78</span>            b_active_loss = b_input_mask.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 79</span>            b_active_logits = b_outputs.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, config.num_labels)[b_active_loss]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 80</span>            b_active_labels = b_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>)[b_active_loss]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 81</span>            loss = loss_fct(b_active_logits, b_active_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 82</span>            acc = np.mean(np.argmax(b_active_logits.detach().cpu().numpy(), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>).flatten() == b_active_labels.detach().cpu().numpy().flatten())<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 83</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 84</span>        eval_loss += loss.mean().item()<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 85</span>        eval_acc += acc<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 86</span>        nb_eval_steps += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>    <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 87</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 88</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> e % <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span> ==<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 89</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 90</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Epoch: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{e}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 91</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tr_loss/nb_tr_steps)}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 92</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train acc: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tr_acc/nb_tr_steps)}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 93</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train Time: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(time.time()-start_time)/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span>}</span> mins"</span>)  <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 94</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 95</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Validation loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{eval_loss/nb_eval_steps}</span>"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 96</span>        print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Validation Accuracy: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(eval_acc/nb_eval_steps)}</span>"</span>) <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 97</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 98</span>        train_losses.append(tr_loss/nb_tr_steps)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 99</span>        train_accs.append(tr_acc/nb_tr_steps)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span>        eval_losses.append(eval_loss/nb_eval_steps)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">101</span>        eval_accs.append(eval_acc/nb_eval_steps)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">102</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">103</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">104</span>        tr_loss, tr_acc, nb_tr_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">105</span>        eval_loss, eval_acc, nb_eval_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">106</span>        start_time = time.time() <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">107</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">108</span>        torch.save({<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>: model_tokenclassification.state_dict(),<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>: optimizer.state_dict(),<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">109</span>           <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>: e, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_losses'</span>: train_losses,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accs'</span>: train_accs, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_losses'</span>:eval_losses,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_accs'</span>:eval_accs}<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">110</span>       , <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'finetune_checkpoint'</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">111</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">112</span>plot.plot(train_losses)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">113</span>plot.plot(train_accs)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">114</span>plot.plot(eval_losses)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">115</span>plot.plot(eval_accs)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">116</span>plot.legend(labels = [<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_loss'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accuracy'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'validation_loss'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'validation_accuracy'</span>])<br  /></section>
第五步:预测
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Prediction</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">predict</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(texts)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span>    tokenized_texts = [tokenizer.tokenize(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> texts]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span>    input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> tokenized_texts],<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span>                              maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span>    attention_mask = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> input_ids]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span>    input_ids = torch.tensor(input_ids)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span>    attention_mask = torch.tensor(attention_mask)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span>    dataset = TensorDataset(input_ids, attention_mask)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span>    datasampler = SequentialSampler(dataset)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span>    dataloader = DataLoader(dataset, sampler=datasampler, batch_size=BATCH_SIZE) <br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span>    predicted_labels = []<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">18</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> dataloader:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">19</span>        batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span>        b_input_ids, b_input_mask = batch<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">21</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">22</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">with</span> torch.no_grad():<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">23</span>            logits = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>,<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">24</span>                           attention_mask=b_input_mask)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">25</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">26</span>            predicted_labels.append(np.multiply(np.argmax(logits.detach().cpu().numpy(),axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>), b_input_mask.detach().cpu().numpy()))<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">27</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># np.concatenate(predicted_labels), to flatten list of arrays of batch_size * max_len into list of arrays of max_len</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">28</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> np.concatenate(predicted_labels).astype(int), tokenized_texts<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">29</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">30</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">get_dish_candidate_names</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(predicted_label, tokenized_text)</span>:</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">31</span>    name_lists = []<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">32</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(np.where(predicted_label><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">33</span>        name_idx_combined = np.where(predicted_label><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">34</span>        name_idxs = np.split(name_idx_combined, np.where(np.diff(name_idx_combined) != <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">35</span>        name_lists.append([<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">" "</span>.join(np.take(tokenized_text,name_idx)) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> name_idx <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> name_idxs])<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">36</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># If there duplicate names in the name_lists</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">37</span>        name_lists = np.unique(name_lists)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">38</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> name_lists<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">39</span>    <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span>        <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">41</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">42</span>texts = df_data_val.review.values<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">43</span>predicted_labels, _ = predict(texts)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">44</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_review_label'</span>] = list(predicted_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">45</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_name'</span>]=df_data_val.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: get_dish_candidate_names(row.predicted_review_label, row.review_tokens)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">46</span>                                                , axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">47</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">48</span>texts = df_data_train.review.values<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">49</span>predicted_labels, _ = predict(texts)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">50</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_review_label'</span>] = list(predicted_labels)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">51</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_name'</span>]=df_data_train.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: get_dish_candidate_names(row.predicted_review_label, row.review_tokens)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">52</span>                                                , axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">53</span><br  /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">54</span>(df_data_val)<br  /></section>
实验结果
“瘦身成功”的ALBERT,能取代BERT吗?
“瘦身成功”的ALBERT,能取代BERT吗?
可以看到,模型成功地从用餐评论中,提取出了菜名。


   模型比拼

从上面的实战应用中可以看到,ALBERT虽然很lite,结果也可以说相当不错。
那么,参数少、结果好,是否就可以替代BERT呢?
“瘦身成功”的ALBERT,能取代BERT吗?
我们可以仔细看下二者实验性能的比较,这里的Speedup是指训练时间。
因为数据数据少了,分布式训练时吞吐上去了,所以ALBERT训练更快。但推理时间还是需要和BERT一样的transformer计算。
所以可以总结为:
  • 在相同的训练时间下,ALBERT效果要比BERT好。

  • 在相同的推理时间下,ALBERT base和large的效果都是没有BERT好。

此外,Naman Bansal认为,由于ALBERT的结构,实现ALBERT的计算代价比BERT要高一些。
所以,还是“鱼和熊掌不可兼得”的关系,要想让ALBERT完全超越、替代BERT,还需要做更进一步的研究和改良。
原文地址:
https://medium.com/@namanbansal9909/should-we-shift-from-bert-to-albert-e6fbb7779d3e
作者系网易新闻·网易号“各有态度”签约作者
<pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;white-space: normal;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;line-height: 1.75em;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><section style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">人工智能领域最具影响力的十大女科学家</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">MIT最新深度学习入门课,安排起来!</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">有了这个神器,轻松用 Python 写个 App</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">「最全」实至名归,NumPy 官方早有中文教程</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">漫画版 Linux 内核的世界</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"  /></section></section></section></section></section></section></section></section></section>
“瘦身成功”的ALBERT,能取代BERT吗?

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享