来自 | 十三 发自 凹非寺
转自 | 量子位
是否应该用ALBERT来代替BERT?
BERT与ALBERT
如何实现自定义语料库(预训练)ALBERT?
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Downlading all files and data</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_train.csv<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/dish_name_val.csv<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review.txt<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/data_toy/restaurant_review_nopunct.txt<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/models_toy/albert_config.json<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/finetune_checkpoint<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span>!wget https://github.com/LydiaXiaohongLi/Albert_Finetune_with_Pretrain_on_Custom_Corpus/raw/master/model_checkpoint/pretrain_checkpoint<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Creating files and setting up ALBERT</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span>!pip install sentencepiece<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span>!git clone https://github.com/google-research/ALBERT<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span>!python ./ALBERT/create_pretraining_data.py --input_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"restaurant_review.txt"</span> --output_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"restaurant_review_train"</span> --vocab_file <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"vocab.txt"</span> --max_seq_length=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span>!pip install transformers<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span>!pip install tfrecord<br /></section>
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Defining Layers for ALBERT</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertModel, AlbertPreTrainedModel<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.configuration_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertConfig<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertSequenceOrderHead</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(nn.Module)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, config)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span> super().__init__()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span> self.dense = nn.Linear(config.hidden_size, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span> self.bias = nn.Parameter(torch.zeros(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, hidden_states)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span> hidden_states = self.dense(hidden_states)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span> prediction_scores = hidden_states + self.bias<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> prediction_scores<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">18</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> CrossEntropyLoss<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">19</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_bert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> ACT2FN<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertForPretrain</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(AlbertPreTrainedModel)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">21</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">22</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, config)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">23</span> super().__init__(config)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">24</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">25</span> self.albert = AlbertModel(config) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">26</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">27</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># For Masked LM</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">28</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># The original huggingface implementation, created new output weights via dense layer</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">29</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># However the original Albert </span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">30</span> self.predictions_dense = nn.Linear(config.hidden_size, config.embedding_size)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">31</span> self.predictions_activation = ACT2FN[config.hidden_act]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">32</span> self.predictions_LayerNorm = nn.LayerNorm(config.embedding_size)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">33</span> self.predictions_bias = nn.Parameter(torch.zeros(config.vocab_size)) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">34</span> self.predictions_decoder = nn.Linear(config.embedding_size, config.vocab_size)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">35</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">36</span> self.predictions_decoder.weight = self.albert.embeddings.word_embeddings.weight<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">37</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">38</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># For sequence order prediction</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">39</span> self.seq_relationship = AlbertSequenceOrderHead(config)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">41</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">42</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">43</span> self,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">44</span> input_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">45</span> attention_mask=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">46</span> token_type_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">47</span> position_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">48</span> head_mask=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">49</span> inputs_embeds=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">50</span> masked_lm_labels=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">51</span> seq_relationship_labels=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">52</span> )</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">53</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">54</span> outputs = self.albert(<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">55</span> input_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">56</span> attention_mask=attention_mask,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">57</span> token_type_ids=token_type_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">58</span> position_ids=position_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">59</span> head_mask=head_mask,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span> inputs_embeds=inputs_embeds,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">61</span> )<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">62</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">63</span> loss_fct = CrossEntropyLoss()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">65</span> sequence_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">66</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">67</span> sequence_output = self.predictions_dense(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">68</span> sequence_output = self.predictions_activation(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">69</span> sequence_output = self.predictions_LayerNorm(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">70</span> prediction_scores = self.predictions_decoder(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">71</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">72</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">73</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> masked_lm_labels <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">74</span> masked_lm_loss = loss_fct(prediction_scores.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, self.config.vocab_size)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">75</span> , masked_lm_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">76</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">77</span> pooled_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">78</span> seq_relationship_scores = self.seq_relationship(pooled_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">79</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> seq_relationship_labels <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>: <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">80</span> seq_relationship_loss = loss_fct(seq_relationship_scores.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>), seq_relationship_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">81</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">82</span> loss = masked_lm_loss + seq_relationship_loss<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">83</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">84</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> loss<br /></section>
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Using LAMB optimizer</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#LAMB - "https://github.com/cybertronai/pytorch-lamb"</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.optim <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> Optimizer<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">Lamb</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(Optimizer)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">r"""Implements Lamb algorithm.<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span> It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span> Arguments:<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 10</span> params (iterable): iterable of parameters to optimize or dicts defining<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 11</span> parameter groups<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 12</span> lr (float, optional): learning rate (default: 1e-3)<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 13</span> betas (Tuple[float, float], optional): coefficients used for computing<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 14</span> running averages of gradient and its square (default: (0.9, 0.999))<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 15</span> eps (float, optional): term added to the denominator to improve<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 16</span> numerical stability (default: 1e-8)<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 17</span> weight_decay (float, optional): weight decay (L2 penalty) (default: 0)<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 18</span> adam (bool, optional): always use trust ratio = 1, which turns this into<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 19</span> Adam. Useful for comparison purposes.<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 20</span> .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 21</span> https://arxiv.org/abs/1904.00962<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 22</span> """</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 23</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 24</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, params, lr=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1e-3</span>, betas=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">(<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.9</span>, <span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0.999</span>)</span>, eps=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">1e-6</span>,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"> 25</span> weight_decay=<span style="line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, adam=False)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 26</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= lr:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 27</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid learning rate: {}"</span>.format(lr))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 28</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= eps:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 29</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid epsilon value: {}"</span>.format(eps))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 30</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>] < <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 31</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid beta parameter at index 0: {}"</span>.format(betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 32</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span> <= betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] < <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 33</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> ValueError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"Invalid beta parameter at index 1: {}"</span>.format(betas[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 34</span> defaults = dict(lr=lr, betas=betas, eps=eps,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 35</span> weight_decay=weight_decay)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 36</span> self.adam = adam<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 37</span> super(Lamb, self).__init__(params, defaults)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 38</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 39</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">step</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, closure=None)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 40</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"""Performs a single optimization step.<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 41</span> Arguments:<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 42</span> closure (callable, optional): A closure that reevaluates the model<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 43</span> and returns the loss.<br /><span style="padding-right: 20px;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 44</span> """</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 45</span> loss = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 46</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> closure <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 47</span> loss = closure()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 48</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 49</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> group <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> self.param_groups:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 50</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>]:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 51</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> p.grad <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">is</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 52</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">continue</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 53</span> grad = p.grad.data<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 54</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> grad.is_sparse:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 55</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">raise</span> RuntimeError(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'Lamb does not support sparse gradients, consider SparseAdam instad.'</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 56</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 57</span> state = self.state[p]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 58</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 59</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># State initialization</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 60</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(state) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 61</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'step'</span>] = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 62</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Exponential moving average of gradient values</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 63</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg'</span>] = torch.zeros_like(p.data)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 64</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Exponential moving average of squared gradient values</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 65</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg_sq'</span>] = torch.zeros_like(p.data)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 66</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 67</span> exp_avg, exp_avg_sq = state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg'</span>], state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'exp_avg_sq'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 68</span> beta1, beta2 = group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'betas'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 69</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 70</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'step'</span>] += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 71</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 72</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Decay the first and second moment running average coefficient</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 73</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># m_t</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 74</span> exp_avg.mul_(beta1).add_(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span> - beta1, grad)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 75</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># v_t</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 76</span> exp_avg_sq.mul_(beta2).addcmul_(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span> - beta2, grad, grad)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 77</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 78</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Paper v3 does not use debiasing.</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 79</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># bias_correction1 = 1 - beta1 ** state['step']</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 80</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># bias_correction2 = 1 - beta2 ** state['step']</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 81</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Apply bias to lr to avoid broadcast.</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 82</span> step_size = group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'lr'</span>] <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># * math.sqrt(bias_correction2) / bias_correction1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 83</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 84</span> weight_norm = p.data.pow(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>).sum().sqrt().clamp(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 85</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 86</span> adam_step = exp_avg / exp_avg_sq.sqrt().add(group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eps'</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 87</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay'</span>] != <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 88</span> adam_step.add_(group[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay'</span>], p.data)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 89</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 90</span> adam_norm = adam_step.pow(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>).sum().sqrt()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 91</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> weight_norm == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">or</span> adam_norm == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 92</span> trust_ratio = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 93</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 94</span> trust_ratio = weight_norm / adam_norm<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 95</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_norm'</span>] = weight_norm<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 96</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'adam_norm'</span>] = adam_norm<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 97</span> state[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'trust_ratio'</span>] = trust_ratio<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 98</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> self.adam:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 99</span> trust_ratio = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">101</span> p.data.add_(-step_size * trust_ratio, adam_step)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">102</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">103</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> loss<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">104</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">105</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> time<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">106</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">107</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">108</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> tfrecord.torch.dataset <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> TFRecordDataset<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">109</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> numpy <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> np<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">110</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">111</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">112</span>LEARNING_RATE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.001</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">113</span>EPOCH = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">114</span>BATCH_SIZE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">115</span>MAX_GRAD_NORM = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">116</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">117</span>print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Resume/Start training ---"</span>) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">118</span>feat_map = {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"input_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>, <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">119</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"input_mask"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">120</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"segment_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">121</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"next_sentence_labels"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">122</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"masked_lm_positions"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">123</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"masked_lm_ids"</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"int"</span>}<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">124</span>pretrain_file = <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'restaurant_review_train'</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">125</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">126</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Create albert pretrain model</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">127</span>config = AlbertConfig.from_json_file(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"albert_config.json"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">128</span>albert_pretrain = AlbertForPretrain(config)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">129</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Create optimizer</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">130</span>optimizer = Lamb([{<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"params"</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> list(albert_pretrain.named_parameters())]}], lr=LEARNING_RATE)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">131</span>albert_pretrain.train()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">132</span>dataset = TFRecordDataset(pretrain_file, index_path = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, description=feat_map)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">133</span>loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">134</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">135</span>tmp_loss = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">136</span>start_time = time.time()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">137</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">138</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> os.path.isfile(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'pretrain_checkpoint'</span>):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">139</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Load from checkpoint ---"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">140</span> checkpoint = torch.load(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"pretrain_checkpoint"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">141</span> albert_pretrain.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">142</span> optimizer.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">143</span> epoch = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">144</span> loss = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'loss'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">145</span> losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'losses'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">146</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">147</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">148</span> epoch = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">149</span> losses = []<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">150</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> e <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(epoch+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>, EPOCH):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">151</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> loader:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">152</span> b_input_ids = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'input_ids'</span>].long() <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">153</span> b_token_type_ids = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'segment_ids'</span>].long() <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">154</span> b_seq_relationship_labels = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'next_sentence_labels'</span>].long()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">155</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">156</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Convert the dataformat from loaded decoded format into format </span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">157</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># loaded format is created by google's Albert create_pretrain.py script</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">158</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># required by huggingfaces pytorch implementation of albert</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">159</span> mask_rows = np.nonzero(batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy())[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">160</span> mask_cols = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()[batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()!=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">161</span> b_attention_mask = np.zeros((BATCH_SIZE,<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span>),dtype=np.int64)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">162</span> b_attention_mask[mask_rows,mask_cols] = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">163</span> b_masked_lm_labels = np.zeros((BATCH_SIZE,<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span>),dtype=np.int64) - <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">164</span> b_masked_lm_labels[mask_rows,mask_cols] = batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_ids'</span>].numpy()[batch[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'masked_lm_positions'</span>].numpy()!=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>] <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">165</span> b_attention_mask=torch.tensor(b_attention_mask).long()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">166</span> b_masked_lm_labels=torch.tensor(b_masked_lm_labels).long()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">167</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">168</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">169</span> loss = albert_pretrain(input_ids = b_input_ids<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">170</span> , attention_mask = b_attention_mask<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">171</span> , token_type_ids = b_token_type_ids<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">172</span> , masked_lm_labels = b_masked_lm_labels <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">173</span> , seq_relationship_labels = b_seq_relationship_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">174</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">175</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># clears old gradients</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">176</span> optimizer.zero_grad()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">177</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># backward pass</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">178</span> loss.backward()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">179</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># gradient clipping</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">180</span> torch.nn.utils.clip_grad_norm_(parameters=albert_pretrain.parameters(), max_norm=MAX_GRAD_NORM)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">181</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># update parameters</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">182</span> optimizer.step()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">183</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">184</span> tmp_loss += loss.detach().item()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">185</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">186</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># print metrics and save to checkpoint every epoch</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">187</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Epoch: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{e}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">188</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tmp_loss/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span>)}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">189</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train Time: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(time.time()-start_time)/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span>}</span> mins"</span>) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">190</span> losses.append(tmp_loss/<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">191</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">192</span> tmp_loss = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">193</span> start_time = time.time()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">194</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">195</span> torch.save({<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>: albert_pretrain.state_dict(),<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>: optimizer.state_dict(),<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">196</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>: e, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'loss'</span>: loss,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'losses'</span>: losses}<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">197</span> , <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'pretrain_checkpoint'</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">198</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> matplotlib <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pyplot <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> plot<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">199</span>plot.plot(losses)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">200</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">201</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Fine tuning ALBERT</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">202</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">203</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># At the time of writing, Hugging face didnt provide the class object for </span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">204</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># AlbertForTokenClassification, hence write your own defination below</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">205</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.modeling_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertModel, AlbertPreTrainedModel<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">206</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.configuration_albert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> AlbertConfig<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">207</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> transformers.tokenization_bert <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> BertTokenizer<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">208</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">209</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> CrossEntropyLoss<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">210</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">class</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">AlbertForTokenClassification</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(AlbertPreTrainedModel)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">211</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">212</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">__init__</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(self, albert, config)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">213</span> super().__init__(config)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">214</span> self.num_labels = config.num_labels<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">215</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">216</span> self.albert = albert<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">217</span> self.dropout = nn.Dropout(config.hidden_dropout_prob)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">218</span> self.classifier = nn.Linear(config.hidden_size, config.num_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">219</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">220</span> <span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">forward</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">221</span> self,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">222</span> input_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">223</span> attention_mask=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">224</span> token_type_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">225</span> position_ids=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">226</span> head_mask=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">227</span> inputs_embeds=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">228</span> labels=None,<br /><span style="padding-right: 20px;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">229</span> )</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">230</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">231</span> outputs = self.albert(<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">232</span> input_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">233</span> attention_mask=attention_mask,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">234</span> token_type_ids=token_type_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">235</span> position_ids=position_ids,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">236</span> head_mask=head_mask,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">237</span> inputs_embeds=inputs_embeds,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">238</span> )<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">239</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">240</span> sequence_output = outputs[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">241</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">242</span> sequence_output = self.dropout(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">243</span> logits = self.classifier(sequence_output)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">244</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">245</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> logits<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">246</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">247</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> numpy <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> np<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">248</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">label_sent</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(name_tokens, sent_tokens)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">249</span> label = []<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">250</span> i = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">251</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(name_tokens)>len(sent_tokens):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">252</span> label = np.zeros(len(sent_tokens))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">253</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">254</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">while</span> i<len(sent_tokens):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">255</span> found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">False</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">256</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> name_tokens[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>] == sent_tokens[i]: <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">257</span> found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">True</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">258</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> j <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(len(name_tokens)<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">259</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> ((i+j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)>=len(sent_tokens)):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">260</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> label<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">261</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> name_tokens[j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] != sent_tokens[i+j+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>]:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">262</span> found_match = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">False</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">263</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> found_match:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">264</span> label.extend(list(np.ones(len(name_tokens)).astype(int)))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">265</span> i = i + len(name_tokens)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">266</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>: <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">267</span> label.extend([<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">268</span> i = i+ <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">269</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">270</span> label.extend([<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">271</span> i=i+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">272</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> label<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">273</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">274</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pandas <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> pd<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">275</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> glob<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">276</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">277</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">278</span>tokenizer = BertTokenizer(vocab_file=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"vocab.txt"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">279</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">280</span>df_data_train = pd.read_csv(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"dish_name_train.csv"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">281</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>] = df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'dish_name'</span>].apply(tokenizer.tokenize)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">282</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>] = df_data_train.review.apply(tokenizer.tokenize)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">283</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>] = df_data_train.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: label_sent(row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>], row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">284</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">285</span>df_data_val = pd.read_csv(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"dish_name_val.csv"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">286</span>df_data_val = df_data_val.dropna().reset_index()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">287</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>] = df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'dish_name'</span>].apply(tokenizer.tokenize)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">288</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>] = df_data_val.review.apply(tokenizer.tokenize)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">289</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>] = df_data_val.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: label_sent(row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'name_tokens'</span>], row[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">290</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">291</span>MAX_LEN = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">64</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">292</span>BATCH_SIZE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">293</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> keras.preprocessing.sequence <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> pad_sequences<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">294</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">295</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.utils.data <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> TensorDataset, DataLoader, RandomSampler, SequentialSampler<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">296</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">297</span>tr_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]],maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">298</span>tr_tags = pad_sequences(df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>],maxlen=MAX_LEN, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>,dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">299</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># create the mask to ignore the padded elements in the sequences.</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">300</span>tr_masks = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> tr_inputs]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">301</span>tr_inputs = torch.tensor(tr_inputs)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">302</span>tr_tags = torch.tensor(tr_tags)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">303</span>tr_masks = torch.tensor(tr_masks)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">304</span>train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">305</span>train_sampler = RandomSampler(train_data)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">306</span>train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">307</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">308</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">309</span>val_inputs = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_tokens'</span>]],maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">310</span>val_tags = pad_sequences(df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'review_label'</span>],maxlen=MAX_LEN, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>,dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">311</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># create the mask to ignore the padded elements in the sequences.</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">312</span>val_masks = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> val_inputs]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">313</span>val_inputs = torch.tensor(val_inputs)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">314</span>val_tags = torch.tensor(val_tags)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">315</span>val_masks = torch.tensor(val_masks)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">316</span>val_data = TensorDataset(val_inputs, val_masks, val_tags)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">317</span>val_sampler = RandomSampler(val_data)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">318</span>val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=BATCH_SIZE)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">319</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">320</span>model_tokenclassification = AlbertForTokenClassification(albert_pretrain.albert, config)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">321</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">from</span> torch.optim <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> Adam<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">322</span>LEARNING_RATE = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0000003</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">323</span>FULL_FINETUNING = <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">True</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">324</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> FULL_FINETUNING:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">325</span> param_optimizer = list(model_tokenclassification.named_parameters())<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">326</span> no_decay = [<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'bias'</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'gamma'</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'beta'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">327</span> optimizer_grouped_parameters = [<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">328</span> {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">not</span> any(nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> n <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> no_decay)],<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">329</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay_rate'</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.01</span>},<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">330</span> {<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'params'</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> any(nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> n <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> nd <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> no_decay)],<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">331</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'weight_decay_rate'</span>: <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0.0</span>}<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">332</span> ]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">333</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">334</span> param_optimizer = list(model_tokenclassification.classifier.named_parameters()) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">335</span> optimizer_grouped_parameters = [{<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"params"</span>: [p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> n, p <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> param_optimizer]}]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">336</span>optimizer = Adam(optimizer_grouped_parameters, lr=LEARNING_RATE)<br /></section>
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Training the model</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># from torch.utils.tensorboard import SummaryWriter</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> time<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> os.path<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch.nn <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">as</span> nn<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">import</span> torch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span>EPOCH = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">800</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span>MAX_GRAD_NORM = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1.0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 10</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 11</span>start_time = time.time()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 12</span>tr_loss, tr_acc, nb_tr_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 13</span>eval_loss, eval_acc, nb_eval_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 14</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 15</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> os.path.isfile(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'finetune_checkpoint'</span>):<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 16</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Load from checkpoint ---"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 17</span> checkpoint = torch.load(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"finetune_checkpoint"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 18</span> model_tokenclassification.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 19</span> optimizer.load_state_dict(checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 20</span> epoch = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 21</span> train_losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_losses'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 22</span> train_accs = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accs'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 23</span> eval_losses = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_losses'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 24</span> eval_accs = checkpoint[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_accs'</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 25</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 26</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 27</span> epoch = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 28</span> train_losses,train_accs,eval_losses,eval_accs = [],[],[],[]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 29</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 30</span>print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"--- Resume/Start training ---"</span>) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 31</span><span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> e <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> range(epoch+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>, EPOCH): <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 32</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 33</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># TRAIN loop</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 34</span> model_tokenclassification.train()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 35</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 36</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> train_dataloader:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 37</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># add batch to gpu</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 38</span> batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 39</span> b_input_ids, b_input_mask, b_labels = batch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 40</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># forward pass</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 41</span> b_outputs = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>, attention_mask=b_input_mask, labels=b_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 42</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 43</span> ce_loss_fct = CrossEntropyLoss()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 44</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Only keep active parts of the loss</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 45</span> b_active_loss = b_input_mask.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 46</span> b_active_logits = b_outputs.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, config.num_labels)[b_active_loss]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 47</span> b_active_labels = b_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>)[b_active_loss]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 48</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 49</span> loss = ce_loss_fct(b_active_logits, b_active_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 50</span> acc = torch.mean((torch.max(b_active_logits.detach(),<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>] == b_active_labels.detach()).float())<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 51</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 52</span> model_tokenclassification.zero_grad()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 53</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># backward pass</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 54</span> loss.backward()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 55</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># track train loss</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 56</span> tr_loss += loss.item()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 57</span> tr_acc += acc<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 58</span> nb_tr_steps += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 59</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># gradient clipping</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 60</span> torch.nn.utils.clip_grad_norm_(parameters=model_tokenclassification.parameters(), max_norm=MAX_GRAD_NORM)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 61</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># update parameters</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 62</span> optimizer.step()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 63</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 64</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 65</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># VALIDATION on validation set</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 66</span> model_tokenclassification.eval()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 67</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> val_dataloader:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 68</span> batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 69</span> b_input_ids, b_input_mask, b_labels = batch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 70</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 71</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">with</span> torch.no_grad():<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 72</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 73</span> b_outputs = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 74</span> attention_mask=b_input_mask, labels=b_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 75</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 76</span> loss_fct = CrossEntropyLoss()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 77</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># Only keep active parts of the loss</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 78</span> b_active_loss = b_input_mask.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>) == <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 79</span> b_active_logits = b_outputs.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>, config.num_labels)[b_active_loss]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 80</span> b_active_labels = b_labels.view(<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">-1</span>)[b_active_loss]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 81</span> loss = loss_fct(b_active_logits, b_active_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 82</span> acc = np.mean(np.argmax(b_active_logits.detach().cpu().numpy(), axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>).flatten() == b_active_labels.detach().cpu().numpy().flatten())<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 83</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 84</span> eval_loss += loss.mean().item()<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 85</span> eval_acc += acc<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 86</span> nb_eval_steps += <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span> <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 87</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 88</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> e % <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span> ==<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 89</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 90</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Epoch: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{e}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 91</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tr_loss/nb_tr_steps)}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 92</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train acc: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(tr_acc/nb_tr_steps)}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 93</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Train Time: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(time.time()-start_time)/<span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">60</span>}</span> mins"</span>) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 94</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 95</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Validation loss: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{eval_loss/nb_eval_steps}</span>"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 96</span> print(<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">f"Validation Accuracy: <span style="color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;">{(eval_acc/nb_eval_steps)}</span>"</span>) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 97</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 98</span> train_losses.append(tr_loss/nb_tr_steps)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 99</span> train_accs.append(tr_acc/nb_tr_steps)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">100</span> eval_losses.append(eval_loss/nb_eval_steps)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">101</span> eval_accs.append(eval_acc/nb_eval_steps)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">102</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">103</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">104</span> tr_loss, tr_acc, nb_tr_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">105</span> eval_loss, eval_acc, nb_eval_steps = <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>, <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span> <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">106</span> start_time = time.time() <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">107</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">108</span> torch.save({<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'model_state_dict'</span>: model_tokenclassification.state_dict(),<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'optimizer_state_dict'</span>: optimizer.state_dict(),<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">109</span> <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'epoch'</span>: e, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_losses'</span>: train_losses,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accs'</span>: train_accs, <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_losses'</span>:eval_losses,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'eval_accs'</span>:eval_accs}<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">110</span> , <span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'finetune_checkpoint'</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">111</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">112</span>plot.plot(train_losses)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">113</span>plot.plot(train_accs)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">114</span>plot.plot(eval_losses)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">115</span>plot.plot(eval_accs)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">116</span>plot.legend(labels = [<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_loss'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'train_accuracy'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'validation_loss'</span>,<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'validation_accuracy'</span>])<br /></section>
<section style="padding: 0.5em;line-height: 18px;font-size: 14px;letter-spacing: 0px;font-family: Consolas, Inconsolata, Courier, monospace;border-radius: 0px;background: rgb(29, 31, 33);color: rgb(197, 200, 198);margin-left: 8px;margin-right: 8px;overflow-wrap: normal !important;word-break: normal !important;overflow: auto !important;display: -webkit-box !important;"><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 1</span><span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;">#Prediction</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 2</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 3</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">predict</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(texts)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 4</span> tokenized_texts = [tokenizer.tokenize(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> texts]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 5</span> input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> txt <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> tokenized_texts],<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 6</span> maxlen=MAX_LEN, dtype=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"long"</span>, truncating=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>, padding=<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">"post"</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 7</span> attention_mask = [[float(i><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> i <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> ii] <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> ii <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> input_ids]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 8</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;"> 9</span> input_ids = torch.tensor(input_ids)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">10</span> attention_mask = torch.tensor(attention_mask)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">11</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">12</span> dataset = TensorDataset(input_ids, attention_mask)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">13</span> datasampler = SequentialSampler(dataset)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">14</span> dataloader = DataLoader(dataset, sampler=datasampler, batch_size=BATCH_SIZE) <br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">15</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">16</span> predicted_labels = []<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">17</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">18</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> batch <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> dataloader:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">19</span> batch = tuple(t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> t <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> batch)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">20</span> b_input_ids, b_input_mask = batch<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">21</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">22</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">with</span> torch.no_grad():<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">23</span> logits = model_tokenclassification(b_input_ids, token_type_ids=<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span>,<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">24</span> attention_mask=b_input_mask)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">25</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">26</span> predicted_labels.append(np.multiply(np.argmax(logits.detach().cpu().numpy(),axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">2</span>), b_input_mask.detach().cpu().numpy()))<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">27</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># np.concatenate(predicted_labels), to flatten list of arrays of batch_size * max_len into list of arrays of max_len</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">28</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> np.concatenate(predicted_labels).astype(int), tokenized_texts<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">29</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">30</span><span style="font-size: inherit;color: inherit;line-height: inherit;overflow-wrap: inherit !important;word-break: inherit !important;"><span style="line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">def</span> <span style="line-height: inherit;color: rgb(129, 162, 190);overflow-wrap: inherit !important;word-break: inherit !important;">get_dish_candidate_names</span><span style="line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">(predicted_label, tokenized_text)</span>:</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">31</span> name_lists = []<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">32</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">if</span> len(np.where(predicted_label><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>])><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">33</span> name_idx_combined = np.where(predicted_label><span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">34</span> name_idxs = np.split(name_idx_combined, np.where(np.diff(name_idx_combined) != <span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)[<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">0</span>]+<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">35</span> name_lists.append([<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">" "</span>.join(np.take(tokenized_text,name_idx)) <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">for</span> name_idx <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">in</span> name_idxs])<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">36</span> <span style="font-size: inherit;line-height: inherit;color: rgb(150, 152, 150);overflow-wrap: inherit !important;word-break: inherit !important;"># If there duplicate names in the name_lists</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">37</span> name_lists = np.unique(name_lists)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">38</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> name_lists<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">39</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">else</span>:<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">40</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">return</span> <span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">None</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">41</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">42</span>texts = df_data_val.review.values<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">43</span>predicted_labels, _ = predict(texts)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">44</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_review_label'</span>] = list(predicted_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">45</span>df_data_val[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_name'</span>]=df_data_val.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: get_dish_candidate_names(row.predicted_review_label, row.review_tokens)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">46</span> , axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">47</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">48</span>texts = df_data_train.review.values<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">49</span>predicted_labels, _ = predict(texts)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">50</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_review_label'</span>] = list(predicted_labels)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">51</span>df_data_train[<span style="font-size: inherit;line-height: inherit;color: rgb(181, 189, 104);overflow-wrap: inherit !important;word-break: inherit !important;">'predicted_name'</span>]=df_data_train.apply(<span style="font-size: inherit;line-height: inherit;color: rgb(178, 148, 187);overflow-wrap: inherit !important;word-break: inherit !important;">lambda</span> row: get_dish_candidate_names(row.predicted_review_label, row.review_tokens)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">52</span> , axis=<span style="font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">1</span>)<br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">53</span><br /><span style="padding-right: 20px;font-size: inherit;line-height: inherit;color: rgb(222, 147, 95);overflow-wrap: inherit !important;word-break: inherit !important;">54</span>(df_data_val)<br /></section>
模型比拼
-
在相同的训练时间下,ALBERT效果要比BERT好。
-
在相同的推理时间下,ALBERT base和large的效果都是没有BERT好。
https://medium.com/@namanbansal9909/should-we-shift-from-bert-to-albert-e6fbb7779d3e
作者系网易新闻·网易号“各有态度”签约作者
<pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;white-space: normal;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;line-height: 1.75em;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><section style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">人工智能领域最具影响力的十大女科学家</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;" /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">MIT最新深度学习入门课,安排起来!</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">有了这个神器,轻松用 Python 写个 App</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(87, 107, 149);font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;">「最全」实至名归,NumPy 官方早有中文教程</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;" /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">漫画版 Linux 内核的世界</span><br style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;" /></section></section></section></section></section></section></section></section></section>
本篇文章来源于: 深度学习这件小事
本文为原创文章,版权归知行编程网所有,欢迎分享本文,转载请保留出处!
内容反馈