<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>PyTorch Issues | Haobin Tan</title><link>https://haobin-tan.netlify.app/tags/pytorch-issues/</link><atom:link href="https://haobin-tan.netlify.app/tags/pytorch-issues/index.xml" rel="self" type="application/rss+xml"/><description>PyTorch Issues</description><generator>Hugo Blox Builder (https://hugoblox.com)</generator><language>en-us</language><lastBuildDate>Wed, 09 Mar 2022 00:00:00 +0000</lastBuildDate><image><url>https://haobin-tan.netlify.app/media/icon_hu7d15bc7db65c8eaf7a4f66f5447d0b42_15095_512x512_fill_lanczos_center_3.png</url><title>PyTorch Issues</title><link>https://haobin-tan.netlify.app/tags/pytorch-issues/</link></image><item><title>Model Registration</title><link>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-issues/model_registration/</link><pubDate>Wed, 09 Mar 2022 00:00:00 +0000</pubDate><guid>https://haobin-tan.netlify.app/docs/ai/pytorch/pytorch-issues/model_registration/</guid><description>&lt;p>Before training the model, modules that need to be trained must be correctly registered. Otherwise, the unregistered modules would NOT be trained without errors or exceptions being thrown. Moreover, when we call &lt;code>model.cuda()&lt;/code>, the unregistered modules will stay on CPU and will not be moved to GPU. In other words, this gotcha is usually hard to notice.&lt;/p>
&lt;h2 id="when-does-this-gotcha-usually-occurs">When does this gotcha usually occurs?&lt;/h2>
&lt;ol>
&lt;li>Use python&amp;rsquo;s &lt;code>list&lt;/code> or &lt;code>dict&lt;/code>but forget to wrap it with &lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html">&lt;code>nn.ModuleList&lt;/code>&lt;/a> or &lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html">&lt;code>nn.ModuleDict&lt;/code>&lt;/a>.
&lt;ul>
&lt;li>In this case, PyTorch can not correctly recognize its elements as trainable modules. Therefore, they can NOT be correctly registered and trained.&lt;/li>
&lt;/ul>
&lt;/li>
&lt;li>An attribute of the model is python&amp;rsquo;s &lt;code>list&lt;/code> or &lt;code>dict&lt;/code>, but forget to wrap it with &lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html">&lt;code>nn.ModuleList&lt;/code>&lt;/a> or &lt;a href="https://pytorch.org/docs/stable/generated/torch.nn.ModuleDict.html">&lt;code>nn.ModuleDict&lt;/code>&lt;/a>.&lt;/li>
&lt;/ol>
&lt;h3 id="example">Example&lt;/h3>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="kn">import&lt;/span> &lt;span class="nn">torch.nn&lt;/span> &lt;span class="k">as&lt;/span> &lt;span class="nn">nn&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">DummyModule&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="s2">&amp;#34;dummy&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_dummy_modules&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Here self.dummy_modul_list is just a python list, &lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># as we do not wrap it with nn.ModuleList&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="n">DummyModule&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_dummy_modules&lt;/span>&lt;span class="p">)]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;#dummy modules: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">dummy_module&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dummy_module&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>Now we initialize the model and move it to GPU:&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-txt" data-lang="txt">&lt;span class="line">&lt;span class="cl">#dummy modules: 4
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">Net()
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;p>We can see that &lt;code>Net&lt;/code> contains nothing. The 4 &lt;code>DummyModule&lt;/code> are not registered.&lt;/p>
&lt;p>Now we use &lt;code>nn.ModuleList&lt;/code> to wrap &lt;code>self.dummy_modul_list&lt;/code> and covert its element to registered trainable modules.&lt;/p>
&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="k">class&lt;/span> &lt;span class="nc">Net&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">Module&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="fm">__init__&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">num_dummy_modules&lt;/span>&lt;span class="o">=&lt;/span>&lt;span class="mi">4&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">super&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="fm">__init__&lt;/span>&lt;span class="p">()&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="p">[&lt;/span>&lt;span class="n">DummyModule&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">cuda&lt;/span>&lt;span class="p">()&lt;/span> &lt;span class="k">for&lt;/span> &lt;span class="n">_&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="nb">range&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">num_dummy_modules&lt;/span>&lt;span class="p">)]&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="c1"># Register elements in self.dummy_module_list as trainable modules&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">nn&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">ModuleList&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="sa">f&lt;/span>&lt;span class="s2">&amp;#34;#dummy modules: &lt;/span>&lt;span class="si">{&lt;/span>&lt;span class="nb">len&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span>&lt;span class="p">)&lt;/span>&lt;span class="si">}&lt;/span>&lt;span class="s2">&amp;#34;&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">def&lt;/span> &lt;span class="nf">forward&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="bp">self&lt;/span>&lt;span class="p">,&lt;/span> &lt;span class="n">x&lt;/span>&lt;span class="p">):&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">for&lt;/span> &lt;span class="n">dummy_module&lt;/span> &lt;span class="ow">in&lt;/span> &lt;span class="bp">self&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">dummy_module_list&lt;/span>&lt;span class="p">:&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="n">x&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">dummy_module&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">x&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> &lt;span class="k">return&lt;/span> &lt;span class="n">x&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-python" data-lang="python">&lt;span class="line">&lt;span class="cl">&lt;span class="n">model&lt;/span> &lt;span class="o">=&lt;/span> &lt;span class="n">Net&lt;/span>&lt;span class="p">()&lt;/span>&lt;span class="o">.&lt;/span>&lt;span class="n">to&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">device&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">&lt;span class="nb">print&lt;/span>&lt;span class="p">(&lt;/span>&lt;span class="n">model&lt;/span>&lt;span class="p">)&lt;/span>
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;div class="highlight">&lt;pre tabindex="0" class="chroma">&lt;code class="language-txt" data-lang="txt">&lt;span class="line">&lt;span class="cl">#dummy modules: 4
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">Net(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (dummy_module_list): ModuleList(
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (0): DummyModule()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (1): DummyModule()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (2): DummyModule()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> (3): DummyModule()
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl"> )
&lt;/span>&lt;/span>&lt;span class="line">&lt;span class="cl">)
&lt;/span>&lt;/span>&lt;/code>&lt;/pre>&lt;/div>&lt;h2 id="references">References&lt;/h2>
&lt;ul>
&lt;li>&lt;a href="https://hellojialee.github.io/2020/05/28/Pytorch%E5%AE%9E%E7%94%A8%E6%8C%87%E5%8D%97/">网络模型构建&lt;/a>&lt;/li>
&lt;/ul></description></item></channel></rss>