<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>PyTorch Recipe | Haobin Tan</title><link>https://haobin-tan.netlify.app/tags/pytorch-recipe/</link><atom:link href="https://haobin-tan.netlify.app/tags/pytorch-recipe/index.xml" rel="self" type="application/rss+xml"/><description>PyTorch Recipe</description><generator>Hugo Blox Builder (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Mon, 24 May 2021 00:00:00 +0000</lastBuildDate><image><url>https://haobin-tan.netlify.app/media/icon_hu7d15bc7db65c8eaf7a4f66f5447d0b42_15095_512x512_fill_lanczos_center_3.png</url><title>PyTorch Recipe</title><link>https://haobin-tan.netlify.app/tags/pytorch-recipe/</link></image><item><title>🧾 PyTorch Recipes</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/</link><pubDate>Mon, 07 Sep 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/</guid><description>&lt;p>This section provides a lot of useful recipes that make use of specific PyTorch features.&lt;/p></description></item><item><title>🔥 Transfer Learning for Computer Vision</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/transfer-learning/</link><pubDate>Tue, 03 Nov 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/transfer-learning/</guid><description>&lt;h2 id="handling-settings-for-training-and-valiadtion-phase-flexibly">Handling settings for training and valiadtion phase flexibly&lt;/h2>
&lt;p>💡 Use Python dictionary&lt;/p>
&lt;ul>
&lt;li>Phase (&lt;code>'train'&lt;/code> or &lt;code>'val'&lt;/code>) as key&lt;/li>
&lt;/ul>
&lt;p>For example:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">data_transforms&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># For training: data augmentation and normalization&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Compose&lt;/span>&lt;span class="p">([&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">RandomResizedCrop&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">224&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">RandomHorizontalFlip&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Normalize&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mf">0.485&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.456&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.406&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mf">0.229&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.224&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.225&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">]),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># For validation: only normalization&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Compose&lt;/span>&lt;span class="p">([&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Resize&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">256&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CenterCrop&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">224&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Normalize&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mf">0.485&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.456&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.406&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="mf">0.229&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.224&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mf">0.225&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="p">}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">data_dir&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="s1">&amp;#39;hymenoptera_data&amp;#39;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">image_datasets&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">datasets&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ImageFolder&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">os&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">path&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">join&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">data_dir&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">data_transforms&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span>&lt;span class="p">]}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataloaders&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">utils&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">DataLoader&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">batch_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">shuffle&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_workers&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span>&lt;span class="p">]}&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dataset_sizes&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">{&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image_datasets&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">])&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span>&lt;span class="p">]}&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="general-function-to-train-a-model">General function to train a model&lt;/h2>
&lt;p>Here we will&lt;/p>
&lt;ul>
&lt;li>schedule the learning rate&lt;/li>
&lt;li>save hte best model&lt;/li>
&lt;/ul>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">train_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">scheduler&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">25&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> scheduler is an LR scheduler object from torch.optim.lr_scheduler
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">since&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">time&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">time&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">best_model_wts&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">copy&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">deepcopy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">best_acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">epoch&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;Epoch &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">epoch&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s1">&amp;#39;-&amp;#39;&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># each epoch has a training and validation phase&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">phase&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span>&lt;span class="p">]:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">phase&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">train&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># set model to training mode&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="c1"># set model to evaluate mode&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">running_loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">running_corrects&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">inputs&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">dataloaders&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">phase&lt;/span>&lt;span class="p">]:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">inputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">inputs&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># zero the params gradients&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zero_grad&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># forward&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># track history if only in train&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">set_grad_enabled&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">phase&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">outputs&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">inputs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">preds&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">outputs&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">outputs&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># backward + optimize only in trianing phase&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">phase&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">backward&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># statistics&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">running_loss&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">loss&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">item&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">inputs&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">shape&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">running_corrects&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">preds&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="n">labels&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">phase&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;train&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scheduler&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">step&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">epoch_loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">running_loss&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">dataset_sizes&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">phase&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">epoch_acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">running_corrects&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">double&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">dataset_sizes&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">phase&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">phase&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1"> Loss: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">epoch_loss&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.4f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">, Acc: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">epoch_acc&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.4f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># deep copy the model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">phase&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="s1">&amp;#39;val&amp;#39;&lt;/span> &lt;span class="ow">and&lt;/span> &lt;span class="n">epoch_acc&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">best_acc&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">best_acc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">epoch_acc&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">best_model_wts&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">copy&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">deepcopy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">time_elapsed&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">time&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">time&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">-&lt;/span> &lt;span class="n">since&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;Training complete in &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">time_elapsed&lt;/span> &lt;span class="o">//&lt;/span> &lt;span class="mi">60&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.0f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">m &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">time_elapsed&lt;/span> &lt;span class="o">%&lt;/span> &lt;span class="mi">60&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.0f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">s&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s1">&amp;#39;Best val Acc: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">best_acc&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s1">.4f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s1">&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># load best model weights&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_state_dict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">best_model_wts&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">model&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="major-transfer-learning-scenarios">Major Transfer Learning scenarios&lt;/h2>
&lt;p>In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. &lt;strong>Instead, it is common to pretrain a ConvNet on a very large dataset&lt;/strong> (e.g. ImageNet, which contains 1.2 million images with 1000 categories), &lt;strong>and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.&lt;/strong>&lt;/p>
&lt;h3 id="convnet-as-fixed-feature-extractor">ConvNet as fixed feature extractor&lt;/h3>
&lt;ol>
&lt;li>Take a ConvNet pretrained on ImageNet&lt;/li>
&lt;li>Remove the last fully-connected layer (this layer’s outputs are the 1000 class scores for a different task like ImageNet)&lt;/li>
&lt;li>Treat the rest of the ConvNet as a fixed feature extractor for the new dataset. (We call these features &lt;strong>CNN codes&lt;/strong>.)&lt;/li>
&lt;/ol>
&lt;h4 id="implementation-with-pytorch">Implementation with PyTorch&lt;/h4>
&lt;ul>
&lt;li>we will freeze the weights for all of the network except that of the final fully connected layer.&lt;/li>
&lt;li>This last fully connected layer is replaced with a new one with random weights and &lt;strong>only this layer is trained&lt;/strong>.&lt;/li>
&lt;/ul>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Load pretrained model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_conv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torchvision&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">models&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">resnet18&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pretrained&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Freeze all the network&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">param&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">model_conv&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">param&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">requires_grad&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="kc">False&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Parameters of newly constructed modules have requires_grad=True by default&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># in other words, now we freeze all the network except the final layer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">num_features&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model_conv&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">in_features&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_conv&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_conv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model_conv&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">criterion&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CrossEntropyLoss&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Observe that only parameters of final layer are being optimized as&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># opposed to before.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer_conv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model_conv&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Decay LR by a factor of 0.1 every 7 epochs&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">exp_lr_scheduler&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">lr_scheduler&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">StepLR&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">optimizer_conv&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">step_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">gamma&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Train and evaluate:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_conv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">train_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model_conv&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer_conv&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">exp_lr_scheduler&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">25&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="fine-tuning-the-convnet">Fine-tuning the ConvNet&lt;/h3>
&lt;p>The second strategy is to not only replace and retrain the classifier on top of the ConvNet on the new dataset, but to &lt;strong>also fine-tune the weights of the pretrained network by continuing the backpropagation&lt;/strong>. It is possible to fine-tune all the layers of the ConvNet, or it’s possible to keep some of the earlier layers fixed (due to overfitting concerns) and only fine-tune some higher-level portion of the network.&lt;/p>
&lt;p>Motivation: the earlier features of a ConvNet contain more generic features (e.g. edge detectors or color blob detectors) that should be useful to many tasks, but later layers of the ConvNet becomes progressively more specific to the details of the classes contained in the original dataset.&lt;/p>
&lt;h4 id="implementation-with-pytorch-1">Implementation with PyTorch&lt;/h4>
&lt;ul>
&lt;li>Instead of random initializaion, we initialize the network with a pretrained network, like the one that is trained on imagenet 1000 dataset.&lt;/li>
&lt;li>Rest of the training looks as usual.&lt;/li>
&lt;/ul>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Load a pretrained model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_ft&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">models&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">resnet18&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pretrained&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Reset the final fully connected layer according to specific task&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">num_features&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model_ft&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">in_features&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">num_classes&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="c1"># assuming a binary classification task&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_ft&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_features&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_classes&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_ft&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model_ft&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">criterion&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">CrossEntropyLoss&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer_ft&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model_ft&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Decay learning rate by a factor of 0.1 every 7 epochs&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">exp_lr_scheduler&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">lr_scheduler&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">StepLR&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">optimizer_ft&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">step_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">7&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">gamma&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Train and evaluate:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model_conv&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">train_model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model_ft&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">criterion&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">criterion&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">optimizer_ft&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">scheduler&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">exp_lr_scheduler&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_epochs&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">25&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="when-and-how-to-fine-tune">When and how to fine-tune?&lt;/h2>
&lt;p>The two most important factors are:&lt;/p>
&lt;ul>
&lt;li>size of the new dataset (small or big)&lt;/li>
&lt;li>its similarity to the original dataset&lt;/li>
&lt;/ul>
&lt;p>Keeping in mind that &lt;strong>ConvNet features are more generic in early layers and more original-dataset-specific in later layers.&lt;/strong>&lt;/p>
&lt;p>Common rules of thumb for navigating the 4 major scenarios:&lt;/p>
&lt;ol>
&lt;li>
&lt;p>&lt;strong>&lt;em>New dataset is small and similar to original dataset&lt;/em>.&lt;/strong>&lt;/p>
&lt;p>Since the data is small, it is not a good idea to fine-tune the ConvNet due to overfitting concerns. Since the data is similar to the original data, we expect higher-level features in the ConvNet to be relevant to this dataset as well. Hence, the best idea might be to train a linear classifier on the CNN codes.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;strong>&lt;em>New dataset is large and similar to the original dataset&lt;/em>.&lt;/strong>&lt;/p>
&lt;p>Since we have more data, we can have more confidence that we won’t overfit if we were to try to fine-tune through the full network.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;strong>&lt;em>New dataset is small but very different from the original dataset&lt;/em>.&lt;/strong>&lt;/p>
&lt;p>Since the data is small, it is likely best to only train a linear classifier. Since the dataset is very different, it might not be best to train the classifier form the top of the network, which contains more dataset-specific features. Instead, it might work better to train the SVM classifier from activations somewhere earlier in the network.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;strong>&lt;em>New dataset is large and very different from the original dataset&lt;/em>.&lt;/strong>&lt;/p>
&lt;p>Since the dataset is very large, we may expect that we can afford to train a ConvNet from scratch. However, in practice it is very often still beneficial to initialize with weights from a pretrained model. In this case, we would have enough data and confidence to fine-tune through the entire network.&lt;/p>
&lt;/li>
&lt;/ol>
&lt;h3 id="pratical-advices">Pratical advices&lt;/h3>
&lt;ul>
&lt;li>&lt;strong>&lt;em>Constraints from pretrained models&lt;/em>.&lt;/strong>
&lt;ul>
&lt;li>Note that if you wish to use a pretrained network, you may be slightly constrained in terms of the architecture you can use for your new dataset. For example, you can’t arbitrarily take out Conv layers from the pretrained network.&lt;/li>
&lt;li>However, some changes are straight-forward: Due to parameter sharing, you can easily run a pretrained network on images of different spatial size. This is clearly evident in the case of Conv/Pool layers because their forward function is &lt;em>independent&lt;/em> of the input volume spatial size (as long as the strides “fit”).&lt;/li>
&lt;li>In case of FC layers, this still holds true because FC layers can be converted to a Convolutional Layer: For example, in an AlexNet, the final pooling volume before the first FC layer is of size [6x6x512]. Therefore, the FC layer looking at this volume is equivalent to having a Convolutional Layer that has receptive field size 6x6, and is applied with padding of 0.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>&lt;strong>&lt;em>Learning rates&lt;/em>.&lt;/strong>
&lt;ul>
&lt;li>It’s common to use a &lt;strong>smaller&lt;/strong> learning rate for ConvNet weights that are being fine-tuned, in comparison to the (randomly-initialized) weights for the new linear classifier that computes the class scores of your new dataset.&lt;/li>
&lt;li>This is because we expect that the ConvNet weights are relatively good, so we don’t wish to distort them too quickly and too much (especially while the new Linear Classifier above them is being trained from random initialization).&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;h2 id="google-colab-notebook">Google Colab Notebook&lt;/h2>
&lt;p>&lt;a href="https://colab.research.google.com/drive/1pCckg5u_8tnJ1lHVThFHsUnDsD999PPd?authuser=1">Colab Notebook&lt;/a>&lt;/p>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#">Transfer Learning for Computer Vision Tutorial&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://cs231n.github.io/transfer-learning/">CS231n-Transfer Learning&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>Saving and Loading Checkpoints</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/saving-and-loading-checkpoints/</link><pubDate>Fri, 06 Nov 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/saving-and-loading-checkpoints/</guid><description>&lt;h2 id="motivation">Motivation&lt;/h2>
&lt;p>Saving and loading a general checkpoint model for inference or resuming training can be helpful for picking up where we last left off.&lt;/p>
&lt;p>When saving a general checkpoint, you must save more than just the model’s &lt;code>state_dict.&lt;/code> It is important to also save the optimizer’s &lt;code>state_dict&lt;/code>, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the&lt;/p>
&lt;ul>
&lt;li>epoch you left off on,&lt;/li>
&lt;li>the latest recorded training loss,&lt;/li>
&lt;li>external &lt;code>torch.nn.Embedding&lt;/code> layers,&lt;/li>
&lt;li>and more, based on your own algorithm.&lt;/li>
&lt;/ul>
&lt;h2 id="how-to-save-and-load-checkpoints">How to save and load checkpoints?&lt;/h2>
&lt;p>To &lt;strong>save&lt;/strong> multiple checkpoints, we must organize them in a dictionary and use &lt;code>torch.save()&lt;/code> to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the &lt;code>.tar&lt;/code> file extension.&lt;/p>
&lt;p>To &lt;strong>load&lt;/strong> the items,&lt;/p>
&lt;ol>
&lt;li>first initialize the model and optimizer,&lt;/li>
&lt;li>then load the dictionary locally using &lt;code>torch.load()&lt;/code>. From here, we can easily access the saved items by simply querying the dictionary as you would expect.&lt;/li>
&lt;/ol>
&lt;h2 id="example">Example&lt;/h2>
&lt;h3 id="1-import-necessary-libraries-for-loading-our-data">1. Import necessary libraries for loading our data&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="2-define-and-intialize-the-neural-network">2. Define and intialize the neural network&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Net&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">120&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">120&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">84&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">84&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">net&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="3-initialize-the-optimizer">3. Initialize the optimizer&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="4-saving-the-general-checkpoint">4. Saving the general checkpoint&lt;/h3>
&lt;ol>
&lt;li>Collect all relevant information,&lt;/li>
&lt;li>Build our checkpoint &lt;code>dictionary&lt;/code>.&lt;/li>
&lt;li>Save checkpoint using &lt;code>torch.save()&lt;/code>&lt;/li>
&lt;/ol>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Additional information&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">EPOCH&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">5&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">PATH&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="s2">&amp;#34;model.pt&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">LOSS&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mf">0.4&lt;/span> &lt;span class="c1"># just dummy number&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">({&lt;/span>&lt;span class="s1">&amp;#39;epoch&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">EPOCH&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s1">&amp;#39;model_state_dict&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s1">&amp;#39;optimizer_state_dict&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s1">&amp;#39;loss&amp;#39;&lt;/span>&lt;span class="p">:&lt;/span> &lt;span class="n">LOSS&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">},&lt;/span> &lt;span class="n">PATH&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="5-load-the-general-checkpoint">5. Load the general checkpoint&lt;/h2>
&lt;ol>
&lt;li>First initialize the model and optimizer&lt;/li>
&lt;li>Then load the checkpoint &lt;code>dictionary&lt;/code> locally&lt;/li>
&lt;/ol>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># initialize the model and optimizer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">net&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># load checkpoint&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">checkpoint&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">PATH&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_state_dict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">checkpoint&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;model_state_dict&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_state_dict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">checkpoint&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;optimizer_state_dict&amp;#39;&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">epoch&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">checkpoint&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;epoch&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">loss&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">checkpoint&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="s1">&amp;#39;loss&amp;#39;&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;ol start="3">
&lt;li>Call &lt;code>eval()&lt;/code> for inference or &lt;code>train()&lt;/code> for training&lt;/li>
&lt;/ol>
&lt;h2 id="google-colab-notebook">Google Colab Notebook&lt;/h2>
&lt;p>&lt;a href="https://colab.research.google.com/drive/1PlsftZnPEvyWkJUXIoM5M3a-UA1RTXhl?authuser=1">Colab Notebook&lt;/a>&lt;/p>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>&lt;a href="https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html">SAVING AND LOADING A GENERAL CHECKPOINT IN PYTORCH&lt;/a>&lt;/li>
&lt;/ul></description></item><item><title>nn ModuleList vs. Sequential</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/nn-modulelist-vs-sequental/</link><pubDate>Mon, 09 Nov 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/nn-modulelist-vs-sequental/</guid><description>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn.functional&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">F&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="nnmodule">&lt;code>nn.Module&lt;/code>&lt;/h2>
&lt;ul>
&lt;li>Defines the base class for all neural network&lt;/li>
&lt;li>We MUST &lt;em>subclass&lt;/em> it&lt;/li>
&lt;/ul>
&lt;h3 id="example">Example&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bn1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bn2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1024&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bn1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">bn2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># flat&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sigmoid&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">Net(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (fc1): Linear(in_features=25088, out_features=1024, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (fc2): Linear(in_features=1024, out_features=10, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="nnsequential">&lt;code>nn.Sequential&lt;/code>&lt;/h2>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/nn.html?highlight=sequential#torch.nn.Sequential">Sequential&lt;/a> is a container of Modules that can be stacked together and run at the same time.&lt;/p>
&lt;ul>
&lt;li>The &lt;code>nn.Module&lt;/code>&amp;rsquo;s stored in &lt;code>nn.Sequential&lt;/code> are connected in a cascaded way&lt;/li>
&lt;li>&lt;code>nn.Sequential&lt;/code> has a &lt;code>forward()&lt;/code> method
&lt;ul>
&lt;li>Have to make sure that the output size of a block matches the input size of the following block.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>Basically, it behaves just like a &lt;code>nn.Module&lt;/code>&lt;/li>
&lt;/ul>
&lt;h3 id="example-1">Example&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">NetSequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv_block1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv_block2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">64&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decoder&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sequential&lt;/span>&lt;span class="p">(&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sigmoid&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1024&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv_block1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv_block2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">decode&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">NetSequential&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">NetSequential(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (conv_block1): Sequential(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (2): ReLU()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> )
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (conv_block2): Sequential(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (2): ReLU()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> )
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (decoder): Sequential(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (0): Linear(in_features=25088, out_features=1024, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (1): Sigmoid()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (2): Linear(in_features=1024, out_features=10, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> )
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="nnmodulelist">&lt;code>nn.ModuleList&lt;/code>&lt;/h2>
&lt;blockquote>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html">Documentation&lt;/a>:&lt;/p>
&lt;p>Holds submodules in a list.&lt;/p>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html#torch.nn.ModuleList">&lt;code>ModuleList&lt;/code>&lt;/a> can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all&lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module">&lt;code>Module&lt;/code>&lt;/a> methods.&lt;/p>
&lt;/blockquote>
&lt;ul>
&lt;li>Does NOT have a &lt;code>forward()&lt;/code> method, because it does not define any neural network, that is, there is no connection between each of the &lt;code>nn.Module&lt;/code>&amp;rsquo;s that it stores.&lt;/li>
&lt;li>We may use it to store &lt;code>nn.Module&lt;/code>&amp;rsquo;s, just like you use Python lists to store other types of objects (integers, strings, etc). And Pytorch is “aware” of the existence of the &lt;code>nn.Module&lt;/code>&amp;rsquo;s inside an &lt;code>nn.ModuleList&lt;/code>&lt;/li>
&lt;li>Execution order of &lt;code>nn.Modules&lt;/code> stored in &lt;code>nn.ModuleList&lt;/code> is defined in &lt;code>forward()&lt;/code>, which we have to implement explicitly by ourselves.&lt;/li>
&lt;/ul>
&lt;h3 id="example-2">Example&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">NetModuleList&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">module_list&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ModuleList&lt;/span>&lt;span class="p">([&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">in_c&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">64&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">kernel_size&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">stride&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">padding&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">BatchNorm2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">64&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ReLU&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Flatten&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">32&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">28&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Sigmoid&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1024&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">n_classes&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">module&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">module_list&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">module&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">NetModuleList&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">NetModuleList(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (module_list): ModuleList(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (2): ReLU()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (5): ReLU()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (6): Flatten(start_dim=1, end_dim=-1)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (7): Linear(in_features=25088, out_features=1024, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (8): Sigmoid()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (9): Linear(in_features=1024, out_features=10, bias=True)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> )
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="nnsequential-vs-nnmodulelist">&lt;code>nn.Sequential&lt;/code> vs. &lt;code>nn.ModuleList&lt;/code>&lt;/h2>
&lt;table>
&lt;thead>
&lt;tr>
&lt;th>&lt;/th>
&lt;th>&lt;code>nn.Sequential&lt;/code>&lt;/th>
&lt;th>&lt;code>nn.ModuleList&lt;/code>&lt;/th>
&lt;/tr>
&lt;/thead>
&lt;tbody>
&lt;tr>
&lt;td>Has &lt;code>forward()&lt;/code> ?&lt;/td>
&lt;td>✅&lt;/td>
&lt;td>❌&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>Connection between &lt;code>nn.Modules&lt;/code> stored inside?&lt;/td>
&lt;td>✅&lt;/td>
&lt;td>❌&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>Execution order = stored order?&lt;/td>
&lt;td>✅&lt;/td>
&lt;td>❌&lt;/td>
&lt;/tr>
&lt;tr>
&lt;td>Advantages&lt;/td>
&lt;td>succinct&lt;/td>
&lt;td>flexible&lt;/td>
&lt;/tr>
&lt;/tbody>
&lt;/table>
&lt;h3 id="when-to-use-which">When to use which?&lt;/h3>
&lt;ul>
&lt;li>Use &lt;code>Module&lt;/code> when we have a big block compose of multiple smaller blocks&lt;/li>
&lt;li>Use &lt;code>Sequential&lt;/code> when we want to create a small block from layers&lt;/li>
&lt;li>Use &lt;code>ModuleList&lt;/code> when we need to iterate through some layers or building blocks and do something&lt;/li>
&lt;/ul>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://discuss.pytorch.org/t/when-should-i-use-nn-modulelist-and-when-should-i-use-nn-sequential/5463">When should I use nn.ModuleList and when should I use nn.Sequential?&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://towardsdatascience.com/pytorch-how-and-when-to-use-module-sequential-modulelist-and-moduledict-7a54597b5f17">Pytorch: how and when to use Module, Sequential, ModuleList and ModuleDict&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://zhuanlan.zhihu.com/p/64990232">PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>🔥 Custom Datasets and Transforms</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/custom-dataset-transform/</link><pubDate>Thu, 26 Nov 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/custom-dataset-transform/</guid><description>&lt;h2 id="custom-dataset">Custom Dataset&lt;/h2>
&lt;p>In order to use our custom dataset, we need to&lt;/p>
&lt;ul>
&lt;li>
&lt;p>inherit &lt;code>torch.utils.data.Dataset&lt;/code> , an abstract class representing a dataset.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>override&lt;/p>
&lt;ul>
&lt;li>&lt;code>__len__&lt;/code> so that &lt;code>len(dataset)&lt;/code> returns the size of the dataset.&lt;/li>
&lt;li>&lt;code>__getitem__&lt;/code> to support the indexing such that &lt;code>dataset[i]&lt;/code> can be used to get &lt;em>i&lt;/em>-th sample.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;p>The skeleton is as follows:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data.dataset&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyCustomDataset&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">...&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># initial logic, e.g.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># read csv&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># assign data transformation&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># ...&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">index&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;Get the &lt;/span>&lt;span class="si">{index}&lt;/span>&lt;span class="s2">-th sample&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Note: the return value can be customized depending on application&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">img&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">count&lt;/span> &lt;span class="c1"># of how many examples(images?) you have&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="example">Example&lt;/h3>
&lt;p>Let&amp;rsquo;s take &lt;a href="http://yann.lecun.com/exdb/mnist/">MNIST&lt;/a> dataset as example. Assuming we have the csv file located in &lt;code>CSV_PATH&lt;/code>. The structure of our csv file is&lt;/p>
&lt;ul>
&lt;li>
&lt;p>One instance/sample per line&lt;/p>
&lt;ul>
&lt;li>The first column is the digit label (0 - 9)&lt;/li>
&lt;li>The rest 784 columns represents the values of each pixel in the image of size 28x28 ($28 \times 28 = 784$)&lt;/li>
&lt;li>I.e. each sample consists of an image of digit and the label of the digit&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>
&lt;p>There&amp;rsquo;re 5000 lines in total. I.e. 5000 samples&lt;/p>
&lt;ul>
&lt;li>We want to use the first 4000 samples for training and validation,&lt;/li>
&lt;li>and the rest 1000 samples for testing.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;p>Let&amp;rsquo;s implement our custom MNIST dataset:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torch.utils.data&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">Dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyMNIST&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">Dataset&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">TRAIN&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">VALID&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">TEST&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">csv_file&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">TRAIN&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label_transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">None&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Args:
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> csv_file (string): Path to the csv file
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> usage (int): usage of the dataset (train/validation/test)
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> transform (callable, optional): Optional transform to be applied on the image.
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> label_transform (callable, optional): Optional transform to be applied on the label.
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transform&lt;/span> &lt;span class="c1"># image preprocessing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">label_transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">label_transform&lt;/span> &lt;span class="c1"># label preprocessing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># load from csv file&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">all_data&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">genfromtxt&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">csv_file&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">delimiter&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;,&amp;#39;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s1">&amp;#39;uint8&amp;#39;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># 5000 lines in csv file --&amp;gt; 5000 instances&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># training set: first 3000 lines&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># validation set: 3000 - 4000 &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># test set: last 1000 lines&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">test&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">all_data&lt;/span>&lt;span class="p">[:&lt;/span>&lt;span class="mi">4000&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">all_data&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">4000&lt;/span>&lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">train&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">val&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">train&lt;/span>&lt;span class="p">[:&lt;/span>&lt;span class="mi">3000&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">train&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="mi">3000&lt;/span>&lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># choose lines based on specified usage&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">usage&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">TRAIN&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">images&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">train&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">train&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="c1"># first column is label of the digit &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">elif&lt;/span> &lt;span class="n">usage&lt;/span> &lt;span class="o">==&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">VALID&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">images&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">val&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">val&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">images&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">test&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">1&lt;/span>&lt;span class="p">:]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">test&lt;/span>&lt;span class="p">[:,&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__getitem__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">index&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">images&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">index&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">index&lt;/span>&lt;span class="p">]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span> &lt;span class="ow">is&lt;/span> &lt;span class="ow">not&lt;/span> &lt;span class="kc">None&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">transform&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">label_transform&lt;/span> &lt;span class="ow">is&lt;/span> &lt;span class="ow">not&lt;/span> &lt;span class="kc">None&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">label_transform&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># convert label to Tensor of dtype long&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">label&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">as_tensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">label&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">long&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">label&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__len__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">labels&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Use our custom MNIST dataset:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">torchvision&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">transforms&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># apply normalizaton and convertion to Tensor before using the dataset&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">preprocess_transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Compose&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ToTensor&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Normalize&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="mf">0.1&lt;/span>&lt;span class="p">,),&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="mf">0.4&lt;/span>&lt;span class="p">))])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># let&amp;#39;s say we use the dataset for testing&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">my_mnist&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyMNIST&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">csv_file&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">CSV_PATH&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">MyMNIST&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">TEST&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">preprocess_transform&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="custom-transform-and-augmentation">Custom transform and augmentation&lt;/h2>
&lt;p>The example code above takes use of the transforms provided by &lt;code>torchvision.transforms&lt;/code>. We can also implement custom transforms by ourselves.&lt;/p>
&lt;p>To do this, we need to write them as &lt;strong>callable&lt;/strong> classes:&lt;/p>
&lt;ul>
&lt;li>inherit &lt;code>object&lt;/code> class&lt;/li>
&lt;li>implement &lt;code>__init___&lt;/code> if needed&lt;/li>
&lt;li>define desired transformations in &lt;code>__call__(self, image)&lt;/code> method&lt;/li>
&lt;/ul>
&lt;h3 id="example-1">Example&lt;/h3>
&lt;p>For example, let&amp;rsquo;s implement two custom transforms:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyNormalizer&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">object&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;Normalize image&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__call__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Only works for our custom MNIST dataset: Devide the pixel values by 255
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> Generally, normalization should work as follows:
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> data_normalized = (data - data.mean) / data.std
&lt;/span>&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="s2"> &amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">image&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mf">1.0&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="mi">255&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">image&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyToTensor&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">object&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="s2">&amp;#34;&amp;#34;&amp;#34;Convert image to PyTorch Tensor&amp;#34;&amp;#34;&amp;#34;&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__call__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">image&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">image&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">from_numpy&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">image&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">image&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="use-custom-transform-in-our-custom-mnist-dataset">Use custom transform in our custom MNIST dataset&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">preprocess_transform&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">transforms&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Compose&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">MyToTensor&lt;/span>&lt;span class="p">(),&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">MyNormalizer&lt;/span>&lt;span class="p">()])&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">my_mnist&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyMNIST&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">csv_file&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">CSV_PATH&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">usage&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">MyMNIST&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">TEST&lt;/span>&lt;span class="p">,&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">transform&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">preprocess_transform&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#compose-transforms">WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://github.com/utkuozbulak/pytorch-custom-dataset-examples">pytorch-custom-dataset-examples&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>🔥🧾 General Training Steps Using PyTorch</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/pytorch-training-steps/</link><pubDate>Thu, 26 Nov 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/pytorch-training-steps/</guid><description>&lt;p>Open in &lt;a href="https://colab.research.google.com/drive/1OWujzsbTVMrSL-HhKy98abhbf4-y23SD?usp=sharing">Google Colab&lt;/a>&lt;/p>
&lt;p>General steps:&lt;/p>
&lt;ol>
&lt;li>
&lt;p>Set &lt;code>device&lt;/code>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Set &lt;code>Dataset&lt;/code> and &lt;code>DataLoader&lt;/code>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Define network model&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Build network model&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Define loss function and optimizer&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Define training process&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Train the model&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Store/Load weights&lt;/p>
&lt;/li>
&lt;/ol></description></item><item><title>Saving and Loading Models</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/save-load-model/</link><pubDate>Sun, 17 Jan 2021 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/save-load-model/</guid><description>&lt;p>Three core functions for saving and loading models:&lt;/p>
&lt;ol>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save">torch.save&lt;/a>&lt;/p>
&lt;p>Saves a serialized object to disk. This function uses Python’s &lt;a href="https://docs.python.org/3/library/pickle.html">pickle&lt;/a> utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/generated/torch.load.html#torch.load">torch.load&lt;/a>&lt;/p>
&lt;p>Uses &lt;a href="https://docs.python.org/3/library/pickle.html">pickle&lt;/a>’s unpickling facilities to deserialize pickled object files to memory.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/nn.html?highlight=load_state_dict#torch.nn.Module.load_state_dict">torch.nn.Module.load_state_dict&lt;/a>&lt;/p>
&lt;p>Loads a model’s parameter dictionary using a deserialized &lt;em>state_dict&lt;/em>.&lt;/p>
&lt;/li>
&lt;/ol>
&lt;h2 id="state_dict">&lt;code>state_dict&lt;/code>&lt;/h2>
&lt;p>In PyTorch,&lt;/p>
&lt;ul>
&lt;li>
&lt;p>the learnable parameters (i.e. weights and biases) of an &lt;code>torch.nn.Module&lt;/code> model are contained in the model’s &lt;strong>parameters&lt;/strong> (accessed with &lt;code>model.parameters()&lt;/code>). A &lt;strong>state_dict&lt;/strong> is simply a Python &lt;strong>dictionary object&lt;/strong> that maps each layer to its parameter tensor.&lt;/p>
&lt;ul>
&lt;li>Note that only layers with &lt;strong>learnable&lt;/strong> parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s &lt;strong>state_dict&lt;/strong>.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>
&lt;p>Optimizer objects (&lt;code>torch.optim&lt;/code>) also have a &lt;strong>state_dict&lt;/strong>, which contains information about the optimizer’s state, as well as the hyperparameters used.&lt;/p>
&lt;/li>
&lt;/ul>
&lt;p>Because &lt;strong>state_dict&lt;/strong> objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.&lt;/p>
&lt;h3 id="example">Example&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.functional&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">F&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.optim&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">optim&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">TheModelClass&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">TheModelClass&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">MaxPool2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">2&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Conv2d&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">6&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">120&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">120&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">84&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">84&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">10&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">pool&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">conv2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">view&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">-&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">16&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="mi">5&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc1&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">F&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">relu&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc2&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">fc3&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Initialize model&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">TheModelClass&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Initialize optimizer&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">optimizer&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">optim&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">SGD&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">parameters&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">lr&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.001&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">momentum&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mf">0.9&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Print model&amp;#39;s state_dict&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Model&amp;#39;s state_dict:&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">param_tensor&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">param_tensor&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="se">\t&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">()[&lt;/span>&lt;span class="n">param_tensor&lt;/span>&lt;span class="p">]&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">size&lt;/span>&lt;span class="p">())&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">Model&amp;#39;s state_dict:
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">conv1.weight torch.Size([6, 3, 5, 5])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">conv1.bias torch.Size([6])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">conv2.weight torch.Size([16, 6, 5, 5])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">conv2.bias torch.Size([16])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc1.weight torch.Size([120, 400])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc1.bias torch.Size([120])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc2.weight torch.Size([84, 120])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc2.bias torch.Size([84])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc3.weight torch.Size([10, 84])
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">fc3.bias torch.Size([10])
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Print optimizer&amp;#39;s state_dict&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;Optimizer&amp;#39;s state_dict:&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">var_name&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">var_name&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="se">\t&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">optimizer&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">()[&lt;/span>&lt;span class="n">var_name&lt;/span>&lt;span class="p">])&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-fallback" data-lang="fallback">&lt;span class="line">&lt;span class="cl">Optimizer&amp;#39;s state_dict:
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">state {}
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">param_groups [{&amp;#39;lr&amp;#39;: 0.001, &amp;#39;momentum&amp;#39;: 0.9, &amp;#39;dampening&amp;#39;: 0, &amp;#39;weight_decay&amp;#39;: 0, &amp;#39;nesterov&amp;#39;: False, &amp;#39;params&amp;#39;: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="saving--loading-model-for-inference">Saving &amp;amp; Loading Model for Inference&lt;/h2>
&lt;h3 id="saveload-state_dict-recommended">Save/Load &lt;code>state_dict&lt;/code> (Recommended)&lt;/h3>
&lt;p>&lt;strong>Save&lt;/strong>:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">state_dict&lt;/span>&lt;span class="p">(),&lt;/span> &lt;span class="n">PATH&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>&lt;strong>Load&lt;/strong>:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">TheModelClass&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="o">*&lt;/span>&lt;span class="n">args&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="o">**&lt;/span>&lt;span class="n">kwargs&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load_state_dict&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">PATH&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>When saving a model for inference, it is only necessary to save the trained model’s learned parameters. Saving the model’s &lt;em>state_dict&lt;/em> with the &lt;code>torch.save()&lt;/code> function will give you the most flexibility for restoring the model later.&lt;/p>
&lt;p>A common PyTorch convention is to save models using either a &lt;code>.pt&lt;/code> or &lt;code>.pth&lt;/code> file extension.&lt;/p>
&lt;h3 id="saveload-entire-model">Save/Load entire model&lt;/h3>
&lt;p>&lt;strong>Save:&lt;/strong>&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">PATH&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>&lt;strong>Load:&lt;/strong>&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Model class must be defined somewhere&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">PATH&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">eval&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>This save/load process uses the most intuitive syntax and involves the least amount of code. Saving a model in this way will save the &lt;strong>entire&lt;/strong> module using Python’s &lt;a href="https://docs.python.org/3/library/pickle.html">pickle&lt;/a> module.&lt;/p>
&lt;p>🔴 The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.&lt;/p>
&lt;p>A common PyTorch convention is to save models using either a &lt;code>.pt&lt;/code> or &lt;code>.pth&lt;/code> file extension.&lt;/p>
&lt;h2 id="saving--loading-a-general-checkpoint-for-inference-andor-resuming-training">Saving &amp;amp; Loading a General Checkpoint for Inference and/or Resuming Training&lt;/h2>
&lt;p>See: &lt;a href="https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/saving-and-loading-checkpoints/">Saving and Loading Checkpoints&lt;/a>&lt;/p>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-and-loading-models">SAVING AND LOADING MODELS&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://stackoverflow.com/questions/59095824/what-is-the-difference-between-pt-pth-and-pwf-extentions-in-pytorch">What is the difference between .pt, .pth and .pwf extentions in PyTorch?&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>Data Augmentation</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/data_augmentation/</link><pubDate>Sun, 17 Jan 2021 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/data_augmentation/</guid><description>&lt;h2 id="what-is-data-augmentation">What is data augmentation?&lt;/h2>
&lt;p>To solve the problem that it&amp;rsquo;s hard to get enough data for training neural networks, &lt;strong>image augmentation is a process of creating new training examples from the existing ones. To make a new sample, you slightly change the original image.&lt;/strong>&lt;/p>
&lt;p>For instance, you could make a new image a little brighter; you could cut a piece from the original image; you could make a new image by mirroring the original one, etc. Here are some examples of transformations of the original image that will create a new training sample:&lt;/p>
&lt;p>&lt;img src="https://raw.githubusercontent.com/EckoTan0804/upic-repo/master/uPic/augmentation.jpg" alt="augmentation">&lt;/p>
&lt;p>By applying those transformations to the original training dataset, you could create an almost infinite amount of new training samples.&lt;/p>
&lt;h2 id="premise-of-data-augmentation">Premise of data augmentation&lt;/h2>
&lt;p>A &lt;a href="https://nanonets.com/blog/human-pose-estimation-2d-guide/">convolutional neural network&lt;/a> that can robustly classify objects even if its placed in different orientations is said to have the property called &lt;strong>invariance&lt;/strong>. More specifically, a CNN can be invariant to &lt;strong>translation, viewpoint, size&lt;/strong> or &lt;strong>illumination&lt;/strong> (Or a combination of the above).&lt;/p>
&lt;h2 id="when-to-apply-augmentation">When to apply augmentation?&lt;/h2>
&lt;p>The answer may seem quite obvious; we do augmentation &lt;strong>before&lt;/strong> we feed the data to the model.&lt;/p>
&lt;p>However, we have two options here:&lt;/p>
&lt;ul>
&lt;li>&lt;strong>Offline augmentation&lt;/strong>
&lt;ul>
&lt;li>Preferred for relatively &lt;strong>smaller datasets&lt;/strong>&lt;/li>
&lt;li>Increasing the size of the dataset by a factor equal to the number of transformations we perform
&lt;ul>
&lt;li>For example, by &lt;strong>flipping&lt;/strong> all my images, I would &lt;strong>increase the size&lt;/strong> of my odataset by a &lt;strong>factor of 2&lt;/strong>&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>&lt;strong>Online augmentation / Augmentation on the fly&lt;/strong>
&lt;ul>
&lt;li>Preferred for &lt;strong>larger datasets&lt;/strong>, as we can’t afford the explosive increase in size.&lt;/li>
&lt;li>Perform transformations &lt;strong>on the mini-batches&lt;/strong> that we would feed to our model.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;h2 id="use-data-augmentation-in-the-right-way">Use data augmentation in the right way&lt;/h2>
&lt;p>‼️ &lt;strong>Do NOT increase irrelevant data!!!&lt;/strong>&lt;/p>
&lt;p>Sometimes not all augmentation techniques make sense for a dataset. Consider the following car example:&lt;/p>
&lt;figure>&lt;img src="https://raw.githubusercontent.com/EckoTan0804/upic-repo/master/uPic/1*vW3KGPp_w0wN6k3gYVlVHA.jpeg"
alt="The first image (from the left) is the original, the second one is flipped horizontally, the third one is rotated by 180 degrees, and the last one is rotated by 90 degrees (clockwise).">&lt;figcaption>
&lt;p>The first image (from the left) is the original, the second one is flipped horizontally, the third one is rotated by 180 degrees, and the last one is rotated by 90 degrees (clockwise).&lt;/p>
&lt;/figcaption>
&lt;/figure>
&lt;p>They are pictures of the same car, but our target application may NEVER see cars presented in these orientations. For example, if we&amp;rsquo;re gonna classify random cars on the road, only the second image would make sense to be in the dataset.&lt;/p>
&lt;h2 id="how-to-conduct-data-augmentation-in-pytorch">How to conduct data augmentation in PyTorch?&lt;/h2>
&lt;h3 id="use-torchvisiontransforms">Use &lt;code>torchvision.transforms&lt;/code>&lt;/h3>
&lt;ul>
&lt;li>Provides common image transformations&lt;/li>
&lt;li>Can be chained together using &lt;code>transforms.Compose&lt;/code>&lt;/li>
&lt;/ul>
&lt;h3 id="-use-albumentationshttpsgithubcomalbumentations-teamalbumentations">🔥 Use &lt;a href="https://github.com/albumentations-team/albumentations">&lt;code>albumentations&lt;/code>&lt;/a>&lt;/h3>
&lt;h4 id="demo">Demo&lt;/h4>
&lt;p>&lt;a href="https://albumentations-demo.herokuapp.com/">Demo&lt;/a> for viewing different augmentation transformations&lt;/p>
&lt;h2 id="when-will-data-augmentation-be-applied-in-pytorch">When will data augmentation be applied in PyTorch?&lt;/h2>
&lt;p>In any epoch the dataloader will apply a fresh set of random operations &lt;strong>“on the fly”.&lt;/strong> I.e. the augmentation happens inside of this line:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">data&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">target&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">dataloader&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Instead of showing the exact same items at every epoch, you are showing a variant that has been changed in a different way. So after three epochs, you would have seen three random variants of each item in a dataset.&lt;/p>
&lt;p>Note that each image will be transformed randomly on-the-fly, thus NO images will be generated and the length of &lt;code>Dataset&lt;/code> stays the SAME.&lt;/p>
&lt;p>If you want to perferm more augmentation and bring more varaibility for the dataset, just increase the number of epochs.&lt;/p>
&lt;blockquote>
&lt;p>Reference:&lt;/p>
&lt;ul>
&lt;li>&lt;a href="https://discuss.pytorch.org/t/data-augmentation-in-pytorch/7925">Data augmentation in PyTorch&lt;/a>&lt;/li>
&lt;li>&lt;a href="https://discuss.pytorch.org/t/transform-and-image-data-augmentation/71942">Transform and Image Data Augmentation&lt;/a>&lt;/li>
&lt;li>&lt;a href="https://discuss.pytorch.org/t/basic-question-about-torchvision-transforms/40213">Basic question about torchvision.transforms&lt;/a>&lt;/li>
&lt;/ul>
&lt;/blockquote>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>&lt;a href="https://nanonets.com/blog/data-augmentation-how-to-use-deep-learning-when-you-have-limited-data-part-2/">Data Augmentation | How to use Deep Learning when you have Limited Data — Part 2&lt;/a>&lt;/li>
&lt;/ul></description></item><item><title>TorchScript</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/torchscript/</link><pubDate>Wed, 21 Apr 2021 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/torchscript/</guid><description>&lt;p>&lt;img src="https://raw.githubusercontent.com/EckoTan0804/upic-repo/master/uPic/%E6%88%AA%E5%B1%8F2021-04-21%2017.17.23.png" alt="截屏2021-04-21 17.17.23">&lt;/p>
&lt;h2 id="torchscript">TorchScript&lt;/h2>
&lt;ul>
&lt;li>A PyTorch model’s journey from Python to C++ is enabled by &lt;strong>Torch Script&lt;/strong>, a representation of a PyTorch model that can be understood, compiled and serialized by the Torch Script compiler.&lt;/li>
&lt;li>Any TorchScript program can be saved from a Python process and loaded in a process where there is NO Python dependency. In other words, a TorchScript program can be run &lt;strong>independently&lt;/strong> from Python, such as in a standalone C++ program.&lt;/li>
&lt;li>This makes it possible to train models in PyTorch using familiar tools in Python and then export the model via TorchScript to a production environment where Python programs may be disadvantageous for performance and multi-threading reasons.&lt;/li>
&lt;li>👍 Advantage
&lt;ul>
&lt;li>TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.&lt;/li>
&lt;li>This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python&lt;/li>
&lt;li>TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution&lt;/li>
&lt;li>TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;h3 id="steps-for-loading-a-pytorch-model-in-c">Steps for Loading a PyTorch Model in C++&lt;/h3>
&lt;ol>
&lt;li>Converte PyTorch Model to TorchScript&lt;/li>
&lt;li>Serialize script module to a file&lt;/li>
&lt;li>Load script module in C++&lt;/li>
&lt;li>Execute script module in C++&lt;/li>
&lt;/ol>
&lt;h2 id="convert-pytorch-model-to-torch-script">Convert PyTorch Model to Torch Script&lt;/h2>
&lt;p>There are wo ways to convert a PyTorch model to Torch Script&lt;/p>
&lt;ul>
&lt;li>&lt;a href="#tracing">Tracing&lt;/a>&lt;/li>
&lt;li>&lt;a href="#scripting">Scripting&lt;/a>&lt;/li>
&lt;/ul>
&lt;h3 id="tracing">Tracing&lt;/h3>
&lt;ul>
&lt;li>A mechanism in which
&lt;ul>
&lt;li>the structure of the model is captured by evaluating it once using example inputs and&lt;/li>
&lt;li>recording the flow of those inputs through the model.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>Suitable for models that make limited use of control flow&lt;/li>
&lt;li>Function: &lt;code>torch.jit.trace&lt;/code>&lt;/li>
&lt;/ul>
&lt;h4 id="example">Example&lt;/h4>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyCell&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">MyCell&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">h&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">new_h&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">tanh&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">linear&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">h&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">new_h&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">new_h&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">my_cell&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyCell&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">h&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">4&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_cell&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">trace&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">my_cell&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">h&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>What happens under the hood when we call &lt;code>torch.jit.trace&lt;/code>, passing in the &lt;code>Module&lt;/code> and an example input?&lt;/p>
&lt;ul>
&lt;li>It has invoked the &lt;code>Module&lt;/code>&lt;/li>
&lt;li>Recorded the operations that occured when the &lt;code>Module &lt;/code>was run&lt;/li>
&lt;li>Created an instance of &lt;code>torch.jit.ScriptModule&lt;/code>&lt;/li>
&lt;/ul>
&lt;p>TorchScript records its definitions in an &lt;strong>Intermediate Representation&lt;/strong> (or &lt;strong>IR&lt;/strong>), commonly referred to in Deep learning as a &lt;em>graph&lt;/em> (we can examine the graph with the &lt;code>.graph&lt;/code> property).&lt;/p>
&lt;p>A better way is to use the &lt;code>.code&lt;/code> property to give a Python-syntax interpretation of the code:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">traced_cell&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">code&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Out:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-txt" data-lang="txt">&lt;span class="line">&lt;span class="cl">def forward(self,
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> input: Tensor,
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> h: Tensor) -&amp;gt; Tuple[Tensor, Tensor]:
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> _0 = torch.add((self.linear).forward(input, ), h, alpha=1)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> _1 = torch.tanh(_0)
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> return (_1, _1)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h3 id="scripting">Scripting&lt;/h3>
&lt;p>If our code use control flows (if-else, loop&amp;hellip;), then tracing is unsuitable. In this case, we will use a &lt;strong>script compiler&lt;/strong>, which does code analysis of our Python source code to transform it into TorchScript. The function for compiling the module is &lt;code>torch.jit.script&lt;/code>.&lt;/p>
&lt;h4 id="example-1">Example&lt;/h4>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">MyModule&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">M&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">MyModule&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Parameter&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">N&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">M&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="nb">input&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="nb">input&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="mi">0&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">mv&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">input&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">output&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">weight&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="nb">input&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">output&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">my_module&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">MyModule&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">10&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">20&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">sm&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">script&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">my_module&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>&lt;code>sm&lt;/code> is an instance of &lt;code>ScriptModule&lt;/code> that is ready for serialization.&lt;/p>
&lt;h3 id="mixing-scripting-and-tracing">Mixing Scripting and Tracing&lt;/h3>
&lt;p>In many cases either tracing or scripting is an easier approach for converting a model to TorchScript. Tracing and scripting can be composed to suit the particular requirements of a part of a model.&lt;/p>
&lt;p>&lt;strong>Scripted functions can call traced functions.&lt;/strong>&lt;/p>
&lt;ul>
&lt;li>
&lt;p>Useful when we need to use control-flow around a simple feed-forward model&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Example&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">foo&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="mi">2&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">x&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">y&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_foo&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">trace&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">foo&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nd">@torch.jit.script&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">bar&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">traced_foo&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;/li>
&lt;/ul>
&lt;p>&lt;strong>Traced functions can call script functions.&lt;/strong>&lt;/p>
&lt;ul>
&lt;li>
&lt;p>Useful when a small part of a model requires some control-flow even though most of the model is just a feed-forward network.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Control-flow inside of a script function called by a traced function is preserved correctly.&lt;/p>
&lt;/li>
&lt;li>
&lt;p>Example&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nd">@torch.jit.script&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">foo&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">if&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="o">&amp;gt;&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">max&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">r&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">else&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">r&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">y&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">r&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">bar&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">z&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">foo&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">y&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">+&lt;/span> &lt;span class="n">z&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_bar&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">trace&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">bar&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">3&lt;/span>&lt;span class="p">)))&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;/li>
&lt;/ul>
&lt;h2 id="saving-aand-loading-script-module">Saving aand Loading Script Module&lt;/h2>
&lt;ul>
&lt;li>Save: &lt;code>save()&lt;/code>&lt;/li>
&lt;li>Load: &lt;code>torch.jit.load()&lt;/code>&lt;/li>
&lt;/ul>
&lt;p>Example:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># An instance of your model.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torchvision&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">models&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">resnet18&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># An example input you would normally provide to your model&amp;#39;s forward() method.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">example&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">rand&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">224&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">224&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_script_module&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">trace&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">example&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;ul>
&lt;li>
&lt;p>Save:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_script_module&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">save&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;traced_resnet_model.pt&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;/li>
&lt;li>
&lt;p>Load:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">traced_resnet&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">jit&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">load&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;traced_resnet_model.pt&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;/li>
&lt;/ul>
&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html#basics-of-torchscript">Introduction to TorchScript&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/tutorials/advanced/cpp_export.html#step-2-serializing-your-script-module-to-a-file">Loading a TorchScript Model in C++&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://pytorch.org/docs/stable/jit.html#">Torch Script&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>Performance Measurement</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/measure_fps/</link><pubDate>Mon, 24 May 2021 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-recipes/measure_fps/</guid><description>&lt;h2 id="main-issues-of-time-measurement">Main Issues of Time Measurement&lt;/h2>
&lt;h3 id="gpu-execution-mechanism-asynchronous-execution">GPU Execution Mechanism: Asynchronous Execution&lt;/h3>
&lt;p>In multithreaded or multi-device programming, two blocks of code that are independent can be executed in parallel. This means that the second block may be executed before the first is finished. This process is referred to as &lt;strong>asynchronous execution&lt;/strong>.&lt;/p>
&lt;p>&lt;img src="https://raw.githubusercontent.com/EckoTan0804/upic-repo/master/uPic/Figure-1_white.png" alt="img">&lt;/p>
&lt;p>In the deep learning context, we often use this execution because the &lt;strong>GPU operations are asynchronous by default&lt;/strong>.&lt;/p>
&lt;ul>
&lt;li>More specifically, when calling a function using a GPU, the operations are enqueued to the specific device, but not necessarily to other devices. This allows us to execute computations in parallel on the CPU or another GPU.&lt;/li>
&lt;/ul>
&lt;span style="color:green">
&lt;p>Asynchronous execution offers huge advantages for deep learning, such as the ability to decrease run-time by a large factor.&lt;/p>
&lt;ul>
&lt;li>For example, at the inference of multiple batches, the second batch can be preprocessed on the CPU while the first batch is fed forward through the network on the GPU. Clearly, it would be beneficial to use asynchronism whenever possible at inference time.&lt;/li>
&lt;/ul>
&lt;/span>
&lt;span style="color:red">
&lt;p>However, asynchronous execution can be the cause of many headaches when it comes to time measurements.&lt;/p>
&lt;ul>
&lt;li>When you calculate time with the &lt;code>time&lt;/code> library in Python, the measurements are performed on the CPU device. Due to the asynchronous nature of the GPU, the line of code that stops the timing will be executed before the GPU process finishes. As a result, the timing will be inaccurate or irrelevant to the actual inference time.&lt;/li>
&lt;/ul>
&lt;/span>
&lt;h3 id="gpu-warm-up">GPU Warm-up&lt;/h3>
&lt;p>A modern GPU device can exist in one of several different power states.&lt;/p>
&lt;p>When the GPU is NOT being used for any purpose and persistence mode (i.e., which keeps the GPU on) is not enabled, &lt;strong>the GPU will automatically reduce its power state to a very low level, sometimes even a complete shutdown&lt;/strong>. In lower power state, the GPU shuts down different pieces of hardware, including memory subsystems, internal subsystems, or even compute cores and caches.&lt;/p>
&lt;p>In low power state, the invocation of any program that attempts to interact with the GPU will cause the driver to load and/or initialize the GPU. This driver load behavior is noteworthy! Applications that trigger GPU initialization can incur up to 3 seconds of latency, due to the scrubbing behavior of the error correcting code.&lt;/p>
&lt;ul>
&lt;li>For instance, if we measure time for a network that takes 10 milliseconds for one example, running over 1000 examples may result in most of our running time being wasted on initializing the GPU.&lt;/li>
&lt;/ul>
&lt;h2 id="the-correct-way-to-measure-inference-time">The Correct Way to Measure Inference Time&lt;/h2>
&lt;ul>
&lt;li>Before we make any time measurements, we run some dummy examples through the network to do a ‘&lt;strong>GPU warm-up&lt;/strong>.’ This will automatically initialize the GPU and prevent it from going into power-saving mode when we measure time.&lt;/li>
&lt;li>Next, we use &lt;code>torch.cuda.event&lt;/code> to measure time on the GPU.
&lt;ul>
&lt;li>It is crucial here to use &lt;code>torch.cuda.synchronize()&lt;/code>. This line of code performs synchronization between the host and device (i.e., GPU and CPU), so the time recording takes place only after the process running on the GPU is finished. This overcomes the issue of unsynchronized execution.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ul>
&lt;h3 id="code-snippet">Code Snippet&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.models&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">models&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">tqdm&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">tqdm&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">models&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">resnet18&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pretrained&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dummy_input&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2048&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Init loggers&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">WARMUP_REPETITION&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">100&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">MEASURE_REPETITION&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">300&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">starter&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">ender&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Event&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">enable_timing&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Event&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">enable_timing&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">infer_times&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">zeros&lt;/span>&lt;span class="p">((&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span>&lt;span class="p">,&lt;/span>&lt;span class="mi">1&lt;/span>&lt;span class="p">))&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># GPU warm-up&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">for&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">tqdm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">WARMUP_REPETITION&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">desc&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;GPU warm-up&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">total&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">WARMUP_REPETITION&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dummy_input&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Measure performance&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_grad&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">rep&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">tqdm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">desc&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;Measuring inference time&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">total&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">starter&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">record&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dummy_input&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">ender&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">record&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Wait for GPU sync&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">synchronize&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">curr_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">starter&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">elapsed_time&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ender&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="c1"># time unit is milliseconds&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">curr_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">curr_time&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="mi">1000&lt;/span> &lt;span class="c1"># ms -&amp;gt; s&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">infer_times&lt;/span>&lt;span class="p">[&lt;/span>&lt;span class="n">rep&lt;/span>&lt;span class="p">]&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">curr_time&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">mean_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">sum&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">infer_times&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">MEASURE_REPETITION&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">std_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">np&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">std&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">infer_times&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;Mean: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">mean_time&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s2">.3f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2"> s, Std: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">std_time&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s2">.3f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2"> s&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;FPS: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="mi">1&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">mean_time&lt;/span>&lt;span class="si">:&lt;/span>&lt;span class="s2">.3f&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-txt" data-lang="txt">&lt;span class="line">&lt;span class="cl">GPU warm-up: 100%|██████████| 100/100 [00:04&amp;lt;00:00, 24.66it/s]
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">Measuring inference time: 100%|██████████| 300/300 [00:13&amp;lt;00:00, 21.52it/s]
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">Mean: 44.390 s, Std: 0.890 s
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">FPS: 22.528
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="common-mistakes-when-measuring-time">Common Mistakes when Measuring Time&lt;/h2>
&lt;p>When we measure the latency of a network, our goal is to &lt;strong>measure only the feed-forward of the network (i.e. the inference)&lt;/strong>, not more and not less.&lt;/p>
&lt;p>Some common mistakes are listed below:&lt;/p>
&lt;h3 id="transferring-data-between-the-host-and-the-device">Transferring data between the host and the device&lt;/h3>
&lt;p>One of the most common mistakes involves the &lt;strong>transfer of data between the CPU and GPU&lt;/strong> while taking time measurements. This is usually done unintentionally when a tensor is created on the CPU and inference is then performed on the GPU. This memory allocation takes a considerable amount of time, which subsequently enlarges the time for inference.&lt;/p>
&lt;h3 id="not-using-gpu-warm-up">Not using GPU warm-up&lt;/h3>
&lt;p>The first run on the GPU prompts its initialization. GPU initialization can take up to 3 seconds, which makes a huge difference when the timing is in terms of milliseconds.&lt;/p>
&lt;h3 id="using-standard-cpu-timing">Using standard CPU timing&lt;/h3>
&lt;p>The most common mistake made is to measure time without synchronization.&lt;/p>
&lt;h3 id="taking-only-one-sample">Taking only one sample&lt;/h3>
&lt;p>A common mistake is to use ONLY one sample and refer to it as the run-time.&lt;/p>
&lt;p>Like many processes in computer science, feed forward of the neural network has a (small) stochastic component. The variance of the run-time can be significant, especially when measuring a low latency network. To this end, it is essential to &lt;strong>run the network over several examples and then average the results&lt;/strong> (300 examples can be a good number).&lt;/p>
&lt;h2 id="measuring-fps">Measuring FPS&lt;/h2>
&lt;p>Once we have measured the inference time per image (in second), Frames Per Second (FPS) can be easily computed:
&lt;/p>
$$
FPS = \frac{1}{\text{inference time per image}}
$$
&lt;h2 id="measuring-throughput">Measuring Throughput&lt;/h2>
&lt;p>The &lt;strong>throughput&lt;/strong> of a neural network is defined as &lt;strong>the maximal number of input instances the network can process in time a unit&lt;/strong> (e.g., a second). To achieve maximal throughput we would like to process in parallel as many instances as possible. The effective parallelism is obviously data-, model-, and device-dependent.&lt;/p>
&lt;p>Thus, to correctly measure throughput we perform the following two steps:&lt;/p>
&lt;ol>
&lt;li>
&lt;p>We estimate the &lt;strong>optimal batch size&lt;/strong> that allows for maximum parallelism&lt;/p>
&lt;ul>
&lt;li>Rule of thumb: reach the memory limit of our GPU for the given data type&lt;/li>
&lt;li>Using a for loop, we increase by one the batch size until Run Time error is achieved, this identifies the largest batch size the GPU can process, for our neural network model and the input data it processes.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>
&lt;p>Given this optimal batch size, we measure the number of instances the network can process in one second.&lt;/p>
&lt;ul>
&lt;li>We process many batches (100 batches will be a sufficient number) and then use the following formula:
$$
\frac{\text{\#batches} \times \text{batch size}}{\text{total time in seconds}}
$$
This formula gives the number of examples our network can process in one second.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;/ol>
&lt;h3 id="code-snippet-1">Code Snippet&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torchvision.models&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">models&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">numpy&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">np&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">tqdm&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">tqdm&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Assume that we have estimated the optimal batch size&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">device&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;cuda&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">models&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">resnet18&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">pretrained&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">dummy_input&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">randn&lt;/span>&lt;span class="p">([&lt;/span>&lt;span class="n">optimal_batch_size&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">3&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">1024&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="mi">2048&lt;/span>&lt;span class="p">],&lt;/span> &lt;span class="n">dtype&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">float&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Init loggers&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">MEASURE_REPETITION&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">300&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">starter&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">ender&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Event&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">enable_timing&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Event&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">enable_timing&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="kc">True&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">total_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="mi">0&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="c1"># Measure performance&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">with&lt;/span> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">no_grad&lt;/span>&lt;span class="p">():&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">rep&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="n">tqdm&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span>&lt;span class="p">),&lt;/span> &lt;span class="n">desc&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="s2">&amp;#34;Measuring throughput&amp;#34;&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">total&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">starter&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">record&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">_&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">model&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">dummy_input&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">ender&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">record&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Wait for GPU sync&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">torch&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">synchronize&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">curr_time&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">starter&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">elapsed_time&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">ender&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="mi">1000&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">total_time&lt;/span> &lt;span class="o">+=&lt;/span> &lt;span class="n">curr_time&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="n">throughput&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">(&lt;/span>&lt;span class="n">MEASURE_REPETITION&lt;/span> &lt;span class="o">*&lt;/span> &lt;span class="n">optimal_batch_size&lt;/span>&lt;span class="p">)&lt;/span> &lt;span class="o">/&lt;/span> &lt;span class="n">total_time&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;Final Throughput: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="n">throughput&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="compute-flops">Compute FLOPs&lt;/h2>
&lt;p>Firstly, we have to clearly distinguish between &lt;strong>FLOPS&lt;/strong> and &lt;strong>FLOPs&lt;/strong>&lt;/p>
&lt;ul>
&lt;li>&lt;strong>FLOPS&lt;/strong>: floating point operations per second, is a measure of computer (hardware) performance, useful in fields of scientific computations that require floating-point calculations.&lt;/li>
&lt;li>&lt;strong>FLOPs&lt;/strong>: floating point operations, is the amount of floating point operations, which is a metric for measurement of the complexity of a model or an algorithm.&lt;/li>
&lt;/ul>
&lt;p>To compute FLOPS, we can use &lt;a href="https://github.com/facebookresearch/fvcore">&lt;code>fvcore&lt;/code>&lt;/a>. (More details see: &lt;a href="https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md">Flop Counter for PyTorch Models&lt;/a>)&lt;/p>
&lt;p>Code example:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">from&lt;/span> &lt;span class="nn">fvcore.nn&lt;/span> &lt;span class="kn">import&lt;/span> &lt;span class="n">FlopCountAnalysis&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">def&lt;/span> &lt;span class="nf">get_FLOPs&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dummy_input&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">flops&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">FlopCountAnalysis&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">dummy_input&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">flops&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">total&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="reference">Reference&lt;/h2>
&lt;ul>
&lt;li>
&lt;p>&lt;a href="https://deci.ai/the-correct-way-to-measure-inference-time-of-deep-neural-networks/">The Correct Way to Measure Inference Time of Deep Neural Networks&lt;/a>&lt;/p>
&lt;/li>
&lt;li>
&lt;p>&lt;a href="https://zhuanlan.zhihu.com/p/137719986">CNN 模型所需的计算力flops是什么？怎么计算？&lt;/a>&lt;/p>
&lt;/li>
&lt;/ul></description></item><item><title>📈 Training</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/training/</link><pubDate>Mon, 07 Sep 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/training/</guid><description>&lt;p>This section includes some practical tips and tools for training of neural networks with PyTorch.&lt;/p></description></item><item><title>‼️ Issues &amp; Gotchas</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-issues/</link><pubDate>Mon, 07 Sep 2020 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-issues/</guid><description>&lt;p>This section summaries some issues and gotchas which may occur in practice.&lt;/p></description></item></channel></rss>